diff --git a/src/server.rs b/src/server.rs index 91169aa..232f40c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -57,10 +57,10 @@ async fn open_db(db_path: &Path) -> sqlx::Result { } #[derive(Clone)] -pub(self) struct Repo(Arc); +pub struct Repo(Arc); #[derive(Clone)] -pub(self) struct BenchRepo(Arc); +pub struct BenchRepo(Arc); #[derive(Clone, FromRef)] pub struct Server { diff --git a/src/server/web.rs b/src/server/web.rs index 5e16627..0385c59 100644 --- a/src/server/web.rs +++ b/src/server/web.rs @@ -12,7 +12,13 @@ use axum_extra::routing::RouterExt; use crate::{config::Config, somehow}; -use self::admin::queue::post_admin_queue_add; +use self::{ + admin::queue::post_admin_queue_add, + api::worker::{ + get_api_worker_bench_repo_by_hash_tree_tar_gz, get_api_worker_repo_by_hash_tree_tar_gz, + post_api_worker_status, + }, +}; use super::Server; @@ -53,8 +59,10 @@ pub async fn run(server: Server) -> somehow::Result<()> { .route("/queue/", get(queue::get)) .route("/queue/inner", get(queue::get_inner)) .route("/worker/:name", get(worker::get)) + .typed_get(get_api_worker_bench_repo_by_hash_tree_tar_gz) + .typed_get(get_api_worker_repo_by_hash_tree_tar_gz) .typed_post(post_admin_queue_add) - .merge(api::router(&server)) + .typed_post(post_api_worker_status) .fallback(get(r#static::static_handler)) .with_state(server.clone()); diff --git a/src/server/web/api.rs b/src/server/web/api.rs index 5d04b4e..2c8b839 100644 --- a/src/server/web/api.rs +++ b/src/server/web/api.rs @@ -1,246 +1 @@ -mod auth; -mod stream; - -use std::sync::{Arc, Mutex}; - -use axum::{ - body::StreamBody, - extract::{Path, State}, - headers::{authorization::Basic, Authorization}, - http::StatusCode, - http::{header, HeaderValue}, - response::{IntoResponse, Response}, - routing::{get, post}, - Json, Router, TypedHeader, -}; -use gix::{ObjectId, ThreadSafeRepository}; -use sqlx::{Acquire, SqlitePool}; -use time::OffsetDateTime; -use tracing::debug; - -use crate::{ - config::Config, - server::{ - workers::{WorkerInfo, Workers}, - BenchRepo, Repo, Server, - }, - shared::{BenchMethod, FinishedRun, ServerResponse, WorkerRequest}, - somehow, -}; - -async fn save_work(finished: FinishedRun, db: &SqlitePool) -> somehow::Result<()> { - let mut tx = db.begin().await?; - let conn = tx.acquire().await?; - - let end = OffsetDateTime::now_utc(); - let bench_method = match finished.run.bench_method { - BenchMethod::Internal => "internal".to_string(), - BenchMethod::Repo { hash } => format!("bench repo, hash {hash}"), - }; - - sqlx::query!( - "\ - INSERT INTO runs ( \ - id, \ - hash, \ - bench_method, \ - start, \ - end, \ - exit_code \ - ) \ - VALUES (?, ?, ?, ?, ?, ?) \ - ", - finished.run.id, - finished.run.hash, - bench_method, - finished.run.start, - end, - finished.exit_code, - ) - .execute(&mut *conn) - .await?; - - for (name, measurement) in finished.measurements { - sqlx::query!( - "\ - INSERT INTO run_measurements ( \ - id, \ - name, \ - value, \ - stddev, \ - unit, \ - direction \ - ) \ - VALUES (?, ?, ?, ?, ?, ?) \ - ", - finished.run.id, - name, - measurement.value, - measurement.stddev, - measurement.unit, - measurement.direction, - ) - .execute(&mut *conn) - .await?; - } - - for (idx, (source, text)) in finished.output.into_iter().enumerate() { - // Hopefully we won't need more than 4294967296 output chunks per run :P - let idx = idx as u32; - sqlx::query!( - "\ - INSERT INTO run_output ( \ - id, \ - idx, \ - source, \ - text \ - ) \ - VALUES (?, ?, ?, ?) \ - ", - finished.run.id, - idx, - source, - text, - ) - .execute(&mut *conn) - .await?; - } - - // The thing has been done :D - sqlx::query!("DELETE FROM queue WHERE hash = ?", finished.run.hash) - .execute(&mut *conn) - .await?; - - tx.commit().await?; - Ok(()) -} - -async fn post_status( - State(config): State<&'static Config>, - State(db): State, - State(bench_repo): State>, - State(workers): State>>, - auth: Option>>, - Json(request): Json, -) -> somehow::Result { - let name = match auth::authenticate(config, auth) { - Ok(name) => name, - Err(response) => return Ok(response), - }; - - if let Some(run) = request.submit_run { - save_work(run, &db).await?; - } - - // Fetch queue - let queue = sqlx::query_scalar!( - "\ - SELECT hash FROM queue \ - ORDER BY priority DESC, unixepoch(date) DESC, hash ASC \ - " - ) - .fetch_all(&db) - .await?; - - // Fetch bench method - let bench_method = match bench_repo { - Some(bench_repo) => BenchMethod::Repo { - hash: bench_repo.0.to_thread_local().head_id()?.to_string(), - }, - None => BenchMethod::Internal, - }; - - // Update internal state - let (work, abort_work) = { - let mut guard = workers.lock().unwrap(); - guard.clean(); - if !guard.verify(&name, &request.secret) { - return Ok((StatusCode::UNAUTHORIZED, "invalid secret").into_response()); - } - guard.update( - name.clone(), - WorkerInfo::new(request.secret, OffsetDateTime::now_utc(), request.status), - ); - let work = match request.request_run { - true => guard.find_and_reserve_run(&name, &queue, bench_method), - false => None, - }; - let abort_work = guard.should_abort_work(&name, &queue); - (work, abort_work) - }; - - debug!("Received status update from {name}"); - Ok(Json(ServerResponse { - run: work, - abort_run: abort_work, - }) - .into_response()) -} - -fn stream_response(repo: Arc, id: ObjectId) -> impl IntoResponse { - ( - [ - ( - header::CONTENT_TYPE, - HeaderValue::from_static("application/gzip"), - ), - ( - header::CONTENT_DISPOSITION, - HeaderValue::from_static("attachment; filename=\"tree.tar.gz\""), - ), - ], - StreamBody::new(stream::tar_and_gzip(repo, id)), - ) -} - -async fn get_repo( - State(config): State<&'static Config>, - State(repo): State>, - auth: Option>>, - Path(hash): Path, -) -> somehow::Result { - let _name = match auth::authenticate(config, auth) { - Ok(name) => name, - Err(response) => return Ok(response), - }; - - let Some(repo) = repo else { - return Ok(StatusCode::NOT_FOUND.into_response()); - }; - - let id = hash.parse::()?; - Ok(stream_response(repo.0, id).into_response()) -} - -async fn get_bench_repo( - State(config): State<&'static Config>, - State(bench_repo): State>, - auth: Option>>, - Path(hash): Path, -) -> somehow::Result { - let _name = match auth::authenticate(config, auth) { - Ok(name) => name, - Err(response) => return Ok(response), - }; - - let Some(bench_repo) = bench_repo else { - return Ok(StatusCode::NOT_FOUND.into_response()); - }; - - let id = hash.parse::()?; - Ok(stream_response(bench_repo.0, id).into_response()) -} - -pub fn router(server: &Server) -> Router { - if server.repo.is_none() { - return Router::new(); - } - - Router::new() - .route("/api/worker/status", post(post_status)) - .route("/api/worker/repo/:hash/tree.tar.gz", get(get_repo)) - .route( - "/api/worker/bench_repo/:hash/tree.tar.gz", - get(get_bench_repo), - ) -} +pub mod worker; diff --git a/src/server/web/api/worker.rs b/src/server/web/api/worker.rs new file mode 100644 index 0000000..4549d17 --- /dev/null +++ b/src/server/web/api/worker.rs @@ -0,0 +1,250 @@ +mod auth; +mod stream; + +use std::sync::{Arc, Mutex}; + +use axum::{ + body::StreamBody, + extract::State, + headers::{authorization::Basic, Authorization}, + http::StatusCode, + http::{header, HeaderValue}, + response::{IntoResponse, Response}, + Json, TypedHeader, +}; +use axum_extra::routing::TypedPath; +use gix::{ObjectId, ThreadSafeRepository}; +use serde::Deserialize; +use sqlx::{Acquire, SqlitePool}; +use time::OffsetDateTime; +use tracing::debug; + +use crate::{ + config::Config, + server::{ + workers::{WorkerInfo, Workers}, + BenchRepo, Repo, + }, + shared::{BenchMethod, FinishedRun, ServerResponse, WorkerRequest}, + somehow, +}; + +async fn save_work(finished: FinishedRun, db: &SqlitePool) -> somehow::Result<()> { + let mut tx = db.begin().await?; + let conn = tx.acquire().await?; + + let end = OffsetDateTime::now_utc(); + let bench_method = match finished.run.bench_method { + BenchMethod::Internal => "internal".to_string(), + BenchMethod::Repo { hash } => format!("bench repo, hash {hash}"), + }; + + sqlx::query!( + "\ + INSERT INTO runs ( \ + id, \ + hash, \ + bench_method, \ + start, \ + end, \ + exit_code \ + ) \ + VALUES (?, ?, ?, ?, ?, ?) \ + ", + finished.run.id, + finished.run.hash, + bench_method, + finished.run.start, + end, + finished.exit_code, + ) + .execute(&mut *conn) + .await?; + + for (name, measurement) in finished.measurements { + sqlx::query!( + "\ + INSERT INTO run_measurements ( \ + id, \ + name, \ + value, \ + stddev, \ + unit, \ + direction \ + ) \ + VALUES (?, ?, ?, ?, ?, ?) \ + ", + finished.run.id, + name, + measurement.value, + measurement.stddev, + measurement.unit, + measurement.direction, + ) + .execute(&mut *conn) + .await?; + } + + for (idx, (source, text)) in finished.output.into_iter().enumerate() { + // Hopefully we won't need more than 4294967296 output chunks per run :P + let idx = idx as u32; + sqlx::query!( + "\ + INSERT INTO run_output ( \ + id, \ + idx, \ + source, \ + text \ + ) \ + VALUES (?, ?, ?, ?) \ + ", + finished.run.id, + idx, + source, + text, + ) + .execute(&mut *conn) + .await?; + } + + // The thing has been done :D + sqlx::query!("DELETE FROM queue WHERE hash = ?", finished.run.hash) + .execute(&mut *conn) + .await?; + + tx.commit().await?; + Ok(()) +} + +#[derive(Deserialize, TypedPath)] +#[typed_path("/api/worker/status")] +pub struct PathApiWorkerStatus {} + +pub async fn post_api_worker_status( + _path: PathApiWorkerStatus, + State(config): State<&'static Config>, + State(db): State, + State(bench_repo): State>, + State(workers): State>>, + auth: Option>>, + Json(request): Json, +) -> somehow::Result { + let name = match auth::authenticate(config, auth) { + Ok(name) => name, + Err(response) => return Ok(response), + }; + + if let Some(run) = request.submit_run { + save_work(run, &db).await?; + } + + // Fetch queue + let queue = sqlx::query_scalar!( + "\ + SELECT hash FROM queue \ + ORDER BY priority DESC, unixepoch(date) DESC, hash ASC \ + " + ) + .fetch_all(&db) + .await?; + + // Fetch bench method + let bench_method = match bench_repo { + Some(bench_repo) => BenchMethod::Repo { + hash: bench_repo.0.to_thread_local().head_id()?.to_string(), + }, + None => BenchMethod::Internal, + }; + + // Update internal state + let (work, abort_work) = { + let mut guard = workers.lock().unwrap(); + guard.clean(); + if !guard.verify(&name, &request.secret) { + return Ok((StatusCode::UNAUTHORIZED, "invalid secret").into_response()); + } + guard.update( + name.clone(), + WorkerInfo::new(request.secret, OffsetDateTime::now_utc(), request.status), + ); + let work = match request.request_run { + true => guard.find_and_reserve_run(&name, &queue, bench_method), + false => None, + }; + let abort_work = guard.should_abort_work(&name, &queue); + (work, abort_work) + }; + + debug!("Received status update from {name}"); + Ok(Json(ServerResponse { + run: work, + abort_run: abort_work, + }) + .into_response()) +} + +fn stream_response(repo: Arc, id: ObjectId) -> impl IntoResponse { + ( + [ + ( + header::CONTENT_TYPE, + HeaderValue::from_static("application/gzip"), + ), + ( + header::CONTENT_DISPOSITION, + HeaderValue::from_static("attachment; filename=\"tree.tar.gz\""), + ), + ], + StreamBody::new(stream::tar_and_gzip(repo, id)), + ) +} + +#[derive(Deserialize, TypedPath)] +#[typed_path("/api/worker/repo/:hash/tree.tar.gz")] +pub struct PathApiWorkerRepoByHashTreeTarGz { + hash: String, +} + +pub async fn get_api_worker_repo_by_hash_tree_tar_gz( + path: PathApiWorkerRepoByHashTreeTarGz, + State(config): State<&'static Config>, + State(repo): State>, + auth: Option>>, +) -> somehow::Result { + let _name = match auth::authenticate(config, auth) { + Ok(name) => name, + Err(response) => return Ok(response), + }; + + let Some(repo) = repo else { + return Ok(StatusCode::NOT_FOUND.into_response()); + }; + + let id = path.hash.parse::()?; + Ok(stream_response(repo.0, id).into_response()) +} + +#[derive(Deserialize, TypedPath)] +#[typed_path("/api/worker/bench_repo/:hash/tree.tar.gz")] +pub struct PathApiWorkerBenchRepoByHashTreeTarGz { + hash: String, +} + +pub async fn get_api_worker_bench_repo_by_hash_tree_tar_gz( + path: PathApiWorkerBenchRepoByHashTreeTarGz, + State(config): State<&'static Config>, + State(bench_repo): State>, + auth: Option>>, +) -> somehow::Result { + let _name = match auth::authenticate(config, auth) { + Ok(name) => name, + Err(response) => return Ok(response), + }; + + let Some(bench_repo) = bench_repo else { + return Ok(StatusCode::NOT_FOUND.into_response()); + }; + + let id = path.hash.parse::()?; + Ok(stream_response(bench_repo.0, id).into_response()) +} diff --git a/src/server/web/api/auth.rs b/src/server/web/api/worker/auth.rs similarity index 100% rename from src/server/web/api/auth.rs rename to src/server/web/api/worker/auth.rs diff --git a/src/server/web/api/stream.rs b/src/server/web/api/worker/stream.rs similarity index 100% rename from src/server/web/api/stream.rs rename to src/server/web/api/worker/stream.rs