diff --git a/CHANGELOG.md b/CHANGELOG.md index bf33087..b73cf50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ Procedure when bumping the version number: ### Added - `serde` feature - `serde::from_row_via_index` +- `serde::from_row_via_name` ### Changed - **(breaking)** diff --git a/src/serde.rs b/src/serde.rs index a7eb5c1..d58640d 100644 --- a/src/serde.rs +++ b/src/serde.rs @@ -15,6 +15,7 @@ use serde::{ #[derive(Debug)] enum Error { ExpectedTupleLikeBaseType, + ExpectedStructLikeBaseType, Utf8(Utf8Error), Rusqlite(rusqlite::Error), Custom(String), @@ -24,6 +25,7 @@ impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::ExpectedTupleLikeBaseType => write!(f, "expected tuple-like base type"), + Self::ExpectedStructLikeBaseType => write!(f, "expected struct-like base type"), Self::Utf8(err) => err.fmt(f), Self::Rusqlite(err) => err.fmt(f), Self::Custom(msg) => msg.fmt(f), @@ -245,3 +247,79 @@ where T::deserialize(IndexedRowDeserializer { row }) .map_err(|err| FromSqlError::Other(Box::new(err)).into()) } + +struct NamedRowDeserializer<'de, 'stmt> { + row: &'de Row<'stmt>, +} + +impl<'de> Deserializer<'de> for NamedRowDeserializer<'de, '_> { + type Error = Error; + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes + byte_buf option unit unit_struct newtype_struct seq tuple tuple_struct + map enum identifier ignored_any + } + + fn deserialize_any>(self, _visitor: V) -> Result { + Err(Error::ExpectedStructLikeBaseType) + } + + fn deserialize_struct>( + self, + _name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result { + visitor.visit_map(NamedRowMap::new(self.row, fields)) + } +} + +struct NamedRowMap<'de, 'stmt> { + row: &'de Row<'stmt>, + fields: &'static [&'static str], + next_index: usize, +} + +impl<'de, 'stmt> NamedRowMap<'de, 'stmt> { + fn new(row: &'de Row<'stmt>, fields: &'static [&'static str]) -> Self { + Self { + row, + fields, + next_index: 0, + } + } +} + +impl<'de> MapAccess<'de> for NamedRowMap<'de, '_> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: DeserializeSeed<'de>, + { + if let Some(key) = self.fields.get(self.next_index) { + self.next_index += 1; + seed.deserialize(BorrowedStrDeserializer::new(key)) + .map(Some) + } else { + Ok(None) + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + let value = self.row.get_ref(self.next_index - 1)?; + seed.deserialize(ValueRefDeserializer { value }) + } +} + +pub fn from_row_via_name<'de, T>(row: &'de Row<'_>) -> rusqlite::Result +where + T: Deserialize<'de>, +{ + T::deserialize(NamedRowDeserializer { row }) + .map_err(|err| FromSqlError::Other(Box::new(err)).into()) +}