From 5b1b387522a441b7f73184f2ce71cf24e3f6ac8f Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Wed, 24 Jul 2024 14:43:08 +0300 Subject: [PATCH] Fix `Bytes` decoding --- src/codec.rs | 58 +++++++++++++++++++++++++++++-------------- src/depth_limit.rs | 12 +++++++++ src/lib.rs | 2 +- src/mem_tracking.rs | 8 ++++++ tests/mem_tracking.rs | 14 +++++++++++ 5 files changed, 75 insertions(+), 19 deletions(-) diff --git a/src/codec.rs b/src/codec.rs index b3e76e79..65783693 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -93,14 +93,14 @@ pub trait Input { /// !INTERNAL USE ONLY! /// - /// Decodes a `bytes::Bytes`. + /// Used when decoding a `bytes::Bytes` from a `BytesCursor` input. #[cfg(feature = "bytes")] #[doc(hidden)] - fn scale_internal_decode_bytes(&mut self) -> Result + fn __private_bytes_cursor(&mut self) -> Option<&mut BytesCursor> where Self: Sized, { - Vec::::decode(self).map(bytes::Bytes::from) + None } } @@ -414,12 +414,32 @@ mod feature_wrapper_bytes { impl EncodeLike for Vec {} } +/// `Input` implementation optimized for decoding `bytes::Bytes`. #[cfg(feature = "bytes")] -struct BytesCursor { +pub struct BytesCursor { bytes: bytes::Bytes, position: usize, } +#[cfg(feature = "bytes")] +impl BytesCursor { + /// Create a new instance of `BytesCursor`. + pub fn new(bytes: bytes::Bytes) -> Self { + Self { bytes, position: 0 } + } + + fn decode_bytes_with_len(&mut self, length: usize) -> Result { + bytes::Buf::advance(&mut self.bytes, self.position); + self.position = 0; + + if length > self.bytes.len() { + return Err("Not enough data to fill buffer".into()); + } + + Ok(self.bytes.split_to(length)) + } +} + #[cfg(feature = "bytes")] impl Input for BytesCursor { fn remaining_len(&mut self) -> Result, Error> { @@ -436,18 +456,11 @@ impl Input for BytesCursor { Ok(()) } - fn scale_internal_decode_bytes(&mut self) -> Result { - let length = >::decode(self)?.0 as usize; - - bytes::Buf::advance(&mut self.bytes, self.position); - self.position = 0; - - if length > self.bytes.len() { - return Err("Not enough data to fill buffer".into()); - } - - self.on_before_alloc_mem(length)?; - Ok(self.bytes.split_to(length)) + fn __private_bytes_cursor(&mut self) -> Option<&mut BytesCursor> + where + Self: Sized, + { + Some(self) } } @@ -473,14 +486,23 @@ where // However, if `T` doesn't contain any `Bytes` then this extra allocation is // technically unnecessary, and we can avoid it by tracking the position ourselves // and treating the underlying `Bytes` as a fancy `&[u8]`. - let mut input = BytesCursor { bytes, position: 0 }; + let mut input = BytesCursor::new(bytes); T::decode(&mut input) } #[cfg(feature = "bytes")] impl Decode for bytes::Bytes { fn decode(input: &mut I) -> Result { - input.scale_internal_decode_bytes() + let len = >::decode(input)?.0 as usize; + if input.__private_bytes_cursor().is_some() { + input.on_before_alloc_mem(len)?; + } + + if let Some(bytes_cursor) = input.__private_bytes_cursor() { + bytes_cursor.decode_bytes_with_len(len) + } else { + decode_vec_with_len::(input, len).map(bytes::Bytes::from) + } } } diff --git a/src/depth_limit.rs b/src/depth_limit.rs index 2af17843..7d4affb4 100644 --- a/src/depth_limit.rs +++ b/src/depth_limit.rs @@ -64,6 +64,18 @@ impl<'a, I: Input> Input for DepthTrackingInput<'a, I> { self.input.ascend_ref(); self.depth -= 1; } + + fn on_before_alloc_mem(&mut self, size: usize) -> Result<(), Error> { + self.input.on_before_alloc_mem(size) + } + + #[cfg(feature = "bytes")] + fn __private_bytes_cursor(&mut self) -> Option<&mut crate::BytesCursor> + where + Self: Sized, + { + self.input.__private_bytes_cursor() + } } impl DecodeLimit for T { diff --git a/src/lib.rs b/src/lib.rs index d29673dd..6bea868b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -130,4 +130,4 @@ pub use max_encoded_len::MaxEncodedLen; pub use parity_scale_codec_derive::MaxEncodedLen; #[cfg(feature = "bytes")] -pub use self::codec::decode_from_bytes; +pub use self::codec::{decode_from_bytes, BytesCursor}; diff --git a/src/mem_tracking.rs b/src/mem_tracking.rs index 2042b57b..409baa9c 100644 --- a/src/mem_tracking.rs +++ b/src/mem_tracking.rs @@ -75,4 +75,12 @@ impl<'a, I: Input> Input for MemTrackingInput<'a, I> { Ok(()) } + + #[cfg(feature = "bytes")] + fn __private_bytes_cursor(&mut self) -> Option<&mut crate::BytesCursor> + where + Self: Sized, + { + self.input.__private_bytes_cursor() + } } diff --git a/tests/mem_tracking.rs b/tests/mem_tracking.rs index b9b994ae..354c7d9a 100644 --- a/tests/mem_tracking.rs +++ b/tests/mem_tracking.rs @@ -83,6 +83,20 @@ fn decode_complex_objects_works() { assert!(decode_object(Box::new(Rc::new(vec![String::from("test")])), usize::MAX, 60).is_ok()); } +#[cfg(feature = "bytes")] +#[test] +fn decode_bytes_from_bytes_works() { + use parity_scale_codec::Decode; + + let obj = ([0u8; 100], Box::new(0u8), bytes::Bytes::from(vec![0u8; 50])); + let encoded_bytes = obj.encode(); + let mut bytes_cursor = parity_scale_codec::BytesCursor::new(bytes::Bytes::from(encoded_bytes)); + let mut input = MemTrackingInput::new(&mut bytes_cursor, usize::MAX); + let decoded_obj = <([u8; 100], Box, bytes::Bytes)>::decode(&mut input).unwrap(); + assert_eq!(&decoded_obj, &obj); + assert_eq!(input.used_mem(), 51); +} + #[test] fn decode_complex_derived_struct_works() { #[derive(DeriveEncode, DeriveDecode, PartialEq, Debug)]