Skip to content

Commit

Permalink
Implement DecodeWithMemTracking for basic types
Browse files Browse the repository at this point in the history
  • Loading branch information
serban300 committed Jul 23, 2024
1 parent 2bbd882 commit 7051378
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 6 deletions.
73 changes: 67 additions & 6 deletions src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use crate::{
},
compact::Compact,
encode_like::EncodeLike,
mem_tracking::DecodeWithMemTracking,
DecodeFinished, Error,
};

Expand Down Expand Up @@ -445,6 +446,7 @@ impl Input for BytesCursor {
return Err("Not enough data to fill buffer".into());
}

self.on_before_alloc_mem(length)?;
Ok(self.bytes.split_to(length))
}
}
Expand Down Expand Up @@ -482,6 +484,9 @@ impl Decode for bytes::Bytes {
}
}

#[cfg(feature = "bytes")]
impl DecodeWithMemTracking for bytes::Bytes {}

impl<T, X> Encode for X
where
T: Encode + ?Sized,
Expand Down Expand Up @@ -543,6 +548,7 @@ impl<T> WrapperTypeDecode for Box<T> {
// TODO: Use `Box::new_uninit` once that's stable.
let layout = core::alloc::Layout::new::<MaybeUninit<T>>();

input.on_before_alloc_mem(layout.size())?;
let ptr: *mut MaybeUninit<T> = if layout.size() == 0 {
core::ptr::NonNull::dangling().as_ptr()
} else {
Expand Down Expand Up @@ -581,6 +587,8 @@ impl<T> WrapperTypeDecode for Box<T> {
}
}

impl<T: DecodeWithMemTracking> DecodeWithMemTracking for Box<T> {}

impl<T> WrapperTypeDecode for Rc<T> {
type Wrapped = T;

Expand All @@ -593,6 +601,9 @@ impl<T> WrapperTypeDecode for Rc<T> {
}
}

// `Rc<T>` uses `Box::<T>::decode()` internally, so it supports `DecodeWithMemTracking`.
impl<T: DecodeWithMemTracking> DecodeWithMemTracking for Rc<T> {}

#[cfg(target_has_atomic = "ptr")]
impl<T> WrapperTypeDecode for Arc<T> {
type Wrapped = T;
Expand All @@ -606,6 +617,9 @@ impl<T> WrapperTypeDecode for Arc<T> {
}
}

// `Arc<T>` uses `Box::<T>::decode()` internally, so it supports `DecodeWithMemTracking`.
impl<T: DecodeWithMemTracking> DecodeWithMemTracking for Arc<T> {}

impl<T, X> Decode for X
where
T: Decode + Into<X>,
Expand Down Expand Up @@ -695,6 +709,8 @@ impl<T: Decode, E: Decode> Decode for Result<T, E> {
}
}

impl<T: DecodeWithMemTracking, E: DecodeWithMemTracking> DecodeWithMemTracking for Result<T, E> {}

/// Shim type because we can't do a specialised implementation for `Option<bool>` directly.
#[derive(Eq, PartialEq, Clone, Copy)]
pub struct OptionBool(pub Option<bool>);
Expand Down Expand Up @@ -732,6 +748,8 @@ impl Decode for OptionBool {
}
}

impl DecodeWithMemTracking for OptionBool {}

impl<T: EncodeLike<U>, U: Encode> EncodeLike<Option<U>> for Option<T> {}

impl<T: Encode> Encode for Option<T> {
Expand Down Expand Up @@ -768,6 +786,8 @@ impl<T: Decode> Decode for Option<T> {
}
}

impl<T: DecodeWithMemTracking> DecodeWithMemTracking for Option<T> {}

macro_rules! impl_for_non_zero {
( $( $name:ty ),* $(,)? ) => {
$(
Expand Down Expand Up @@ -797,6 +817,8 @@ macro_rules! impl_for_non_zero {
.ok_or_else(|| Error::from("cannot create non-zero number from 0"))
}
}

impl DecodeWithMemTracking for $name {}
)*
}
}
Expand Down Expand Up @@ -1000,6 +1022,8 @@ impl<T: Decode, const N: usize> Decode for [T; N] {
}
}

impl<T: DecodeWithMemTracking, const N: usize> DecodeWithMemTracking for [T; N] {}

impl<T: EncodeLike<U>, U: Encode, const N: usize> EncodeLike<[U; N]> for [T; N] {}

impl Encode for str {
Expand Down Expand Up @@ -1029,6 +1053,11 @@ where
}
}

impl<'a, T: ToOwned + DecodeWithMemTracking> DecodeWithMemTracking for Cow<'a, T> where
Cow<'a, T>: Decode
{
}

impl<T> EncodeLike for PhantomData<T> {}

impl<T> Encode for PhantomData<T> {
Expand All @@ -1041,12 +1070,16 @@ impl<T> Decode for PhantomData<T> {
}
}

impl<T: DecodeWithMemTracking> DecodeWithMemTracking for PhantomData<T> where PhantomData<T>: Decode {}

impl Decode for String {
fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
Self::from_utf8(Vec::decode(input)?).map_err(|_| "Invalid utf8 sequence".into())
}
}

impl DecodeWithMemTracking for String {}

/// Writes the compact encoding of `len` do `dest`.
pub(crate) fn compact_encode_len_to<W: Output + ?Sized>(
dest: &mut W,
Expand All @@ -1072,9 +1105,13 @@ impl<T: Encode> Encode for [T] {
}
}

fn decode_vec_chunked<T, F>(len: usize, mut decode_chunk: F) -> Result<Vec<T>, Error>
fn decode_vec_chunked<T, I: Input, F>(
input: &mut I,
len: usize,
mut decode_chunk: F,
) -> Result<Vec<T>, Error>
where
F: FnMut(&mut Vec<T>, usize) -> Result<(), Error>,
F: FnMut(&mut I, &mut Vec<T>, usize) -> Result<(), Error>,
{
const { assert!(MAX_PREALLOCATION >= mem::size_of::<T>()) }
// we have to account for the fact that `mem::size_of::<T>` can be 0 for types like `()`
Expand All @@ -1085,9 +1122,10 @@ where
let mut num_undecoded_items = len;
while num_undecoded_items > 0 {
let chunk_len = chunk_len.min(num_undecoded_items);
input.on_before_alloc_mem(chunk_len.saturating_mul(mem::size_of::<T>()))?;
decoded_vec.reserve_exact(chunk_len);

decode_chunk(&mut decoded_vec, chunk_len)?;
decode_chunk(input, &mut decoded_vec, chunk_len)?;

num_undecoded_items -= chunk_len;
}
Expand Down Expand Up @@ -1115,7 +1153,7 @@ where
}
}

decode_vec_chunked(len, |decoded_vec, chunk_len| {
decode_vec_chunked(input, len, |input, decoded_vec, chunk_len| {
let decoded_vec_len = decoded_vec.len();
let decoded_vec_size = decoded_vec_len * mem::size_of::<T>();
unsafe {
Expand All @@ -1133,7 +1171,7 @@ where
I: Input,
{
input.descend_ref()?;
let vec = decode_vec_chunked(len, |decoded_vec, chunk_len| {
let vec = decode_vec_chunked(input, len, |input, decoded_vec, chunk_len| {
for _ in 0..chunk_len {
decoded_vec.push(T::decode(input)?);
}
Expand Down Expand Up @@ -1185,6 +1223,8 @@ impl<T: Decode> Decode for Vec<T> {
}
}

impl<T: DecodeWithMemTracking> DecodeWithMemTracking for Vec<T> {}

macro_rules! impl_codec_through_iterator {
($(
$type:ident
Expand Down Expand Up @@ -1212,13 +1252,20 @@ macro_rules! impl_codec_through_iterator {
fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
<Compact<u32>>::decode(input).and_then(move |Compact(len)| {
input.descend_ref()?;
let result = Result::from_iter((0..len).map(|_| Decode::decode(input)));
let result = Result::from_iter((0..len).map(|_| {
input.on_before_alloc_mem(0 $( + mem::size_of::<$generics>() )*)?;
Decode::decode(input)
}));
input.ascend_ref();
result
})
}
}

impl<$( $generics: DecodeWithMemTracking ),*> DecodeWithMemTracking
for $type<$( $generics, )*>
where $type<$( $generics, )*>: Decode {}

impl<$( $impl_like_generics )*> EncodeLike<$type<$( $type_like_generics ),*>>
for $type<$( $generics ),*> {}
impl<$( $impl_like_generics )*> EncodeLike<&[( $( $type_like_generics, )* )]>
Expand Down Expand Up @@ -1265,6 +1312,8 @@ impl<T: Decode> Decode for VecDeque<T> {
}
}

impl<T: DecodeWithMemTracking> DecodeWithMemTracking for VecDeque<T> {}

impl EncodeLike for () {}

impl Encode for () {
Expand Down Expand Up @@ -1445,6 +1494,8 @@ macro_rules! impl_endians {
Some(mem::size_of::<$t>())
}
}

impl DecodeWithMemTracking for $t {}
)* }
}
macro_rules! impl_one_byte {
Expand All @@ -1470,6 +1521,8 @@ macro_rules! impl_one_byte {
Ok(input.read_byte()? as $t)
}
}

impl DecodeWithMemTracking for $t {}
)* }
}

Expand Down Expand Up @@ -1505,6 +1558,8 @@ impl Decode for bool {
}
}

impl DecodeWithMemTracking for bool {}

impl Encode for Duration {
fn size_hint(&self) -> usize {
mem::size_of::<u64>() + mem::size_of::<u32>()
Expand All @@ -1529,6 +1584,8 @@ impl Decode for Duration {
}
}

impl DecodeWithMemTracking for Duration {}

impl EncodeLike for Duration {}

impl<T> Encode for Range<T>
Expand All @@ -1555,6 +1612,8 @@ where
}
}

impl<T: DecodeWithMemTracking> DecodeWithMemTracking for Range<T> {}

impl<T> Encode for RangeInclusive<T>
where
T: Encode,
Expand All @@ -1579,6 +1638,8 @@ where
}
}

impl<T: DecodeWithMemTracking> DecodeWithMemTracking for RangeInclusive<T> {}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
4 changes: 4 additions & 0 deletions src/mem_tracking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
// limitations under the License.

use crate::Decode;
use impl_trait_for_tuples::impl_for_tuples;

/// Marker trait used for identifying types that call the mem tracking hooks exposed by `Input`
/// while decoding.
pub trait DecodeWithMemTracking: Decode {}

#[impl_for_tuples(18)]
impl DecodeWithMemTracking for Tuple {}

0 comments on commit 7051378

Please sign in to comment.