diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b59147..bf33087 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ Procedure when bumping the version number: ### Added - `serde` feature +- `serde::from_row_via_index` ### Changed - **(breaking)** diff --git a/src/lib.rs b/src/lib.rs index bf7ecf9..edc1392 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,12 +9,17 @@ // Clippy lints #![warn(clippy::use_self)] +#[cfg(feature = "serde")] +pub mod serde; pub mod simple; #[cfg(feature = "tokio")] pub mod tokio; use rusqlite::{Connection, Transaction}; +#[cfg(feature = "serde")] +pub use self::serde::*; + /// An action that can be performed on a [`Connection`]. /// /// Both commands and queries are considered actions. Commands usually have a diff --git a/src/serde.rs b/src/serde.rs new file mode 100644 index 0000000..a7eb5c1 --- /dev/null +++ b/src/serde.rs @@ -0,0 +1,247 @@ +use std::{error, fmt, str::Utf8Error}; + +use rusqlite::{ + types::{FromSqlError, ValueRef}, + Row, +}; +use serde::{ + de::{ + self, value::BorrowedStrDeserializer, DeserializeSeed, Deserializer, MapAccess, SeqAccess, + Visitor, + }, + forward_to_deserialize_any, Deserialize, +}; + +#[derive(Debug)] +enum Error { + ExpectedTupleLikeBaseType, + Utf8(Utf8Error), + Rusqlite(rusqlite::Error), + Custom(String), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::ExpectedTupleLikeBaseType => write!(f, "expected tuple-like base type"), + Self::Utf8(err) => err.fmt(f), + Self::Rusqlite(err) => err.fmt(f), + Self::Custom(msg) => msg.fmt(f), + } + } +} + +impl error::Error for Error {} + +impl de::Error for Error { + fn custom(msg: T) -> Self { + Self::Custom(msg.to_string()) + } +} + +impl From for Error { + fn from(value: Utf8Error) -> Self { + Self::Utf8(value) + } +} + +impl From for Error { + fn from(value: rusqlite::Error) -> Self { + Self::Rusqlite(value) + } +} + +struct ValueRefDeserializer<'de> { + value: ValueRef<'de>, +} + +impl<'de> Deserializer<'de> for ValueRefDeserializer<'de> { + type Error = Error; + + forward_to_deserialize_any! { + i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes byte_buf + unit unit_struct seq tuple tuple_struct map struct identifier + ignored_any + } + + fn deserialize_any>(self, visitor: V) -> Result { + match self.value { + ValueRef::Null => visitor.visit_unit(), + ValueRef::Integer(v) => visitor.visit_i64(v), + ValueRef::Real(v) => visitor.visit_f64(v), + ValueRef::Text(v) => visitor.visit_borrowed_str(std::str::from_utf8(v)?), + ValueRef::Blob(v) => visitor.visit_borrowed_bytes(v), + } + } + + fn deserialize_bool>(self, visitor: V) -> Result { + match self.value { + ValueRef::Integer(0) => visitor.visit_bool(false), + ValueRef::Integer(_) => visitor.visit_bool(true), + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_option>(self, visitor: V) -> Result { + match self.value { + ValueRef::Null => visitor.visit_none(), + _ => visitor.visit_some(self), + } + } + + fn deserialize_newtype_struct>( + self, + _name: &'static str, + visitor: V, + ) -> Result { + visitor.visit_newtype_struct(self) + } + + fn deserialize_enum>( + self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result { + match self.value { + ValueRef::Text(v) => { + let v = BorrowedStrDeserializer::new(std::str::from_utf8(v)?); + v.deserialize_enum(name, variants, visitor) + } + _ => self.deserialize_any(visitor), + } + } +} + +struct IndexedRowDeserializer<'de, 'stmt> { + row: &'de Row<'stmt>, +} + +impl<'de> Deserializer<'de> for IndexedRowDeserializer<'de, '_> { + type Error = Error; + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes + byte_buf option unit unit_struct map enum identifier ignored_any + } + + fn deserialize_any>(self, _visitor: V) -> Result { + Err(Error::ExpectedTupleLikeBaseType) + } + + fn deserialize_newtype_struct>( + self, + _name: &'static str, + visitor: V, + ) -> Result { + visitor.visit_newtype_struct(self) + } + + fn deserialize_seq>(self, visitor: V) -> Result { + visitor.visit_seq(IndexedRowSeq::new(self.row)) + } + + fn deserialize_tuple>( + self, + _len: usize, + visitor: V, + ) -> Result { + self.deserialize_seq(visitor) + } + + fn deserialize_tuple_struct>( + self, + _name: &'static str, + _len: usize, + visitor: V, + ) -> Result { + self.deserialize_seq(visitor) + } + + fn deserialize_struct>( + self, + _name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result { + visitor.visit_map(IndexedRowMap::new(self.row, fields)) + } +} + +struct IndexedRowSeq<'de, 'stmt> { + row: &'de Row<'stmt>, + next_index: usize, +} + +impl<'de, 'stmt> IndexedRowSeq<'de, 'stmt> { + fn new(row: &'de Row<'stmt>) -> Self { + Self { row, next_index: 0 } + } +} + +impl<'de> SeqAccess<'de> for IndexedRowSeq<'de, '_> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: DeserializeSeed<'de>, + { + match self.row.get_ref(self.next_index) { + Ok(value) => { + self.next_index += 1; + seed.deserialize(ValueRefDeserializer { value }).map(Some) + } + Err(rusqlite::Error::InvalidColumnIndex(_)) => Ok(None), + Err(err) => Err(err)?, + } + } +} + +struct IndexedRowMap<'de, 'stmt> { + row: &'de Row<'stmt>, + fields: &'static [&'static str], + next_index: usize, +} + +impl<'de, 'stmt> IndexedRowMap<'de, 'stmt> { + fn new(row: &'de Row<'stmt>, fields: &'static [&'static str]) -> Self { + Self { + row, + fields, + next_index: 0, + } + } +} + +impl<'de> MapAccess<'de> for IndexedRowMap<'de, '_> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: DeserializeSeed<'de>, + { + if let Some(key) = self.fields.get(self.next_index) { + self.next_index += 1; + seed.deserialize(BorrowedStrDeserializer::new(key)) + .map(Some) + } else { + Ok(None) + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + let value = self.row.get_ref(self.next_index - 1)?; + seed.deserialize(ValueRefDeserializer { value }) + } +} + +pub fn from_row_via_index<'de, T>(row: &'de Row<'_>) -> rusqlite::Result +where + T: Deserialize<'de>, +{ + T::deserialize(IndexedRowDeserializer { row }) + .map_err(|err| FromSqlError::Other(Box::new(err)).into()) +}