Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Borrow field identifiers from deserializer, instead of input #185

Merged
merged 1 commit into from
Jul 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 91 additions & 2 deletions src/serdes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,104 @@ mod array;
mod slice;
mod utils;

/// A list of fields in the `BitSeq` and `BitArr` transport format.
static FIELDS: &[&str] = &["order", "head", "bits", "data"];
use core::{
any,
fmt::{
self,
Formatter,
},
marker::PhantomData,
};

use serde::de::{
Deserialize,
Deserializer,
Unexpected,
Visitor,
};

/// A result of serialization.
type Result<S> = core::result::Result<
<S as serde::Serializer>::Ok,
<S as serde::Serializer>::Error,
>;

/// A list of fields in the `BitSeq` and `BitArr` transport format.
static FIELDS: &[&str] = &["order", "head", "bits", "data"];

enum Field {
Order,
Head,
Bits,
Data,
}

struct FieldVisitor;

impl<'de> Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
where D: Deserializer<'de> {
deserializer.deserialize_identifier(FieldVisitor)
}
}

impl<'de> Visitor<'de> for FieldVisitor {
type Value = Field;

fn expecting(&self, fmt: &mut Formatter) -> fmt::Result {
fmt.write_str("field_identifier")
}

fn visit_str<E>(self, value: &str) -> core::result::Result<Self::Value, E>
where E: serde::de::Error {
match value {
"order" => Ok(Field::Order),
"head" => Ok(Field::Head),
"bits" => Ok(Field::Bits),
"data" => Ok(Field::Data),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
}

/// A zero-sized type that deserializes from any string as long as it is equal
/// to `any::type_name::<O>()`.
struct TypeName<O>(PhantomData<O>);

impl<O> TypeName<O> {
fn new() -> Self {
TypeName(PhantomData)
}
}

impl<'de, O> Deserialize<'de> for TypeName<O> {
fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
where D: Deserializer<'de> {
deserializer.deserialize_str(Self::new())
}
}

impl<'de, O> Visitor<'de> for TypeName<O> {
type Value = Self;

fn expecting(&self, fmt: &mut Formatter) -> fmt::Result {
write!(fmt, "the string {:?}", any::type_name::<O>())
}

fn visit_str<E>(self, value: &str) -> core::result::Result<Self::Value, E>
where E: serde::de::Error {
if value == any::type_name::<O>() {
Ok(self)
}
else {
Err(serde::de::Error::invalid_value(
Unexpected::Str(value),
&self,
))
}
}
}

#[cfg(test)]
mod tests {
use serde::{
Expand Down
79 changes: 38 additions & 41 deletions src/serdes/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use core::{
self,
Formatter,
},
marker::PhantomData,
};

use serde::{
Expand All @@ -28,6 +27,8 @@ use serde::{

use super::{
utils::Array,
Field,
TypeName,
FIELDS,
};
use crate::{
Expand Down Expand Up @@ -88,11 +89,7 @@ where
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de> {
deserializer
.deserialize_struct(
"BitArr",
FIELDS,
BitArrVisitor::<'de, T, O, 1>::THIS,
)
.deserialize_struct("BitArr", FIELDS, BitArrVisitor::<T, O, 1>::THIS)
.map(|BitArray { data: [elem], .. }| BitArray::new(elem))
}
}
Expand All @@ -108,22 +105,19 @@ where
deserializer.deserialize_struct(
"BitArr",
FIELDS,
BitArrVisitor::<'de, T, O, N>::THIS,
BitArrVisitor::<T, O, N>::THIS,
)
}
}

/// Assists in deserialization of a static `BitArr`.
struct BitArrVisitor<'de, T, O, const N: usize>
struct BitArrVisitor<T, O, const N: usize>
where
T: BitStore,
O: BitOrder,
Array<T, N>: Deserialize<'de>,
{
/// This produces a bit-array value during its work.
typ: PhantomData<BitArray<T, O>>,
/// The deserialized bit-ordering string.
order: Option<&'de str>,
order: Option<TypeName<O>>,
/// The deserialized head-bit index. This must be zero; it is used for
/// consistency with `BitSeq` and to carry `T::Mem` information.
head: Option<BitIdx<T::Mem>>,
Expand All @@ -133,15 +127,14 @@ where
data: Option<Array<T, N>>,
}

impl<'de, T, O, const N: usize> BitArrVisitor<'de, T, O, N>
impl<'de, T, O, const N: usize> BitArrVisitor<T, O, N>
where
T: BitStore,
O: BitOrder,
Array<T, N>: Deserialize<'de>,
{
/// A new visitor in its ready condition.
const THIS: Self = Self {
typ: PhantomData,
order: None,
head: None,
bits: None,
Expand All @@ -151,15 +144,11 @@ where
/// Attempts to assemble deserialized components into an output value.
fn assemble<E>(mut self) -> Result<BitArray<[T; N], O>, E>
where E: Error {
let order =
self.order.take().ok_or_else(|| E::missing_field("order"))?;
self.order.take().ok_or_else(|| E::missing_field("order"))?;
let head = self.head.take().ok_or_else(|| E::missing_field("head"))?;
let bits = self.bits.take().ok_or_else(|| E::missing_field("bits"))?;
let data = self.data.take().ok_or_else(|| E::missing_field("data"))?;

if order != any::type_name::<O>() {
return Err(E::invalid_type(Unexpected::Str(order), &self));
}
if head != BitIdx::MIN {
return Err(E::invalid_value(
Unexpected::Unsigned(head.into_inner() as u64),
Expand All @@ -175,7 +164,7 @@ where
}
}

impl<'de, T, O, const N: usize> Visitor<'de> for BitArrVisitor<'de, T, O, N>
impl<'de, T, O, const N: usize> Visitor<'de> for BitArrVisitor<T, O, N>
where
T: BitStore,
O: BitOrder,
Expand Down Expand Up @@ -217,32 +206,28 @@ where

fn visit_map<V>(mut self, mut map: V) -> Result<Self::Value, V::Error>
where V: MapAccess<'de> {
while let Some(key) = map.next_key::<&'de str>()? {
while let Some(key) = map.next_key()? {
match key {
"order" => {
Field::Order => {
if self.order.replace(map.next_value()?).is_some() {
return Err(<V::Error>::duplicate_field("order"));
}
},
"head" => {
Field::Head => {
if self.head.replace(map.next_value()?).is_some() {
return Err(<V::Error>::duplicate_field("head"));
}
},
"bits" => {
Field::Bits => {
if self.bits.replace(map.next_value()?).is_some() {
return Err(<V::Error>::duplicate_field("bits"));
}
},
"data" => {
Field::Data => {
if self.data.replace(map.next_value()?).is_some() {
return Err(<V::Error>::duplicate_field("data"));
}
},
f => {
let _ = map.next_value::<()>();
return Err(<V::Error>::unknown_field(f, FIELDS));
},
}
}

Expand Down Expand Up @@ -279,6 +264,10 @@ mod tests {
let array3 = serde_json::from_str::<BA>(&json)?;
assert_eq!(array, array3);

let json_value = serde_json::to_value(&array)?;
let array4 = serde_json::from_value::<BA>(json_value)?;
assert_eq!(array, array4);

type BA2 = BitArray<u16, Msb0>;
let array = BA2::new(44203);

Expand All @@ -290,6 +279,10 @@ mod tests {
let array3 = serde_json::from_str::<BA2>(&json)?;
assert_eq!(array, array3);

let json_value = serde_json::to_value(&array)?;
let array4 = serde_json::from_value::<BA2>(json_value)?;
assert_eq!(array, array4);

Ok(())
}

Expand Down Expand Up @@ -341,9 +334,21 @@ mod tests {
#[cfg(feature = "alloc")]
fn errors() {
type BA = BitArr!(for 8, in u8, Msb0);
let tokens = &mut [
let mut tokens = vec![
Token::Seq { len: Some(4) },
Token::BorrowedStr(any::type_name::<Msb0>()),
];

assert_de_tokens_error::<BitArr!(for 8, in u8, Lsb0)>(
&tokens,
&format!(
"invalid value: string \"{}\", expected the string \"{}\"",
any::type_name::<Msb0>(),
any::type_name::<Lsb0>(),
),
);

tokens.extend([
Token::Seq { len: Some(2) },
Token::U8(8),
Token::U8(0),
Expand All @@ -353,24 +358,18 @@ mod tests {
Token::U8(0),
Token::TupleEnd,
Token::SeqEnd,
];

assert_de_tokens_error::<BitArr!(for 8, in u8, Lsb0)>(
tokens,
"invalid type: string \"bitvec::order::Msb0\", expected a \
`BitArray<[u8; 1], bitvec::order::Lsb0>`",
);
]);

tokens[6] = Token::U64(7);
assert_de_tokens_error::<BA>(
tokens,
&tokens,
"invalid length 7, expected a `BitArray<[u8; 1], \
bitvec::order::Msb0>`",
);

tokens[4] = Token::U8(1);
assert_de_tokens_error::<BA>(
tokens,
&tokens,
"invalid value: integer `1`, expected `BitArray` must have a \
head-bit of `0`",
);
Expand All @@ -382,8 +381,6 @@ mod tests {
len: 2,
},
Token::BorrowedStr("placeholder"),
Token::Unit,
Token::StructEnd,
],
&format!(
"unknown field `placeholder`, expected one of `{}`",
Expand Down
Loading