diff --git a/Cargo.toml b/Cargo.toml index 29a8bb5..37d8edc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,5 +3,10 @@ name = "vault" version = "0.0.0" edition = "2021" +[features] +tokio = ["dep:tokio"] + [dependencies] rusqlite = "0.28.0" +thiserror = "1.0.37" +tokio = { version = "1.23.0", features = ["sync"], optional = true } diff --git a/src/lib.rs b/src/lib.rs index 8f47352..86c09be 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,9 @@ +#[cfg(feature = "tokio")] +pub mod tokio; + use rusqlite::Connection; -pub trait DbExecute { - fn run(self, conn: &mut Connection) -> rusqlite::Result<()>; -} - -pub trait DbQuery { +pub trait Action { type Result; fn run(self, conn: &mut Connection) -> rusqlite::Result; } diff --git a/src/tokio.rs b/src/tokio.rs new file mode 100644 index 0000000..9920757 --- /dev/null +++ b/src/tokio.rs @@ -0,0 +1,150 @@ +use std::{any::Any, result, thread}; + +use rusqlite::{Connection, Transaction}; +use tokio::sync::{mpsc, oneshot}; + +use crate::Action; + +/// Wrapper trait around [`Action`] that turns `Box` into a `Self` and the +/// action's return type into `Box`. +/// +/// This way, the trait that users of this crate interact with is kept simpler. +trait ActionWrapper { + fn run(self: Box, conn: &mut Connection) -> rusqlite::Result>; +} + +impl ActionWrapper for T +where + T::Result: Send + 'static, +{ + fn run(self: Box, conn: &mut Connection) -> rusqlite::Result> { + let result = (*self).run(conn)?; + Ok(Box::new(result)) + } +} + +/// Command to be sent via the mpsc channel to the vault thread. +enum Command { + Action( + Box, + oneshot::Sender>>, + ), + Stop(oneshot::Sender<()>), +} + +/// Error that can occur during execution of an [`Action`]. +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// The vault thread has stopped. + #[error("vault thread has stopped")] + Stopped, + + /// A [`rusqlite::Error`] occurred while running the action. + #[error("{0}")] + Rusqlite(#[from] rusqlite::Error), +} + +pub type Result = result::Result; + +/// A single database migration. +/// +/// It receives a [`Transaction`] to perform database operations in, and its +/// index in the migration array. The latter might be useful for logging. +/// +/// The transaction spans all migrations currently being performed. If any +/// single migration fails, all migrations are rolled back and the database is +/// unchanged. +/// +/// The migration does not need to update the `user_version` or commit the +/// transaction. +pub type Migration = fn(&mut Transaction, usize) -> rusqlite::Result<()>; + +fn migrate(conn: &mut Connection, migrations: &[Migration]) -> rusqlite::Result<()> { + let mut tx = conn.transaction()?; + + let user_version: usize = + tx.query_row("SELECT * FROM pragma_user_version", [], |r| r.get(0))?; + + let total = migrations.len(); + assert!(user_version <= total, "malformed database schema"); + for (i, migration) in migrations.iter().enumerate().skip(user_version) { + migration(&mut tx, i)?; + } + + tx.pragma_update(None, "user_version", total)?; + tx.commit() +} + +fn run(mut conn: Connection, mut rx: mpsc::UnboundedReceiver) { + while let Some(command) = rx.blocking_recv() { + match command { + Command::Action(action, tx) => { + let result = action.run(&mut conn); + let _ = tx.send(result); + } + Command::Stop(tx) => { + drop(conn); + drop(tx); + break; + } + } + } +} + +#[derive(Clone)] +pub struct TokioVault { + tx: mpsc::UnboundedSender, +} + +impl TokioVault { + /// Launch a new thread to run database queries on, and return a + /// [`TokioVault`] for communication with that thread. + /// + /// It is recommended to set a few pragmas before calling this function, for + /// example: + /// - `journal_mode` to `"wal"` + /// - `foreign_keys` to `true` + /// - `trusted_schema` to `false` + pub fn launch( + mut conn: Connection, + migrations: &[Migration], + prepare: impl FnOnce(&mut Connection) -> rusqlite::Result<()>, + ) -> rusqlite::Result { + migrate(&mut conn, migrations)?; + prepare(&mut conn)?; + + let (tx, rx) = mpsc::unbounded_channel(); + thread::spawn(move || run(conn, rx)); + Ok(Self { tx }) + } + + /// Execute an [`Action`] and return the result. + pub async fn execute(&self, action: A) -> Result + where + A: Action + Send + 'static, + A::Result: Send, + { + let (tx, rx) = oneshot::channel(); + self.tx + .send(Command::Action(Box::new(action), tx)) + .map_err(|_| Error::Stopped)?; + + let result = rx.await.map_err(|_| Error::Stopped)??; + + // The ActionWrapper runs Action::run, which returns Action::Result. It + // then wraps this into Any, which we're now trying to downcast again to + // Action::Result. This should always work. + let result = result.downcast().unwrap(); + + Ok(*result) + } + + /// Stop the vault thread. + /// + /// Returns when the vault has been stopped successfully. + pub async fn stop(&self) { + let (tx, rx) = oneshot::channel(); + let _ = self.tx.send(Command::Stop(tx)); + let _ = rx.await; + } +}