Skip to content

Commit

Permalink
feat!: remove strict length check on tuple deserialization (#35)
Browse files Browse the repository at this point in the history
This should be handled by #[serde(deny_unknown_fields)] so it can also be used
non-strict (for custom visitors/deserializers). Unfortunately this isn't
supported upstream in serde yet, so we should work on that.

Co-authored-by: Steven Allen <[email protected]>
  • Loading branch information
rvagg and Stebalien authored Feb 13, 2025
1 parent 8930ea8 commit 89c6fcf
Show file tree
Hide file tree
Showing 7 changed files with 429 additions and 47 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ edition = "2018"

[dependencies]
cbor4ii = { version = "0.2.14", default-features = false, features = ["use_alloc"] }
ipld-core = { version = "0.4.0", default-features = false, features = ["serde"] }
ipld-core = { version = "0.4.2", default-features = false, features = ["serde"] }
scopeguard = "1.1.0"
serde = { version = "1.0.164", default-features = false, features = ["alloc"] }

[dev-dependencies]
serde_derive = { version = "1.0.164", default-features = false }
serde_bytes = { version = "0.11.9", default-features = false, features = ["alloc"]}
serde-transcode = "1.1.1"
const-hex = "1.14.0"
serde_tuple = "1.1.0"

[features]
default = ["codec", "std"]
Expand Down
1 change: 0 additions & 1 deletion examples/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::convert::{TryFrom, TryInto};
use ipld_core::{cid::Cid, ipld::Ipld};
use serde::{de, Deserialize};
use serde_bytes::ByteBuf;
use serde_derive::Deserialize;
use serde_ipld_dagcbor::from_slice;

/// The CID `bafkreibme22gw2h7y2h7tg2fhqotaqjucnbc24deqo72b6mkl2egezxhvy` encoded as CBOR
Expand Down
98 changes: 59 additions & 39 deletions src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,50 @@ impl<'de, R: dec::Read<'de>> Deserializer<R> {
Err(error) => Err(error),
}
}

fn visit_seq<V>(
&mut self,
name: &'static str,
visitor: V,
) -> Result<V::Value, DecodeError<R::Error>>
where
V: Visitor<'de>,
{
let mut de = self.try_step()?;
let mut seq = Accessor::array(&mut de)?;
let value = seq.len;
let res = visitor.visit_seq(&mut seq)?;
match seq.len {
0 => Ok(res),
remaining => Err(DecodeError::RequireLength {
name,
expect: value - remaining,
value,
}),
}
}

fn visit_map<V>(
&mut self,
name: &'static str,
visitor: V,
) -> Result<V::Value, DecodeError<R::Error>>
where
V: Visitor<'de>,
{
let mut de = self.try_step()?;
let mut map = Accessor::map(&mut de)?;
let value = map.len;
let res = visitor.visit_map(&mut map)?;
match map.len {
0 => Ok(res),
remaining => Err(DecodeError::RequireLength {
name,
expect: value - remaining,
value,
}),
}
}
}

macro_rules! deserialize_type {
Expand All @@ -162,7 +206,7 @@ macro_rules! deserialize_type {
};
}

impl<'de, 'a, R: dec::Read<'de>> serde::Deserializer<'de> for &'a mut Deserializer<R> {
impl<'de, R: dec::Read<'de>> serde::Deserializer<'de> for &mut Deserializer<R> {
type Error = DecodeError<R::Error>;

fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
Expand Down Expand Up @@ -343,55 +387,49 @@ impl<'de, 'a, R: dec::Read<'de>> serde::Deserializer<'de> for &'a mut Deserializ
where
V: Visitor<'de>,
{
let mut de = self.try_step()?;
let seq = Accessor::array(&mut de)?;
visitor.visit_seq(seq)
self.visit_seq("array", visitor)
}

#[inline]
fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let mut de = self.try_step()?;
let seq = Accessor::tuple(&mut de, len)?;
visitor.visit_seq(seq)
self.visit_seq("tuple", visitor)
}

#[inline]
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
len: usize,
name: &'static str,
_len: usize,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.deserialize_tuple(len, visitor)
self.visit_seq(name, visitor)
}

#[inline]
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let mut de = self.try_step()?;
let map = Accessor::map(&mut de)?;
visitor.visit_map(map)
self.visit_map("map", visitor)
}

#[inline]
fn deserialize_struct<V>(
self,
_name: &'static str,
name: &'static str,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.deserialize_map(visitor)
self.visit_map(name, visitor)
}

#[inline]
Expand Down Expand Up @@ -439,7 +477,7 @@ struct Accessor<'a, R> {

impl<'de, 'a, R: dec::Read<'de>> Accessor<'a, R> {
#[inline]
pub fn array(de: &'a mut Deserializer<R>) -> Result<Accessor<'a, R>, DecodeError<R::Error>> {
fn array(de: &'a mut Deserializer<R>) -> Result<Accessor<'a, R>, DecodeError<R::Error>> {
let array_start = dec::ArrayStart::decode(&mut de.reader)?;
array_start.0.map_or_else(
|| Err(DecodeError::IndefiniteSize),
Expand All @@ -448,25 +486,7 @@ impl<'de, 'a, R: dec::Read<'de>> Accessor<'a, R> {
}

#[inline]
pub fn tuple(
de: &'a mut Deserializer<R>,
len: usize,
) -> Result<Accessor<'a, R>, DecodeError<R::Error>> {
let array_start = dec::ArrayStart::decode(&mut de.reader)?;

if array_start.0 == Some(len) {
Ok(Accessor { de, len })
} else {
Err(DecodeError::RequireLength {
name: "tuple",
expect: len,
value: array_start.0.unwrap_or(0),
})
}
}

#[inline]
pub fn map(de: &'a mut Deserializer<R>) -> Result<Accessor<'a, R>, DecodeError<R::Error>> {
fn map(de: &'a mut Deserializer<R>) -> Result<Accessor<'a, R>, DecodeError<R::Error>> {
let map_start = dec::MapStart::decode(&mut de.reader)?;
map_start.0.map_or_else(
|| Err(DecodeError::IndefiniteSize),
Expand All @@ -475,7 +495,7 @@ impl<'de, 'a, R: dec::Read<'de>> Accessor<'a, R> {
}
}

impl<'de, 'a, R> de::SeqAccess<'de> for Accessor<'a, R>
impl<'de, R> de::SeqAccess<'de> for Accessor<'_, R>
where
R: dec::Read<'de>,
{
Expand All @@ -500,7 +520,7 @@ where
}
}

impl<'de, 'a, R: dec::Read<'de>> de::MapAccess<'de> for Accessor<'a, R> {
impl<'de, R: dec::Read<'de>> de::MapAccess<'de> for Accessor<'_, R> {
type Error = DecodeError<R::Error>;

#[inline]
Expand Down Expand Up @@ -570,7 +590,7 @@ where
}
}

impl<'de, 'a, R> de::VariantAccess<'de> for EnumAccessor<'a, R>
impl<'de, R> de::VariantAccess<'de> for EnumAccessor<'_, R>
where
R: dec::Read<'de>,
{
Expand Down
5 changes: 3 additions & 2 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl<E: fmt::Debug> From<cbor4ii::EncodeError<E>> for EncodeError<E> {
}

/// A decoding error.
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub enum DecodeError<E> {
/// Custom error message.
Msg(String),
Expand Down Expand Up @@ -101,7 +101,8 @@ pub enum DecodeError<E> {
/// Type name (e.g. "bytes", "str").
name: &'static str,
},
/// Length wasn't large enough.
/// Length wasn't large enough. This error comes after attempting to consume the entirety of a
/// item with a known length and failing to do so.
RequireLength {
/// Type name.
name: &'static str,
Expand Down
Loading

0 comments on commit 89c6fcf

Please sign in to comment.