From 8b68f2cb262aee247b332bbc67aebe0218060d28 Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Mon, 4 Jul 2022 10:38:58 -0700 Subject: [PATCH] Borrow field identifiers from deserializer, instead of input --- src/serdes.rs | 93 ++++++++++++++++++++++++++++++++++++++++++++- src/serdes/array.rs | 79 ++++++++++++++++++-------------------- src/serdes/slice.rs | 56 +++++++++------------------ src/serdes/utils.rs | 43 ++++++++++++++++----- 4 files changed, 182 insertions(+), 89 deletions(-) diff --git a/src/serdes.rs b/src/serdes.rs index dc454c09..5e35ff64 100644 --- a/src/serdes.rs +++ b/src/serdes.rs @@ -5,8 +5,21 @@ mod array; mod slice; mod utils; -/// A list of fields in the `BitSeq` and `BitArr` transport format. -static FIELDS: &[&str] = &["order", "head", "bits", "data"]; +use core::{ + any, + fmt::{ + self, + Formatter, + }, + marker::PhantomData, +}; + +use serde::de::{ + Deserialize, + Deserializer, + Unexpected, + Visitor, +}; /// A result of serialization. type Result = core::result::Result< @@ -14,6 +27,82 @@ type Result = core::result::Result< ::Error, >; +/// A list of fields in the `BitSeq` and `BitArr` transport format. +static FIELDS: &[&str] = &["order", "head", "bits", "data"]; + +enum Field { + Order, + Head, + Bits, + Data, +} + +struct FieldVisitor; + +impl<'de> Deserialize<'de> for Field { + fn deserialize(deserializer: D) -> core::result::Result + where D: Deserializer<'de> { + deserializer.deserialize_identifier(FieldVisitor) + } +} + +impl<'de> Visitor<'de> for FieldVisitor { + type Value = Field; + + fn expecting(&self, fmt: &mut Formatter) -> fmt::Result { + fmt.write_str("field_identifier") + } + + fn visit_str(self, value: &str) -> core::result::Result + where E: serde::de::Error { + match value { + "order" => Ok(Field::Order), + "head" => Ok(Field::Head), + "bits" => Ok(Field::Bits), + "data" => Ok(Field::Data), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } +} + +/// A zero-sized type that deserializes from any string as long as it is equal +/// to `any::type_name::()`. +struct TypeName(PhantomData); + +impl TypeName { + fn new() -> Self { + TypeName(PhantomData) + } +} + +impl<'de, O> Deserialize<'de> for TypeName { + fn deserialize(deserializer: D) -> core::result::Result + where D: Deserializer<'de> { + deserializer.deserialize_str(Self::new()) + } +} + +impl<'de, O> Visitor<'de> for TypeName { + type Value = Self; + + fn expecting(&self, fmt: &mut Formatter) -> fmt::Result { + write!(fmt, "the string {:?}", any::type_name::()) + } + + fn visit_str(self, value: &str) -> core::result::Result + where E: serde::de::Error { + if value == any::type_name::() { + Ok(self) + } + else { + Err(serde::de::Error::invalid_value( + Unexpected::Str(value), + &self, + )) + } + } +} + #[cfg(test)] mod tests { use serde::{ diff --git a/src/serdes/array.rs b/src/serdes/array.rs index 666ff325..f06cfe21 100644 --- a/src/serdes/array.rs +++ b/src/serdes/array.rs @@ -6,7 +6,6 @@ use core::{ self, Formatter, }, - marker::PhantomData, }; use serde::{ @@ -28,6 +27,8 @@ use serde::{ use super::{ utils::Array, + Field, + TypeName, FIELDS, }; use crate::{ @@ -88,11 +89,7 @@ where fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { deserializer - .deserialize_struct( - "BitArr", - FIELDS, - BitArrVisitor::<'de, T, O, 1>::THIS, - ) + .deserialize_struct("BitArr", FIELDS, BitArrVisitor::::THIS) .map(|BitArray { data: [elem], .. }| BitArray::new(elem)) } } @@ -108,22 +105,19 @@ where deserializer.deserialize_struct( "BitArr", FIELDS, - BitArrVisitor::<'de, T, O, N>::THIS, + BitArrVisitor::::THIS, ) } } /// Assists in deserialization of a static `BitArr`. -struct BitArrVisitor<'de, T, O, const N: usize> +struct BitArrVisitor where T: BitStore, O: BitOrder, - Array: Deserialize<'de>, { - /// This produces a bit-array value during its work. - typ: PhantomData>, /// The deserialized bit-ordering string. - order: Option<&'de str>, + order: Option>, /// The deserialized head-bit index. This must be zero; it is used for /// consistency with `BitSeq` and to carry `T::Mem` information. head: Option>, @@ -133,7 +127,7 @@ where data: Option>, } -impl<'de, T, O, const N: usize> BitArrVisitor<'de, T, O, N> +impl<'de, T, O, const N: usize> BitArrVisitor where T: BitStore, O: BitOrder, @@ -141,7 +135,6 @@ where { /// A new visitor in its ready condition. const THIS: Self = Self { - typ: PhantomData, order: None, head: None, bits: None, @@ -151,15 +144,11 @@ where /// Attempts to assemble deserialized components into an output value. fn assemble(mut self) -> Result, E> where E: Error { - let order = - self.order.take().ok_or_else(|| E::missing_field("order"))?; + self.order.take().ok_or_else(|| E::missing_field("order"))?; let head = self.head.take().ok_or_else(|| E::missing_field("head"))?; let bits = self.bits.take().ok_or_else(|| E::missing_field("bits"))?; let data = self.data.take().ok_or_else(|| E::missing_field("data"))?; - if order != any::type_name::() { - return Err(E::invalid_type(Unexpected::Str(order), &self)); - } if head != BitIdx::MIN { return Err(E::invalid_value( Unexpected::Unsigned(head.into_inner() as u64), @@ -175,7 +164,7 @@ where } } -impl<'de, T, O, const N: usize> Visitor<'de> for BitArrVisitor<'de, T, O, N> +impl<'de, T, O, const N: usize> Visitor<'de> for BitArrVisitor where T: BitStore, O: BitOrder, @@ -217,32 +206,28 @@ where fn visit_map(mut self, mut map: V) -> Result where V: MapAccess<'de> { - while let Some(key) = map.next_key::<&'de str>()? { + while let Some(key) = map.next_key()? { match key { - "order" => { + Field::Order => { if self.order.replace(map.next_value()?).is_some() { return Err(::duplicate_field("order")); } }, - "head" => { + Field::Head => { if self.head.replace(map.next_value()?).is_some() { return Err(::duplicate_field("head")); } }, - "bits" => { + Field::Bits => { if self.bits.replace(map.next_value()?).is_some() { return Err(::duplicate_field("bits")); } }, - "data" => { + Field::Data => { if self.data.replace(map.next_value()?).is_some() { return Err(::duplicate_field("data")); } }, - f => { - let _ = map.next_value::<()>(); - return Err(::unknown_field(f, FIELDS)); - }, } } @@ -279,6 +264,10 @@ mod tests { let array3 = serde_json::from_str::(&json)?; assert_eq!(array, array3); + let json_value = serde_json::to_value(&array)?; + let array4 = serde_json::from_value::(json_value)?; + assert_eq!(array, array4); + type BA2 = BitArray; let array = BA2::new(44203); @@ -290,6 +279,10 @@ mod tests { let array3 = serde_json::from_str::(&json)?; assert_eq!(array, array3); + let json_value = serde_json::to_value(&array)?; + let array4 = serde_json::from_value::(json_value)?; + assert_eq!(array, array4); + Ok(()) } @@ -341,9 +334,21 @@ mod tests { #[cfg(feature = "alloc")] fn errors() { type BA = BitArr!(for 8, in u8, Msb0); - let tokens = &mut [ + let mut tokens = vec![ Token::Seq { len: Some(4) }, Token::BorrowedStr(any::type_name::()), + ]; + + assert_de_tokens_error::( + &tokens, + &format!( + "invalid value: string \"{}\", expected the string \"{}\"", + any::type_name::(), + any::type_name::(), + ), + ); + + tokens.extend([ Token::Seq { len: Some(2) }, Token::U8(8), Token::U8(0), @@ -353,24 +358,18 @@ mod tests { Token::U8(0), Token::TupleEnd, Token::SeqEnd, - ]; - - assert_de_tokens_error::( - tokens, - "invalid type: string \"bitvec::order::Msb0\", expected a \ - `BitArray<[u8; 1], bitvec::order::Lsb0>`", - ); + ]); tokens[6] = Token::U64(7); assert_de_tokens_error::( - tokens, + &tokens, "invalid length 7, expected a `BitArray<[u8; 1], \ bitvec::order::Msb0>`", ); tokens[4] = Token::U8(1); assert_de_tokens_error::( - tokens, + &tokens, "invalid value: integer `1`, expected `BitArray` must have a \ head-bit of `0`", ); @@ -382,8 +381,6 @@ mod tests { len: 2, }, Token::BorrowedStr("placeholder"), - Token::Unit, - Token::StructEnd, ], &format!( "unknown field `placeholder`, expected one of `{}`", diff --git a/src/serdes/slice.rs b/src/serdes/slice.rs index c97ee1ff..395f1095 100644 --- a/src/serdes/slice.rs +++ b/src/serdes/slice.rs @@ -18,7 +18,6 @@ use serde::{ Error, MapAccess, SeqAccess, - Unexpected, Visitor, }, ser::{ @@ -29,7 +28,11 @@ use serde::{ }; use wyz::comu::Const; -use super::FIELDS; +use super::{ + Field, + TypeName, + FIELDS, +}; #[cfg(feature = "alloc")] use crate::{ boxed::BitBox, @@ -102,7 +105,7 @@ where O: BitOrder deserializer.deserialize_struct( "BitSeq", FIELDS, - BitSeqVisitor::<'de, u8, O, &'de [u8], Self, _>::new( + BitSeqVisitor::::new( |data, head, bits| unsafe { BitSpan::new(data.as_ptr().into_address(), head, bits) .map(|span| BitSpan::into_bitslice_ref(span)) @@ -138,7 +141,7 @@ where deserializer.deserialize_struct( "BitSeq", FIELDS, - BitSeqVisitor::<'de, T, O, Vec, Self, _>::new( + BitSeqVisitor::, Self, _>::new( |vec, head, bits| unsafe { let addr = vec.as_ptr().into_address(); let mut bv = BitVec::try_from_vec(vec).map_err(|_| { @@ -155,19 +158,16 @@ where } /// Assists in deserialization of a dynamic `BitSeq`. -struct BitSeqVisitor<'de, T, O, In, Out, Func> +struct BitSeqVisitor where - T: 'de + BitStore, + T: BitStore, O: BitOrder, - In: Deserialize<'de>, Func: FnOnce(In, BitIdx, usize) -> Result>, { - /// This produces a bit-slice reference during its work, - typ: PhantomData<&'de BitSlice>, /// As well as a final output value. out: PhantomData>>, /// The deserialized bit-ordering string. - order: Option<&'de str>, + order: Option>, /// The deserialized head-bit index. head: Option>, /// The deserialized bit-count. @@ -179,7 +179,7 @@ where func: Func, } -impl<'de, T, O, In, Out, Func> BitSeqVisitor<'de, T, O, In, Out, Func> +impl<'de, T, O, In, Out, Func> BitSeqVisitor where T: 'de + BitStore, O: BitOrder, @@ -189,7 +189,6 @@ where /// Creates a new visitor with a given transform functor. fn new(func: Func) -> Self { Self { - typ: PhantomData, out: PhantomData, order: None, head: None, @@ -202,21 +201,17 @@ where /// Attempts to assemble deserialized components into an output value. fn assemble(mut self) -> Result where E: Error { - let order = - self.order.take().ok_or_else(|| E::missing_field("order"))?; + self.order.take().ok_or_else(|| E::missing_field("order"))?; let head = self.head.take().ok_or_else(|| E::missing_field("head"))?; let bits = self.bits.take().ok_or_else(|| E::missing_field("bits"))?; let data = self.data.take().ok_or_else(|| E::missing_field("data"))?; - if order != any::type_name::() { - return Err(E::invalid_type(Unexpected::Str(order), &self)); - } (self.func)(data, head, bits as usize).map_err(|_| todo!()) } } impl<'de, T, O, In, Out, Func> Visitor<'de> - for BitSeqVisitor<'de, T, O, In, Out, Func> + for BitSeqVisitor where T: 'de + BitStore, O: BitOrder, @@ -258,32 +253,28 @@ where fn visit_map(mut self, mut map: V) -> Result where V: MapAccess<'de> { - while let Some(key) = map.next_key::<&'de str>()? { + while let Some(key) = map.next_key()? { match key { - "order" => { + Field::Order => { if self.order.replace(map.next_value()?).is_some() { return Err(::duplicate_field("order")); } }, - "head" => { + Field::Head => { if self.head.replace(map.next_value()?).is_some() { return Err(::duplicate_field("head")); } }, - "bits" => { + Field::Bits => { if self.bits.replace(map.next_value()?).is_some() { return Err(::duplicate_field("bits")); } }, - "data" => { + Field::Data => { if self.data.replace(map.next_value()?).is_some() { return Err(::duplicate_field("data")); } }, - f => { - let _ = map.next_value::<()>(); - return Err(::unknown_field(f, FIELDS)); - }, } } @@ -370,16 +361,9 @@ mod tests { &[ Token::Seq { len: Some(4) }, Token::BorrowedStr(any::type_name::()), - Token::Seq { len: Some(2) }, - Token::U8(8), - Token::U8(1), - Token::SeqEnd, - Token::U64(9), - Token::BorrowedBytes(&[0x3C, 0xA5]), - Token::SeqEnd, ], &format!( - "invalid type: string \"{}\", expected a `BitSlice`", + "invalid value: string \"{}\", expected the string \"{}\"", any::type_name::(), any::type_name::(), ), @@ -392,8 +376,6 @@ mod tests { len: 1, }, Token::BorrowedStr("unknown"), - Token::BorrowedStr("field"), - Token::StructEnd, ], &format!( "unknown field `unknown`, expected one of `{}`", diff --git a/src/serdes/utils.rs b/src/serdes/utils.rs index d63cd7fa..a947ab15 100644 --- a/src/serdes/utils.rs +++ b/src/serdes/utils.rs @@ -45,6 +45,37 @@ use crate::{ /// Fields used in the `BitIdx` transport format. static FIELDS: &[&str] = &["width", "index"]; +enum Field { + Width, + Index, +} + +struct FieldVisitor; + +impl<'de> Deserialize<'de> for Field { + fn deserialize(deserializer: D) -> Result + where D: Deserializer<'de> { + deserializer.deserialize_identifier(FieldVisitor) + } +} + +impl<'de> Visitor<'de> for FieldVisitor { + type Value = Field; + + fn expecting(&self, fmt: &mut Formatter) -> fmt::Result { + fmt.write_str("field identifier") + } + + fn visit_str(self, value: &str) -> Result + where E: serde::de::Error { + match value { + "width" => Ok(Field::Width), + "index" => Ok(Field::Index), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } +} + impl Serialize for BitIdx where R: BitRegister { @@ -246,22 +277,18 @@ where R: BitRegister let mut width = None; let mut index = None; - while let Some(key) = map.next_key::<&'de str>()? { + while let Some(key) = map.next_key()? { match key { - "width" => { + Field::Width => { if width.replace(map.next_value::()?).is_some() { return Err(::duplicate_field("width")); } }, - "index" => { + Field::Index => { if index.replace(map.next_value::()?).is_some() { return Err(::duplicate_field("index")); } }, - f => { - let _ = map.next_value::<()>(); - return Err(::unknown_field(f, FIELDS)); - }, } } @@ -362,8 +389,6 @@ mod tests { len: 1, }, Token::BorrowedStr("unknown"), - Token::BorrowedStr("field"), - Token::StructEnd, ], "unknown field `unknown`, expected `width` or `index`", );