From 7051378e86ab77c239026991b4695327f5a66393 Mon Sep 17 00:00:00 2001 From: Serban Iorga Date: Tue, 23 Jul 2024 14:27:15 +0300 Subject: [PATCH] Implement DecodeWithMemTracking for basic types --- src/codec.rs | 73 +++++++++++++++++++++++++++++++++++++++++---- src/mem_tracking.rs | 4 +++ 2 files changed, 71 insertions(+), 6 deletions(-) diff --git a/src/codec.rs b/src/codec.rs index 901748b3..b3e76e79 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -44,6 +44,7 @@ use crate::{ }, compact::Compact, encode_like::EncodeLike, + mem_tracking::DecodeWithMemTracking, DecodeFinished, Error, }; @@ -445,6 +446,7 @@ impl Input for BytesCursor { return Err("Not enough data to fill buffer".into()); } + self.on_before_alloc_mem(length)?; Ok(self.bytes.split_to(length)) } } @@ -482,6 +484,9 @@ impl Decode for bytes::Bytes { } } +#[cfg(feature = "bytes")] +impl DecodeWithMemTracking for bytes::Bytes {} + impl Encode for X where T: Encode + ?Sized, @@ -543,6 +548,7 @@ impl WrapperTypeDecode for Box { // TODO: Use `Box::new_uninit` once that's stable. let layout = core::alloc::Layout::new::>(); + input.on_before_alloc_mem(layout.size())?; let ptr: *mut MaybeUninit = if layout.size() == 0 { core::ptr::NonNull::dangling().as_ptr() } else { @@ -581,6 +587,8 @@ impl WrapperTypeDecode for Box { } } +impl DecodeWithMemTracking for Box {} + impl WrapperTypeDecode for Rc { type Wrapped = T; @@ -593,6 +601,9 @@ impl WrapperTypeDecode for Rc { } } +// `Rc` uses `Box::::decode()` internally, so it supports `DecodeWithMemTracking`. +impl DecodeWithMemTracking for Rc {} + #[cfg(target_has_atomic = "ptr")] impl WrapperTypeDecode for Arc { type Wrapped = T; @@ -606,6 +617,9 @@ impl WrapperTypeDecode for Arc { } } +// `Arc` uses `Box::::decode()` internally, so it supports `DecodeWithMemTracking`. +impl DecodeWithMemTracking for Arc {} + impl Decode for X where T: Decode + Into, @@ -695,6 +709,8 @@ impl Decode for Result { } } +impl DecodeWithMemTracking for Result {} + /// Shim type because we can't do a specialised implementation for `Option` directly. #[derive(Eq, PartialEq, Clone, Copy)] pub struct OptionBool(pub Option); @@ -732,6 +748,8 @@ impl Decode for OptionBool { } } +impl DecodeWithMemTracking for OptionBool {} + impl, U: Encode> EncodeLike> for Option {} impl Encode for Option { @@ -768,6 +786,8 @@ impl Decode for Option { } } +impl DecodeWithMemTracking for Option {} + macro_rules! impl_for_non_zero { ( $( $name:ty ),* $(,)? ) => { $( @@ -797,6 +817,8 @@ macro_rules! impl_for_non_zero { .ok_or_else(|| Error::from("cannot create non-zero number from 0")) } } + + impl DecodeWithMemTracking for $name {} )* } } @@ -1000,6 +1022,8 @@ impl Decode for [T; N] { } } +impl DecodeWithMemTracking for [T; N] {} + impl, U: Encode, const N: usize> EncodeLike<[U; N]> for [T; N] {} impl Encode for str { @@ -1029,6 +1053,11 @@ where } } +impl<'a, T: ToOwned + DecodeWithMemTracking> DecodeWithMemTracking for Cow<'a, T> where + Cow<'a, T>: Decode +{ +} + impl EncodeLike for PhantomData {} impl Encode for PhantomData { @@ -1041,12 +1070,16 @@ impl Decode for PhantomData { } } +impl DecodeWithMemTracking for PhantomData where PhantomData: Decode {} + impl Decode for String { fn decode(input: &mut I) -> Result { Self::from_utf8(Vec::decode(input)?).map_err(|_| "Invalid utf8 sequence".into()) } } +impl DecodeWithMemTracking for String {} + /// Writes the compact encoding of `len` do `dest`. pub(crate) fn compact_encode_len_to( dest: &mut W, @@ -1072,9 +1105,13 @@ impl Encode for [T] { } } -fn decode_vec_chunked(len: usize, mut decode_chunk: F) -> Result, Error> +fn decode_vec_chunked( + input: &mut I, + len: usize, + mut decode_chunk: F, +) -> Result, Error> where - F: FnMut(&mut Vec, usize) -> Result<(), Error>, + F: FnMut(&mut I, &mut Vec, usize) -> Result<(), Error>, { const { assert!(MAX_PREALLOCATION >= mem::size_of::()) } // we have to account for the fact that `mem::size_of::` can be 0 for types like `()` @@ -1085,9 +1122,10 @@ where let mut num_undecoded_items = len; while num_undecoded_items > 0 { let chunk_len = chunk_len.min(num_undecoded_items); + input.on_before_alloc_mem(chunk_len.saturating_mul(mem::size_of::()))?; decoded_vec.reserve_exact(chunk_len); - decode_chunk(&mut decoded_vec, chunk_len)?; + decode_chunk(input, &mut decoded_vec, chunk_len)?; num_undecoded_items -= chunk_len; } @@ -1115,7 +1153,7 @@ where } } - decode_vec_chunked(len, |decoded_vec, chunk_len| { + decode_vec_chunked(input, len, |input, decoded_vec, chunk_len| { let decoded_vec_len = decoded_vec.len(); let decoded_vec_size = decoded_vec_len * mem::size_of::(); unsafe { @@ -1133,7 +1171,7 @@ where I: Input, { input.descend_ref()?; - let vec = decode_vec_chunked(len, |decoded_vec, chunk_len| { + let vec = decode_vec_chunked(input, len, |input, decoded_vec, chunk_len| { for _ in 0..chunk_len { decoded_vec.push(T::decode(input)?); } @@ -1185,6 +1223,8 @@ impl Decode for Vec { } } +impl DecodeWithMemTracking for Vec {} + macro_rules! impl_codec_through_iterator { ($( $type:ident @@ -1212,13 +1252,20 @@ macro_rules! impl_codec_through_iterator { fn decode(input: &mut I) -> Result { >::decode(input).and_then(move |Compact(len)| { input.descend_ref()?; - let result = Result::from_iter((0..len).map(|_| Decode::decode(input))); + let result = Result::from_iter((0..len).map(|_| { + input.on_before_alloc_mem(0 $( + mem::size_of::<$generics>() )*)?; + Decode::decode(input) + })); input.ascend_ref(); result }) } } + impl<$( $generics: DecodeWithMemTracking ),*> DecodeWithMemTracking + for $type<$( $generics, )*> + where $type<$( $generics, )*>: Decode {} + impl<$( $impl_like_generics )*> EncodeLike<$type<$( $type_like_generics ),*>> for $type<$( $generics ),*> {} impl<$( $impl_like_generics )*> EncodeLike<&[( $( $type_like_generics, )* )]> @@ -1265,6 +1312,8 @@ impl Decode for VecDeque { } } +impl DecodeWithMemTracking for VecDeque {} + impl EncodeLike for () {} impl Encode for () { @@ -1445,6 +1494,8 @@ macro_rules! impl_endians { Some(mem::size_of::<$t>()) } } + + impl DecodeWithMemTracking for $t {} )* } } macro_rules! impl_one_byte { @@ -1470,6 +1521,8 @@ macro_rules! impl_one_byte { Ok(input.read_byte()? as $t) } } + + impl DecodeWithMemTracking for $t {} )* } } @@ -1505,6 +1558,8 @@ impl Decode for bool { } } +impl DecodeWithMemTracking for bool {} + impl Encode for Duration { fn size_hint(&self) -> usize { mem::size_of::() + mem::size_of::() @@ -1529,6 +1584,8 @@ impl Decode for Duration { } } +impl DecodeWithMemTracking for Duration {} + impl EncodeLike for Duration {} impl Encode for Range @@ -1555,6 +1612,8 @@ where } } +impl DecodeWithMemTracking for Range {} + impl Encode for RangeInclusive where T: Encode, @@ -1579,6 +1638,8 @@ where } } +impl DecodeWithMemTracking for RangeInclusive {} + #[cfg(test)] mod tests { use super::*; diff --git a/src/mem_tracking.rs b/src/mem_tracking.rs index 61f35c0a..5080d8f7 100644 --- a/src/mem_tracking.rs +++ b/src/mem_tracking.rs @@ -14,7 +14,11 @@ // limitations under the License. use crate::Decode; +use impl_trait_for_tuples::impl_for_tuples; /// Marker trait used for identifying types that call the mem tracking hooks exposed by `Input` /// while decoding. pub trait DecodeWithMemTracking: Decode {} + +#[impl_for_tuples(18)] +impl DecodeWithMemTracking for Tuple {}