Add TokioVault

This commit is contained in:
Joscha 2022-12-11 19:08:41 +01:00
parent 56e25defce
commit b763fb62de
3 changed files with 159 additions and 5 deletions

View file

@ -3,5 +3,10 @@ name = "vault"
version = "0.0.0" version = "0.0.0"
edition = "2021" edition = "2021"
[features]
tokio = ["dep:tokio"]
[dependencies] [dependencies]
rusqlite = "0.28.0" rusqlite = "0.28.0"
thiserror = "1.0.37"
tokio = { version = "1.23.0", features = ["sync"], optional = true }

View file

@ -1,10 +1,9 @@
#[cfg(feature = "tokio")]
pub mod tokio;
use rusqlite::Connection; use rusqlite::Connection;
pub trait DbExecute { pub trait Action {
fn run(self, conn: &mut Connection) -> rusqlite::Result<()>;
}
pub trait DbQuery {
type Result; type Result;
fn run(self, conn: &mut Connection) -> rusqlite::Result<Self::Result>; fn run(self, conn: &mut Connection) -> rusqlite::Result<Self::Result>;
} }

150
src/tokio.rs Normal file
View file

@ -0,0 +1,150 @@
use std::{any::Any, result, thread};
use rusqlite::{Connection, Transaction};
use tokio::sync::{mpsc, oneshot};
use crate::Action;
/// Wrapper trait around [`Action`] that turns `Box<Self>` into a `Self` and the
/// action's return type into `Box<dyn Any + Send>`.
///
/// This way, the trait that users of this crate interact with is kept simpler.
trait ActionWrapper {
fn run(self: Box<Self>, conn: &mut Connection) -> rusqlite::Result<Box<dyn Any + Send>>;
}
impl<T: Action> ActionWrapper for T
where
T::Result: Send + 'static,
{
fn run(self: Box<Self>, conn: &mut Connection) -> rusqlite::Result<Box<dyn Any + Send>> {
let result = (*self).run(conn)?;
Ok(Box::new(result))
}
}
/// Command to be sent via the mpsc channel to the vault thread.
enum Command {
Action(
Box<dyn ActionWrapper + Send>,
oneshot::Sender<rusqlite::Result<Box<dyn Any + Send>>>,
),
Stop(oneshot::Sender<()>),
}
/// Error that can occur during execution of an [`Action`].
#[derive(Debug, thiserror::Error)]
pub enum Error {
/// The vault thread has stopped.
#[error("vault thread has stopped")]
Stopped,
/// A [`rusqlite::Error`] occurred while running the action.
#[error("{0}")]
Rusqlite(#[from] rusqlite::Error),
}
pub type Result<R> = result::Result<R, Error>;
/// A single database migration.
///
/// It receives a [`Transaction`] to perform database operations in, and its
/// index in the migration array. The latter might be useful for logging.
///
/// The transaction spans all migrations currently being performed. If any
/// single migration fails, all migrations are rolled back and the database is
/// unchanged.
///
/// The migration does not need to update the `user_version` or commit the
/// transaction.
pub type Migration = fn(&mut Transaction, usize) -> rusqlite::Result<()>;
fn migrate(conn: &mut Connection, migrations: &[Migration]) -> rusqlite::Result<()> {
let mut tx = conn.transaction()?;
let user_version: usize =
tx.query_row("SELECT * FROM pragma_user_version", [], |r| r.get(0))?;
let total = migrations.len();
assert!(user_version <= total, "malformed database schema");
for (i, migration) in migrations.iter().enumerate().skip(user_version) {
migration(&mut tx, i)?;
}
tx.pragma_update(None, "user_version", total)?;
tx.commit()
}
fn run(mut conn: Connection, mut rx: mpsc::UnboundedReceiver<Command>) {
while let Some(command) = rx.blocking_recv() {
match command {
Command::Action(action, tx) => {
let result = action.run(&mut conn);
let _ = tx.send(result);
}
Command::Stop(tx) => {
drop(conn);
drop(tx);
break;
}
}
}
}
#[derive(Clone)]
pub struct TokioVault {
tx: mpsc::UnboundedSender<Command>,
}
impl TokioVault {
/// Launch a new thread to run database queries on, and return a
/// [`TokioVault`] for communication with that thread.
///
/// It is recommended to set a few pragmas before calling this function, for
/// example:
/// - `journal_mode` to `"wal"`
/// - `foreign_keys` to `true`
/// - `trusted_schema` to `false`
pub fn launch(
mut conn: Connection,
migrations: &[Migration],
prepare: impl FnOnce(&mut Connection) -> rusqlite::Result<()>,
) -> rusqlite::Result<Self> {
migrate(&mut conn, migrations)?;
prepare(&mut conn)?;
let (tx, rx) = mpsc::unbounded_channel();
thread::spawn(move || run(conn, rx));
Ok(Self { tx })
}
/// Execute an [`Action`] and return the result.
pub async fn execute<A>(&self, action: A) -> Result<A::Result>
where
A: Action + Send + 'static,
A::Result: Send,
{
let (tx, rx) = oneshot::channel();
self.tx
.send(Command::Action(Box::new(action), tx))
.map_err(|_| Error::Stopped)?;
let result = rx.await.map_err(|_| Error::Stopped)??;
// The ActionWrapper runs Action::run, which returns Action::Result. It
// then wraps this into Any, which we're now trying to downcast again to
// Action::Result. This should always work.
let result = result.downcast().unwrap();
Ok(*result)
}
/// Stop the vault thread.
///
/// Returns when the vault has been stopped successfully.
pub async fn stop(&self) {
let (tx, rx) = oneshot::channel();
let _ = self.tx.send(Command::Stop(tx));
let _ = rx.await;
}
}