diff --git a/CHANGELOG.md b/CHANGELOG.md index 2cf4f23..71ff42a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). Procedure when bumping the version number: + 1. Update dependencies in a separate commit 2. Set version number in `Cargo.toml` 3. Add new section in this changelog @@ -13,6 +14,45 @@ Procedure when bumping the version number: ## Unreleased +### Changed + +- **(breaking)** Bumped `rusqlite` dependency from `0.32` to `0.33` + +## v0.5.0 - 2024-09-04 + +### Changed + +- **(breaking)** Bumped `rusqlite` dependency from `0.31` to `0.32` + +## v0.4.0 - 2024-02-23 + +### Changed + +- **(breaking)** Bumped `rusqlite` dependency from `0.30` to `0.31` + +## v0.3.0 - 2023-12-26 + +### Changed + +- **(breaking)** Bumped `rusqlite` dependency from `0.29` to `0.30` + +## v0.2.0 - 2023-05-14 + +### Added + +- `serde` feature +- `serde::from_row_via_index` +- `serde::from_row_via_name` + +### Changed + +- **(breaking)** + Error handling of `Action`s is now more complex but more powerful. In + particular, `Action`s can now return almost arbitrary errors without nesting + `Result`s like before. +- **(breaking)** Renamed `Action::Result` to `Action::Output` +- **(breaking)** Bumped `rusqlite` dependency from `0.28` to `0.29` + ## v0.1.0 - 2023-02-12 Initial release diff --git a/Cargo.toml b/Cargo.toml index 3a3b297..80d87a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,13 @@ [package] name = "vault" -version = "0.1.0" +version = "0.5.0" edition = "2021" [features] +serde = ["dep:serde"] tokio = ["dep:tokio"] [dependencies] -rusqlite = "0.28.0" -tokio = { version = "1.25.0", features = ["sync"], optional = true } +rusqlite = "0.33.0" +serde = { version = "1.0.209", optional = true } +tokio = { version = "1.40.0", features = ["sync"], optional = true } diff --git a/src/lib.rs b/src/lib.rs index 5d65e1a..edc1392 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,12 +9,17 @@ // Clippy lints #![warn(clippy::use_self)] +#[cfg(feature = "serde")] +pub mod serde; pub mod simple; #[cfg(feature = "tokio")] pub mod tokio; use rusqlite::{Connection, Transaction}; +#[cfg(feature = "serde")] +pub use self::serde::*; + /// An action that can be performed on a [`Connection`]. /// /// Both commands and queries are considered actions. Commands usually have a @@ -23,8 +28,9 @@ use rusqlite::{Connection, Transaction}; /// Actions are usually passed to a vault which will then execute them and /// return the result. The way in which this occurs depends on the vault. pub trait Action { - type Result; - fn run(self, conn: &mut Connection) -> rusqlite::Result; + type Output; + type Error; + fn run(self, conn: &mut Connection) -> Result; } /// A single database migration. diff --git a/src/serde.rs b/src/serde.rs new file mode 100644 index 0000000..d58640d --- /dev/null +++ b/src/serde.rs @@ -0,0 +1,325 @@ +use std::{error, fmt, str::Utf8Error}; + +use rusqlite::{ + types::{FromSqlError, ValueRef}, + Row, +}; +use serde::{ + de::{ + self, value::BorrowedStrDeserializer, DeserializeSeed, Deserializer, MapAccess, SeqAccess, + Visitor, + }, + forward_to_deserialize_any, Deserialize, +}; + +#[derive(Debug)] +enum Error { + ExpectedTupleLikeBaseType, + ExpectedStructLikeBaseType, + Utf8(Utf8Error), + Rusqlite(rusqlite::Error), + Custom(String), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::ExpectedTupleLikeBaseType => write!(f, "expected tuple-like base type"), + Self::ExpectedStructLikeBaseType => write!(f, "expected struct-like base type"), + Self::Utf8(err) => err.fmt(f), + Self::Rusqlite(err) => err.fmt(f), + Self::Custom(msg) => msg.fmt(f), + } + } +} + +impl error::Error for Error {} + +impl de::Error for Error { + fn custom(msg: T) -> Self { + Self::Custom(msg.to_string()) + } +} + +impl From for Error { + fn from(value: Utf8Error) -> Self { + Self::Utf8(value) + } +} + +impl From for Error { + fn from(value: rusqlite::Error) -> Self { + Self::Rusqlite(value) + } +} + +struct ValueRefDeserializer<'de> { + value: ValueRef<'de>, +} + +impl<'de> Deserializer<'de> for ValueRefDeserializer<'de> { + type Error = Error; + + forward_to_deserialize_any! { + i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes byte_buf + unit unit_struct seq tuple tuple_struct map struct identifier + ignored_any + } + + fn deserialize_any>(self, visitor: V) -> Result { + match self.value { + ValueRef::Null => visitor.visit_unit(), + ValueRef::Integer(v) => visitor.visit_i64(v), + ValueRef::Real(v) => visitor.visit_f64(v), + ValueRef::Text(v) => visitor.visit_borrowed_str(std::str::from_utf8(v)?), + ValueRef::Blob(v) => visitor.visit_borrowed_bytes(v), + } + } + + fn deserialize_bool>(self, visitor: V) -> Result { + match self.value { + ValueRef::Integer(0) => visitor.visit_bool(false), + ValueRef::Integer(_) => visitor.visit_bool(true), + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_option>(self, visitor: V) -> Result { + match self.value { + ValueRef::Null => visitor.visit_none(), + _ => visitor.visit_some(self), + } + } + + fn deserialize_newtype_struct>( + self, + _name: &'static str, + visitor: V, + ) -> Result { + visitor.visit_newtype_struct(self) + } + + fn deserialize_enum>( + self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result { + match self.value { + ValueRef::Text(v) => { + let v = BorrowedStrDeserializer::new(std::str::from_utf8(v)?); + v.deserialize_enum(name, variants, visitor) + } + _ => self.deserialize_any(visitor), + } + } +} + +struct IndexedRowDeserializer<'de, 'stmt> { + row: &'de Row<'stmt>, +} + +impl<'de> Deserializer<'de> for IndexedRowDeserializer<'de, '_> { + type Error = Error; + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes + byte_buf option unit unit_struct map enum identifier ignored_any + } + + fn deserialize_any>(self, _visitor: V) -> Result { + Err(Error::ExpectedTupleLikeBaseType) + } + + fn deserialize_newtype_struct>( + self, + _name: &'static str, + visitor: V, + ) -> Result { + visitor.visit_newtype_struct(self) + } + + fn deserialize_seq>(self, visitor: V) -> Result { + visitor.visit_seq(IndexedRowSeq::new(self.row)) + } + + fn deserialize_tuple>( + self, + _len: usize, + visitor: V, + ) -> Result { + self.deserialize_seq(visitor) + } + + fn deserialize_tuple_struct>( + self, + _name: &'static str, + _len: usize, + visitor: V, + ) -> Result { + self.deserialize_seq(visitor) + } + + fn deserialize_struct>( + self, + _name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result { + visitor.visit_map(IndexedRowMap::new(self.row, fields)) + } +} + +struct IndexedRowSeq<'de, 'stmt> { + row: &'de Row<'stmt>, + next_index: usize, +} + +impl<'de, 'stmt> IndexedRowSeq<'de, 'stmt> { + fn new(row: &'de Row<'stmt>) -> Self { + Self { row, next_index: 0 } + } +} + +impl<'de> SeqAccess<'de> for IndexedRowSeq<'de, '_> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: DeserializeSeed<'de>, + { + match self.row.get_ref(self.next_index) { + Ok(value) => { + self.next_index += 1; + seed.deserialize(ValueRefDeserializer { value }).map(Some) + } + Err(rusqlite::Error::InvalidColumnIndex(_)) => Ok(None), + Err(err) => Err(err)?, + } + } +} + +struct IndexedRowMap<'de, 'stmt> { + row: &'de Row<'stmt>, + fields: &'static [&'static str], + next_index: usize, +} + +impl<'de, 'stmt> IndexedRowMap<'de, 'stmt> { + fn new(row: &'de Row<'stmt>, fields: &'static [&'static str]) -> Self { + Self { + row, + fields, + next_index: 0, + } + } +} + +impl<'de> MapAccess<'de> for IndexedRowMap<'de, '_> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: DeserializeSeed<'de>, + { + if let Some(key) = self.fields.get(self.next_index) { + self.next_index += 1; + seed.deserialize(BorrowedStrDeserializer::new(key)) + .map(Some) + } else { + Ok(None) + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + let value = self.row.get_ref(self.next_index - 1)?; + seed.deserialize(ValueRefDeserializer { value }) + } +} + +pub fn from_row_via_index<'de, T>(row: &'de Row<'_>) -> rusqlite::Result +where + T: Deserialize<'de>, +{ + T::deserialize(IndexedRowDeserializer { row }) + .map_err(|err| FromSqlError::Other(Box::new(err)).into()) +} + +struct NamedRowDeserializer<'de, 'stmt> { + row: &'de Row<'stmt>, +} + +impl<'de> Deserializer<'de> for NamedRowDeserializer<'de, '_> { + type Error = Error; + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes + byte_buf option unit unit_struct newtype_struct seq tuple tuple_struct + map enum identifier ignored_any + } + + fn deserialize_any>(self, _visitor: V) -> Result { + Err(Error::ExpectedStructLikeBaseType) + } + + fn deserialize_struct>( + self, + _name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result { + visitor.visit_map(NamedRowMap::new(self.row, fields)) + } +} + +struct NamedRowMap<'de, 'stmt> { + row: &'de Row<'stmt>, + fields: &'static [&'static str], + next_index: usize, +} + +impl<'de, 'stmt> NamedRowMap<'de, 'stmt> { + fn new(row: &'de Row<'stmt>, fields: &'static [&'static str]) -> Self { + Self { + row, + fields, + next_index: 0, + } + } +} + +impl<'de> MapAccess<'de> for NamedRowMap<'de, '_> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: DeserializeSeed<'de>, + { + if let Some(key) = self.fields.get(self.next_index) { + self.next_index += 1; + seed.deserialize(BorrowedStrDeserializer::new(key)) + .map(Some) + } else { + Ok(None) + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + let value = self.row.get_ref(self.next_index - 1)?; + seed.deserialize(ValueRefDeserializer { value }) + } +} + +pub fn from_row_via_name<'de, T>(row: &'de Row<'_>) -> rusqlite::Result +where + T: Deserialize<'de>, +{ + T::deserialize(NamedRowDeserializer { row }) + .map_err(|err| FromSqlError::Other(Box::new(err)).into()) +} diff --git a/src/simple.rs b/src/simple.rs index e800073..8dcedd0 100644 --- a/src/simple.rs +++ b/src/simple.rs @@ -34,13 +34,13 @@ impl SimpleVault { /// /// The `prepare` parameter allows access to the database after all /// migrations have occurred. This parameter could be replaced by executing + /// an [`Action`] performing the same operations. /// /// It is recommended to set a few pragmas before calling this function, for /// example: /// - `journal_mode` to `"wal"` /// - `foreign_keys` to `true` /// - `trusted_schema` to `false` - /// an [`Action`] performing the same operations. pub fn new_and_prepare( mut conn: Connection, migrations: &[Migration], @@ -52,11 +52,7 @@ impl SimpleVault { } /// Execute an [`Action`] and return the result. - pub fn execute(&mut self, action: A) -> rusqlite::Result - where - A: Action + Send + 'static, - A::Result: Send, - { + pub fn execute(&mut self, action: A) -> Result { action.run(&mut self.0) } } diff --git a/src/tokio.rs b/src/tokio.rs index 822ce49..564f151 100644 --- a/src/tokio.rs +++ b/src/tokio.rs @@ -1,6 +1,6 @@ //! A vault for use with [`tokio`]. -use std::{any::Any, error, fmt, result, thread}; +use std::{any::Any, error, fmt, thread}; use rusqlite::Connection; use tokio::sync::{mpsc, oneshot}; @@ -12,16 +12,25 @@ use crate::{Action, Migration}; /// /// This way, the trait that users of this crate interact with is kept simpler. trait ActionWrapper { - fn run(self: Box, conn: &mut Connection) -> rusqlite::Result>; + fn run( + self: Box, + conn: &mut Connection, + ) -> Result, Box>; } impl ActionWrapper for T where - T::Result: Send + 'static, + T::Output: Send + 'static, + T::Error: Send + 'static, { - fn run(self: Box, conn: &mut Connection) -> rusqlite::Result> { - let result = (*self).run(conn)?; - Ok(Box::new(result)) + fn run( + self: Box, + conn: &mut Connection, + ) -> Result, Box> { + match (*self).run(conn) { + Ok(result) => Ok(Box::new(result)), + Err(err) => Err(Box::new(err)), + } } } @@ -29,46 +38,38 @@ where enum Command { Action( Box, - oneshot::Sender>>, + oneshot::Sender, Box>>, ), Stop(oneshot::Sender<()>), } /// Error that can occur during execution of an [`Action`]. #[derive(Debug)] -pub enum Error { +pub enum Error { /// The vault's thread has been stopped and its sqlite connection closed. Stopped, - /// A [`rusqlite::Error`] occurred while running the action. - Rusqlite(rusqlite::Error), + /// An error was returned by the [`Action`]. + Action(E), } -impl fmt::Display for Error { +impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Stopped => "vault has been stopped".fmt(f), - Self::Rusqlite(err) => err.fmt(f), + Self::Action(err) => err.fmt(f), } } } -impl error::Error for Error { +impl error::Error for Error { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { Self::Stopped => None, - Self::Rusqlite(err) => Some(err), + Self::Action(err) => err.source(), } } } -impl From for Error { - fn from(value: rusqlite::Error) -> Self { - Self::Rusqlite(value) - } -} - -pub type Result = result::Result; - fn run(mut conn: Connection, mut rx: mpsc::UnboundedReceiver) { while let Some(command) = rx.blocking_recv() { match command { @@ -132,24 +133,34 @@ impl TokioVault { } /// Execute an [`Action`] and return the result. - pub async fn execute(&self, action: A) -> Result + pub async fn execute(&self, action: A) -> Result> where A: Action + Send + 'static, - A::Result: Send, + A::Output: Send, + A::Error: Send, { let (tx, rx) = oneshot::channel(); self.tx .send(Command::Action(Box::new(action), tx)) .map_err(|_| Error::Stopped)?; - let result = rx.await.map_err(|_| Error::Stopped)??; + let result = rx.await.map_err(|_| Error::Stopped)?; - // The ActionWrapper runs Action::run, which returns Action::Result. It - // then wraps this into Any, which we're now trying to downcast again to - // Action::Result. This should always work. - let result = result.downcast().unwrap(); - - Ok(*result) + // The ActionWrapper runs Action::run, which returns + // Result. It then wraps the + // Action::Result and Action::Error into Any, which we're now trying to + // downcast again to Action::Result and Action::Error. This should + // always work. + match result { + Ok(result) => { + let result = *result.downcast::().unwrap(); + Ok(result) + } + Err(err) => { + let err = *err.downcast::().unwrap(); + Err(Error::Action(err)) + } + } } /// Stop the vault's thread and close its sqlite connection.