Move worker api to server::web::api::worker

This commit is contained in:
Joscha 2023-08-13 16:04:48 +02:00
parent 087ecfd783
commit 88d9a1f818
6 changed files with 263 additions and 250 deletions

View file

@ -12,7 +12,13 @@ use axum_extra::routing::RouterExt;
use crate::{config::Config, somehow};
use self::admin::queue::post_admin_queue_add;
use self::{
admin::queue::post_admin_queue_add,
api::worker::{
get_api_worker_bench_repo_by_hash_tree_tar_gz, get_api_worker_repo_by_hash_tree_tar_gz,
post_api_worker_status,
},
};
use super::Server;
@ -53,8 +59,10 @@ pub async fn run(server: Server) -> somehow::Result<()> {
.route("/queue/", get(queue::get))
.route("/queue/inner", get(queue::get_inner))
.route("/worker/:name", get(worker::get))
.typed_get(get_api_worker_bench_repo_by_hash_tree_tar_gz)
.typed_get(get_api_worker_repo_by_hash_tree_tar_gz)
.typed_post(post_admin_queue_add)
.merge(api::router(&server))
.typed_post(post_api_worker_status)
.fallback(get(r#static::static_handler))
.with_state(server.clone());

View file

@ -1,246 +1 @@
mod auth;
mod stream;
use std::sync::{Arc, Mutex};
use axum::{
body::StreamBody,
extract::{Path, State},
headers::{authorization::Basic, Authorization},
http::StatusCode,
http::{header, HeaderValue},
response::{IntoResponse, Response},
routing::{get, post},
Json, Router, TypedHeader,
};
use gix::{ObjectId, ThreadSafeRepository};
use sqlx::{Acquire, SqlitePool};
use time::OffsetDateTime;
use tracing::debug;
use crate::{
config::Config,
server::{
workers::{WorkerInfo, Workers},
BenchRepo, Repo, Server,
},
shared::{BenchMethod, FinishedRun, ServerResponse, WorkerRequest},
somehow,
};
async fn save_work(finished: FinishedRun, db: &SqlitePool) -> somehow::Result<()> {
let mut tx = db.begin().await?;
let conn = tx.acquire().await?;
let end = OffsetDateTime::now_utc();
let bench_method = match finished.run.bench_method {
BenchMethod::Internal => "internal".to_string(),
BenchMethod::Repo { hash } => format!("bench repo, hash {hash}"),
};
sqlx::query!(
"\
INSERT INTO runs ( \
id, \
hash, \
bench_method, \
start, \
end, \
exit_code \
) \
VALUES (?, ?, ?, ?, ?, ?) \
",
finished.run.id,
finished.run.hash,
bench_method,
finished.run.start,
end,
finished.exit_code,
)
.execute(&mut *conn)
.await?;
for (name, measurement) in finished.measurements {
sqlx::query!(
"\
INSERT INTO run_measurements ( \
id, \
name, \
value, \
stddev, \
unit, \
direction \
) \
VALUES (?, ?, ?, ?, ?, ?) \
",
finished.run.id,
name,
measurement.value,
measurement.stddev,
measurement.unit,
measurement.direction,
)
.execute(&mut *conn)
.await?;
}
for (idx, (source, text)) in finished.output.into_iter().enumerate() {
// Hopefully we won't need more than 4294967296 output chunks per run :P
let idx = idx as u32;
sqlx::query!(
"\
INSERT INTO run_output ( \
id, \
idx, \
source, \
text \
) \
VALUES (?, ?, ?, ?) \
",
finished.run.id,
idx,
source,
text,
)
.execute(&mut *conn)
.await?;
}
// The thing has been done :D
sqlx::query!("DELETE FROM queue WHERE hash = ?", finished.run.hash)
.execute(&mut *conn)
.await?;
tx.commit().await?;
Ok(())
}
async fn post_status(
State(config): State<&'static Config>,
State(db): State<SqlitePool>,
State(bench_repo): State<Option<BenchRepo>>,
State(workers): State<Arc<Mutex<Workers>>>,
auth: Option<TypedHeader<Authorization<Basic>>>,
Json(request): Json<WorkerRequest>,
) -> somehow::Result<Response> {
let name = match auth::authenticate(config, auth) {
Ok(name) => name,
Err(response) => return Ok(response),
};
if let Some(run) = request.submit_run {
save_work(run, &db).await?;
}
// Fetch queue
let queue = sqlx::query_scalar!(
"\
SELECT hash FROM queue \
ORDER BY priority DESC, unixepoch(date) DESC, hash ASC \
"
)
.fetch_all(&db)
.await?;
// Fetch bench method
let bench_method = match bench_repo {
Some(bench_repo) => BenchMethod::Repo {
hash: bench_repo.0.to_thread_local().head_id()?.to_string(),
},
None => BenchMethod::Internal,
};
// Update internal state
let (work, abort_work) = {
let mut guard = workers.lock().unwrap();
guard.clean();
if !guard.verify(&name, &request.secret) {
return Ok((StatusCode::UNAUTHORIZED, "invalid secret").into_response());
}
guard.update(
name.clone(),
WorkerInfo::new(request.secret, OffsetDateTime::now_utc(), request.status),
);
let work = match request.request_run {
true => guard.find_and_reserve_run(&name, &queue, bench_method),
false => None,
};
let abort_work = guard.should_abort_work(&name, &queue);
(work, abort_work)
};
debug!("Received status update from {name}");
Ok(Json(ServerResponse {
run: work,
abort_run: abort_work,
})
.into_response())
}
fn stream_response(repo: Arc<ThreadSafeRepository>, id: ObjectId) -> impl IntoResponse {
(
[
(
header::CONTENT_TYPE,
HeaderValue::from_static("application/gzip"),
),
(
header::CONTENT_DISPOSITION,
HeaderValue::from_static("attachment; filename=\"tree.tar.gz\""),
),
],
StreamBody::new(stream::tar_and_gzip(repo, id)),
)
}
async fn get_repo(
State(config): State<&'static Config>,
State(repo): State<Option<Repo>>,
auth: Option<TypedHeader<Authorization<Basic>>>,
Path(hash): Path<String>,
) -> somehow::Result<Response> {
let _name = match auth::authenticate(config, auth) {
Ok(name) => name,
Err(response) => return Ok(response),
};
let Some(repo) = repo else {
return Ok(StatusCode::NOT_FOUND.into_response());
};
let id = hash.parse::<ObjectId>()?;
Ok(stream_response(repo.0, id).into_response())
}
async fn get_bench_repo(
State(config): State<&'static Config>,
State(bench_repo): State<Option<BenchRepo>>,
auth: Option<TypedHeader<Authorization<Basic>>>,
Path(hash): Path<String>,
) -> somehow::Result<Response> {
let _name = match auth::authenticate(config, auth) {
Ok(name) => name,
Err(response) => return Ok(response),
};
let Some(bench_repo) = bench_repo else {
return Ok(StatusCode::NOT_FOUND.into_response());
};
let id = hash.parse::<ObjectId>()?;
Ok(stream_response(bench_repo.0, id).into_response())
}
pub fn router(server: &Server) -> Router<Server> {
if server.repo.is_none() {
return Router::new();
}
Router::new()
.route("/api/worker/status", post(post_status))
.route("/api/worker/repo/:hash/tree.tar.gz", get(get_repo))
.route(
"/api/worker/bench_repo/:hash/tree.tar.gz",
get(get_bench_repo),
)
}
pub mod worker;

View file

@ -0,0 +1,250 @@
mod auth;
mod stream;
use std::sync::{Arc, Mutex};
use axum::{
body::StreamBody,
extract::State,
headers::{authorization::Basic, Authorization},
http::StatusCode,
http::{header, HeaderValue},
response::{IntoResponse, Response},
Json, TypedHeader,
};
use axum_extra::routing::TypedPath;
use gix::{ObjectId, ThreadSafeRepository};
use serde::Deserialize;
use sqlx::{Acquire, SqlitePool};
use time::OffsetDateTime;
use tracing::debug;
use crate::{
config::Config,
server::{
workers::{WorkerInfo, Workers},
BenchRepo, Repo,
},
shared::{BenchMethod, FinishedRun, ServerResponse, WorkerRequest},
somehow,
};
async fn save_work(finished: FinishedRun, db: &SqlitePool) -> somehow::Result<()> {
let mut tx = db.begin().await?;
let conn = tx.acquire().await?;
let end = OffsetDateTime::now_utc();
let bench_method = match finished.run.bench_method {
BenchMethod::Internal => "internal".to_string(),
BenchMethod::Repo { hash } => format!("bench repo, hash {hash}"),
};
sqlx::query!(
"\
INSERT INTO runs ( \
id, \
hash, \
bench_method, \
start, \
end, \
exit_code \
) \
VALUES (?, ?, ?, ?, ?, ?) \
",
finished.run.id,
finished.run.hash,
bench_method,
finished.run.start,
end,
finished.exit_code,
)
.execute(&mut *conn)
.await?;
for (name, measurement) in finished.measurements {
sqlx::query!(
"\
INSERT INTO run_measurements ( \
id, \
name, \
value, \
stddev, \
unit, \
direction \
) \
VALUES (?, ?, ?, ?, ?, ?) \
",
finished.run.id,
name,
measurement.value,
measurement.stddev,
measurement.unit,
measurement.direction,
)
.execute(&mut *conn)
.await?;
}
for (idx, (source, text)) in finished.output.into_iter().enumerate() {
// Hopefully we won't need more than 4294967296 output chunks per run :P
let idx = idx as u32;
sqlx::query!(
"\
INSERT INTO run_output ( \
id, \
idx, \
source, \
text \
) \
VALUES (?, ?, ?, ?) \
",
finished.run.id,
idx,
source,
text,
)
.execute(&mut *conn)
.await?;
}
// The thing has been done :D
sqlx::query!("DELETE FROM queue WHERE hash = ?", finished.run.hash)
.execute(&mut *conn)
.await?;
tx.commit().await?;
Ok(())
}
#[derive(Deserialize, TypedPath)]
#[typed_path("/api/worker/status")]
pub struct PathApiWorkerStatus {}
pub async fn post_api_worker_status(
_path: PathApiWorkerStatus,
State(config): State<&'static Config>,
State(db): State<SqlitePool>,
State(bench_repo): State<Option<BenchRepo>>,
State(workers): State<Arc<Mutex<Workers>>>,
auth: Option<TypedHeader<Authorization<Basic>>>,
Json(request): Json<WorkerRequest>,
) -> somehow::Result<Response> {
let name = match auth::authenticate(config, auth) {
Ok(name) => name,
Err(response) => return Ok(response),
};
if let Some(run) = request.submit_run {
save_work(run, &db).await?;
}
// Fetch queue
let queue = sqlx::query_scalar!(
"\
SELECT hash FROM queue \
ORDER BY priority DESC, unixepoch(date) DESC, hash ASC \
"
)
.fetch_all(&db)
.await?;
// Fetch bench method
let bench_method = match bench_repo {
Some(bench_repo) => BenchMethod::Repo {
hash: bench_repo.0.to_thread_local().head_id()?.to_string(),
},
None => BenchMethod::Internal,
};
// Update internal state
let (work, abort_work) = {
let mut guard = workers.lock().unwrap();
guard.clean();
if !guard.verify(&name, &request.secret) {
return Ok((StatusCode::UNAUTHORIZED, "invalid secret").into_response());
}
guard.update(
name.clone(),
WorkerInfo::new(request.secret, OffsetDateTime::now_utc(), request.status),
);
let work = match request.request_run {
true => guard.find_and_reserve_run(&name, &queue, bench_method),
false => None,
};
let abort_work = guard.should_abort_work(&name, &queue);
(work, abort_work)
};
debug!("Received status update from {name}");
Ok(Json(ServerResponse {
run: work,
abort_run: abort_work,
})
.into_response())
}
fn stream_response(repo: Arc<ThreadSafeRepository>, id: ObjectId) -> impl IntoResponse {
(
[
(
header::CONTENT_TYPE,
HeaderValue::from_static("application/gzip"),
),
(
header::CONTENT_DISPOSITION,
HeaderValue::from_static("attachment; filename=\"tree.tar.gz\""),
),
],
StreamBody::new(stream::tar_and_gzip(repo, id)),
)
}
#[derive(Deserialize, TypedPath)]
#[typed_path("/api/worker/repo/:hash/tree.tar.gz")]
pub struct PathApiWorkerRepoByHashTreeTarGz {
hash: String,
}
pub async fn get_api_worker_repo_by_hash_tree_tar_gz(
path: PathApiWorkerRepoByHashTreeTarGz,
State(config): State<&'static Config>,
State(repo): State<Option<Repo>>,
auth: Option<TypedHeader<Authorization<Basic>>>,
) -> somehow::Result<Response> {
let _name = match auth::authenticate(config, auth) {
Ok(name) => name,
Err(response) => return Ok(response),
};
let Some(repo) = repo else {
return Ok(StatusCode::NOT_FOUND.into_response());
};
let id = path.hash.parse::<ObjectId>()?;
Ok(stream_response(repo.0, id).into_response())
}
#[derive(Deserialize, TypedPath)]
#[typed_path("/api/worker/bench_repo/:hash/tree.tar.gz")]
pub struct PathApiWorkerBenchRepoByHashTreeTarGz {
hash: String,
}
pub async fn get_api_worker_bench_repo_by_hash_tree_tar_gz(
path: PathApiWorkerBenchRepoByHashTreeTarGz,
State(config): State<&'static Config>,
State(bench_repo): State<Option<BenchRepo>>,
auth: Option<TypedHeader<Authorization<Basic>>>,
) -> somehow::Result<Response> {
let _name = match auth::authenticate(config, auth) {
Ok(name) => name,
Err(response) => return Ok(response),
};
let Some(bench_repo) = bench_repo else {
return Ok(StatusCode::NOT_FOUND.into_response());
};
let id = path.hash.parse::<ObjectId>()?;
Ok(stream_response(bench_repo.0, id).into_response())
}