diff --git a/interpreter/Cargo.toml b/interpreter/Cargo.toml index df13fb6..af89b05 100644 --- a/interpreter/Cargo.toml +++ b/interpreter/Cargo.toml @@ -11,7 +11,7 @@ categories = ["compilers"] [dependencies] cel-parser = { path = "../parser", version = "0.7.1 " } thiserror = "1.0.40" -chrono = { version = "0.4.26", default-features = false, features = ["alloc"] } +chrono = { version = "0.4.26", default-features = false, features = ["alloc", "serde"] } nom = "7.1.3" paste = "1.0.14" serde = "1.0.196" diff --git a/interpreter/src/ser.rs b/interpreter/src/ser.rs index ac92279..4c41fb9 100644 --- a/interpreter/src/ser.rs +++ b/interpreter/src/ser.rs @@ -6,7 +6,7 @@ use crate::{objects::Key, Value}; use chrono::FixedOffset; use serde::{ - ser::{self, Impossible}, + ser::{self, Impossible, SerializeStruct}, Serialize, }; use std::{collections::HashMap, fmt::Display, iter::FromIterator, sync::Arc}; @@ -17,10 +17,6 @@ pub struct KeySerializer; /// A wrapper Duration type which allows conversion to [Value::Duration] for /// types using automatic conversion with [serde::Serialize]. /// -/// It is only recommended to use this type with the cel_interpreter -/// [serde::Serializer] implementation, as it may produce unexpected output -/// with other Serializers. -/// /// # Examples /// /// ``` @@ -53,10 +49,11 @@ pub struct Duration(pub chrono::Duration); impl Duration { // Since serde can't natively represent durations, we serialize a special - // struct+field to indicate we want to rebuild the duration in the result. - const SECONDS_FIELD: &str = "$__cel_private_duration_secs"; - const NANOS_FIELD: &str = "$__cel_private_duration_nanos"; + // newtype to indicate we want to rebuild the duration in the result, while + // remaining compatible with most other Serializer implemenations. const NAME: &str = "$__cel_private_Duration"; + const SECS_FIELD: &str = "secs"; + const NANOS_FIELD: &str = "nanos"; } impl From for chrono::Duration { @@ -76,22 +73,28 @@ impl ser::Serialize for Duration { where S: ser::Serializer, { - use serde::ser::SerializeStruct; - - let mut s = serializer.serialize_struct(Self::NAME, 1)?; - s.serialize_field(Self::NANOS_FIELD, &self.0.subsec_nanos())?; - s.serialize_field(Self::SECONDS_FIELD, &self.0.num_seconds())?; - s.end() + // chrono::Duration's Serialize impl isn't stable yet and relies on + // private fields, so attempt to mimic serde's default impl for std + // Duration. + struct DurationProxy(chrono::Duration); + impl Serialize for DurationProxy { + fn serialize( + &self, + serializer: S, + ) -> std::result::Result { + let mut s = serializer.serialize_struct("Duration", 2)?; + s.serialize_field(Duration::SECS_FIELD, &self.0.num_seconds())?; + s.serialize_field(Duration::NANOS_FIELD, &self.0.subsec_nanos())?; + s.end() + } + } + serializer.serialize_newtype_struct(Self::NAME, &DurationProxy(self.0)) } } /// A wrapper Timestamp type which allows conversion to [Value::Timestamp] for /// types using automatic conversion with [serde::Serialize]. /// -/// It is only recommended to use this type with the cel_interpreter -/// [serde::Serializer] implementation, as it may produce unexpected output -/// with other Serializers. -/// /// # Examples /// /// ``` @@ -125,9 +128,9 @@ impl ser::Serialize for Duration { pub struct Timestamp(pub chrono::DateTime); impl Timestamp { - // Since serde can't natively represent durations, we serialize a special - // struct+field to indicate we want to rebuild the duration in the result. - const FIELD: &str = "$__cel_private_timestamp"; + // Since serde can't natively represent timestamps, we serialize a special + // newtype to indicate we want to rebuild the timestamp in the result, + // while remaining compatible with most other Serializer implemenations. const NAME: &str = "$__cel_private_Timestamp"; } @@ -148,11 +151,7 @@ impl ser::Serialize for Timestamp { where S: ser::Serializer, { - use serde::ser::SerializeStruct; - - let mut s = serializer.serialize_struct(Self::NAME, 1)?; - s.serialize_field(Self::FIELD, &self.0.to_rfc3339())?; - s.end() + serializer.serialize_newtype_struct(Self::NAME, &self.0) } } @@ -285,11 +284,15 @@ impl ser::Serializer for Serializer { self.serialize_str(variant) } - fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result + fn serialize_newtype_struct(self, name: &'static str, value: &T) -> Result where T: ?Sized + Serialize, { - value.serialize(self) + match name { + Duration::NAME => value.serialize(TimeSerializer::Duration), + Timestamp::NAME => value.serialize(TimeSerializer::Timestamp), + _ => value.serialize(self), + } } fn serialize_newtype_variant( @@ -338,26 +341,13 @@ impl ser::Serializer for Serializer { fn serialize_map(self, _len: Option) -> Result { Ok(SerializeMap { - kind: SerializeMapKind::Map(HashMap::new()), + map: HashMap::new(), next_key: None, }) } - fn serialize_struct(self, name: &'static str, len: usize) -> Result { - match name { - Duration::NAME => Ok(SerializeMap { - kind: SerializeMapKind::Duration { - seconds: 0, - nanos: 0, - }, - next_key: None, - }), - Timestamp::NAME => Ok(SerializeMap { - kind: SerializeMapKind::Timestamp(Arc::new(String::new())), - next_key: None, - }), - _ => self.serialize_map(Some(len)), - } + fn serialize_struct(self, _name: &'static str, len: usize) -> Result { + self.serialize_map(Some(len)) } fn serialize_struct_variant( @@ -384,21 +374,21 @@ pub struct SerializeTupleVariant { } pub struct SerializeMap { - kind: SerializeMapKind, + map: HashMap, next_key: Option, } -enum SerializeMapKind { - Map(HashMap), - Duration { seconds: i64, nanos: i32 }, - Timestamp(Arc), -} - pub struct SerializeStructVariant { name: String, map: HashMap, } +#[derive(Debug, Default)] +struct SerializeTimestamp { + secs: i64, + nanos: i32, +} + impl ser::SerializeSeq for SerializeVec { type Ok = Value; type Error = SerializationError; @@ -482,10 +472,7 @@ impl ser::SerializeMap for SerializeMap { where T: ?Sized + Serialize, { - let SerializeMapKind::Map(map) = &mut self.kind else { - unreachable!(); - }; - map.insert( + self.map.insert( self.next_key.clone().ok_or_else(|| { SerializationError::InvalidKey( "serialize_value called before serialize_key".to_string(), @@ -497,16 +484,7 @@ impl ser::SerializeMap for SerializeMap { } fn end(self) -> Result { - match self.kind { - SerializeMapKind::Map(map) => Ok(map.into()), - SerializeMapKind::Duration { seconds, nanos } => Ok(chrono::Duration::seconds(seconds) - .checked_add(&chrono::Duration::nanoseconds(nanos.into())) - .unwrap() - .into()), - SerializeMapKind::Timestamp(raw) => Ok(chrono::DateTime::parse_from_rfc3339(&raw) - .map_err(|e| SerializationError::SerdeError(e.to_string()))? - .into()), - } + Ok(self.map.into()) } } @@ -518,50 +496,7 @@ impl ser::SerializeStruct for SerializeMap { where T: ?Sized + Serialize, { - match &mut self.kind { - SerializeMapKind::Map(_) => serde::ser::SerializeMap::serialize_entry(self, key, value), - SerializeMapKind::Duration { seconds, nanos } => match key { - Duration::SECONDS_FIELD => { - let Value::Int(val) = value.serialize(Serializer)? else { - return Err(SerializationError::SerdeError( - "invalid type of value in timestamp struct".to_owned(), - )); - }; - *seconds = val; - Ok(()) - } - Duration::NANOS_FIELD => { - let Value::Int(val) = value.serialize(Serializer)? else { - return Err(SerializationError::SerdeError( - "invalid type of value in timestamp struct".to_owned(), - )); - }; - *nanos = val.try_into().map_err(|_| { - SerializationError::SerdeError( - "timestamp struct nanos field is invalid".to_owned(), - ) - })?; - Ok(()) - } - _ => Err(SerializationError::SerdeError( - "invalid field in duration struct".to_owned(), - )), - }, - SerializeMapKind::Timestamp(raw) => { - if key != Timestamp::FIELD { - return Err(SerializationError::SerdeError( - "invalid field in timestamp struct".to_owned(), - )); - } - let Value::String(val) = value.serialize(Serializer)? else { - return Err(SerializationError::SerdeError( - "invalid type of value in timestamp struct".to_owned(), - )); - }; - *raw = val; - Ok(()) - } - } + serde::ser::SerializeMap::serialize_entry(self, key, value) } fn end(self) -> Result { @@ -588,6 +523,54 @@ impl ser::SerializeStructVariant for SerializeStructVariant { } } +impl ser::SerializeStruct for SerializeTimestamp { + type Ok = Value; + type Error = SerializationError; + fn serialize_field( + &mut self, + key: &'static str, + value: &T, + ) -> std::result::Result<(), Self::Error> + where + T: Serialize, + { + match key { + Duration::SECS_FIELD => { + let Value::Int(val) = value.serialize(Serializer)? else { + return Err(SerializationError::SerdeError( + "invalid type of value in timestamp struct".to_owned(), + )); + }; + self.secs = val; + Ok(()) + } + Duration::NANOS_FIELD => { + let Value::Int(val) = value.serialize(Serializer)? else { + return Err(SerializationError::SerdeError( + "invalid type of value in timestamp struct".to_owned(), + )); + }; + self.nanos = val.try_into().map_err(|_| { + SerializationError::SerdeError( + "timestamp struct nanos field is invalid".to_owned(), + ) + })?; + Ok(()) + } + _ => Err(SerializationError::SerdeError( + "invalid field in duration struct".to_owned(), + )), + } + } + + fn end(self) -> std::result::Result { + Ok(chrono::Duration::seconds(self.secs) + .checked_add(&chrono::Duration::nanoseconds(self.nanos.into())) + .unwrap() + .into()) + } +} + impl ser::Serializer for KeySerializer { type Ok = Key; type Error = SerializationError; @@ -777,6 +760,181 @@ impl ser::Serializer for KeySerializer { } } +#[derive(Debug)] +enum TimeSerializer { + Duration, + Timestamp, +} + +impl ser::Serializer for TimeSerializer { + type Ok = Value; + type Error = SerializationError; + + type SerializeStruct = SerializeTimestamp; + + // Should never be used, so just reuse existing. + type SerializeSeq = SerializeVec; + type SerializeTuple = SerializeVec; + type SerializeTupleStruct = SerializeVec; + type SerializeTupleVariant = SerializeTupleVariant; + type SerializeMap = SerializeMap; + type SerializeStructVariant = SerializeStructVariant; + + fn serialize_struct(self, name: &'static str, len: usize) -> Result { + assert!(matches!(self, Self::Duration { .. })); + assert_eq!(name, "Duration"); + assert_eq!(len, 2); + Ok(SerializeTimestamp::default()) + } + + fn serialize_str(self, v: &str) -> Result { + assert!(matches!(self, Self::Timestamp)); + Ok(chrono::DateTime::parse_from_rfc3339(v) + .map_err(|e| SerializationError::SerdeError(e.to_string()))? + .into()) + } + + fn serialize_bool(self, _v: bool) -> Result { + unreachable!() + } + + fn serialize_i8(self, _v: i8) -> Result { + unreachable!() + } + + fn serialize_i16(self, _v: i16) -> Result { + unreachable!() + } + + fn serialize_i32(self, _v: i32) -> Result { + unreachable!() + } + + fn serialize_i64(self, _v: i64) -> Result { + unreachable!() + } + + fn serialize_u8(self, _v: u8) -> Result { + unreachable!() + } + + fn serialize_u16(self, _v: u16) -> Result { + unreachable!() + } + + fn serialize_u32(self, _v: u32) -> Result { + unreachable!() + } + + fn serialize_u64(self, _v: u64) -> Result { + unreachable!() + } + + fn serialize_f32(self, _v: f32) -> Result { + unreachable!() + } + + fn serialize_f64(self, _v: f64) -> Result { + unreachable!() + } + + fn serialize_char(self, _v: char) -> Result { + unreachable!() + } + + fn serialize_bytes(self, _v: &[u8]) -> Result { + unreachable!() + } + + fn serialize_none(self) -> Result { + unreachable!() + } + + fn serialize_some(self, _value: &T) -> Result + where + T: ?Sized + Serialize, + { + unreachable!() + } + + fn serialize_unit(self) -> Result { + unreachable!() + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + unreachable!() + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + unreachable!() + } + + fn serialize_newtype_struct(self, _name: &'static str, _value: &T) -> Result + where + T: ?Sized + Serialize, + { + unreachable!() + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + unreachable!() + } + + fn serialize_seq(self, _len: Option) -> Result { + unreachable!() + } + + fn serialize_tuple(self, _len: usize) -> Result { + unreachable!() + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + unreachable!() + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + unreachable!() + } + + fn serialize_map(self, _len: Option) -> Result { + unreachable!() + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + unreachable!() + } +} + #[cfg(test)] mod tests { use super::{Duration, Timestamp}; @@ -1062,4 +1220,29 @@ mod tests { .into(); assert_eq!(durations, expected); } + + #[test] + fn test_duration_json() { + // Test that Durations serialize correctly with serde_json. + let durations = [ + Duration(chrono::Duration::milliseconds(1527)), + // Let's test chrono::Duration's particular handling around math + // and negatives. + chrono::Duration::milliseconds(-1527).into(), + (chrono::Duration::seconds(1) - chrono::Duration::nanoseconds(1000000001)).into(), + (chrono::Duration::seconds(-1) + chrono::Duration::nanoseconds(1000000001)).into(), + ]; + let expect = r#"[{"secs":1,"nanos":527000000},{"secs":-1,"nanos":-527000000},{"secs":0,"nanos":-1},{"secs":0,"nanos":1}]"#; + let actual = serde_json::to_string(&durations).unwrap(); + assert_eq!(actual, expect); + } + + #[test] + fn test_timestamp_json() { + // Test that Durations serialize correctly with serde_json. + let timestamp = chrono::DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z").unwrap(); + let expect = r#""2025-01-01T00:00:00Z""#; + let actual = serde_json::to_string(×tamp).unwrap(); + assert_eq!(actual, expect); + } }