diff --git a/src/api/types.rs b/src/api/types.rs index f56f0a1..1548d42 100644 --- a/src/api/types.rs +++ b/src/api/types.rs @@ -7,6 +7,8 @@ #![allow(clippy::use_self)] use std::fmt; +use std::num::ParseIntError; +use std::str::FromStr; use serde::{de, ser, Deserialize, Serialize}; use serde_json::Value; @@ -308,8 +310,8 @@ impl Snowflake { pub const MAX: Self = Snowflake(i64::MAX as u64); } -impl Serialize for Snowflake { - fn serialize(&self, serializer: S) -> Result { +impl fmt::Display for Snowflake { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // Convert u64 to base36 string let mut n = self.0; let mut result = String::with_capacity(13); @@ -318,7 +320,34 @@ impl Serialize for Snowflake { result.insert(0, c); n /= 36; } - result.serialize(serializer) + f.write_str(&result) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ParseSnowflakeError { + #[error("invalid length: expected 13 bytes, got {0}")] + InvalidLength(usize), + #[error("{0}")] + ParseIntError(#[from] ParseIntError), +} + +impl FromStr for Snowflake { + type Err = ParseSnowflakeError; + + fn from_str(s: &str) -> Result { + // Convert base36 string to u64 + if s.len() != 13 { + return Err(ParseSnowflakeError::InvalidLength(s.len())); + } + let n = u64::from_str_radix(s, 36)?; + Ok(Snowflake(n)) + } +} + +impl Serialize for Snowflake { + fn serialize(&self, serializer: S) -> Result { + format!("{}", self).serialize(serializer) } } @@ -332,13 +361,12 @@ impl de::Visitor<'_> for SnowflakeVisitor { } fn visit_str(self, v: &str) -> Result { - // Convert base36 string to u64 - if v.len() != 13 { - return Err(E::invalid_length(v.len(), &self)); - } - let n = u64::from_str_radix(v, 36) - .map_err(|_| E::invalid_value(de::Unexpected::Str(v), &self))?; - Ok(Snowflake(n)) + v.parse().map_err(|e| match e { + ParseSnowflakeError::InvalidLength(len) => E::invalid_length(len, &self), + ParseSnowflakeError::ParseIntError(_) => { + E::invalid_value(de::Unexpected::Str(v), &self) + } + }) } }