Skip to content

Commit

Permalink
Define MemTrackingInput
Browse files Browse the repository at this point in the history
  • Loading branch information
serban300 committed Jul 24, 2024
1 parent 7051378 commit 0313d3b
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pub use self::{
error::Error,
joiner::Joiner,
keyedvec::KeyedVec,
mem_tracking::{DecodeWithMemTracking, MemTrackingInput},
};
#[cfg(feature = "max-encoded-len")]
pub use const_encoded_len::ConstEncodedLen;
Expand Down
56 changes: 55 additions & 1 deletion src/mem_tracking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,66 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::Decode;
use crate::{Decode, Error, Input};
use impl_trait_for_tuples::impl_for_tuples;

/// Marker trait used for identifying types that call the mem tracking hooks exposed by `Input`
/// while decoding.
pub trait DecodeWithMemTracking: Decode {}

const DECODE_OOM_MSG: &str = "Heap memory limit exceeded while decoding";

#[impl_for_tuples(18)]
impl DecodeWithMemTracking for Tuple {}

/// `Input` implementation that can be used for limiting the heap memory usage while decoding.
pub struct MemTrackingInput<'a, I> {
input: &'a mut I,
used_mem: usize,
mem_limit: usize,
}

impl<'a, I: Input> MemTrackingInput<'a, I> {
/// Create a new instance of `MemTrackingInput`.
pub fn new(input: &'a mut I, mem_limit: usize) -> Self {
Self { input, used_mem: 0, mem_limit }
}

/// Get the `used_mem` field.
pub fn used_mem(&self) -> usize {
self.used_mem
}
}

impl<'a, I: Input> Input for MemTrackingInput<'a, I> {
fn remaining_len(&mut self) -> Result<Option<usize>, Error> {
self.input.remaining_len()
}

fn read(&mut self, into: &mut [u8]) -> Result<(), Error> {
self.input.read(into)
}

fn read_byte(&mut self) -> Result<u8, Error> {
self.input.read_byte()
}

fn descend_ref(&mut self) -> Result<(), Error> {
self.input.descend_ref()
}

fn ascend_ref(&mut self) {
self.input.ascend_ref()
}

fn on_before_alloc_mem(&mut self, size: usize) -> Result<(), Error> {
self.input.on_before_alloc_mem(size)?;

self.used_mem = self.used_mem.saturating_add(size);
if self.used_mem >= self.mem_limit {
return Err(DECODE_OOM_MSG.into())
}

Ok(())
}
}
118 changes: 118 additions & 0 deletions tests/mem_tracking.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright (C) Parity Technologies (UK) Ltd.
// SPDX-License-Identifier: Apache-2.0

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use core::fmt::Debug;
use parity_scale_codec::{
alloc::{
collections::{BTreeMap, BTreeSet, LinkedList, VecDeque},
rc::Rc,
},
DecodeWithMemTracking, Encode, Error, MemTrackingInput,
};
use parity_scale_codec_derive::{Decode as DeriveDecode, Encode as DeriveEncode};

fn decode_object<T>(obj: T, mem_limit: usize, expected_used_mem: usize) -> Result<T, Error>
where
T: Encode + DecodeWithMemTracking + PartialEq + Debug,
{
let encoded_bytes = obj.encode();
let raw_input = &mut &encoded_bytes[..];
let mut input = MemTrackingInput::new(raw_input, mem_limit);
let decoded_obj = T::decode(&mut input)?;
assert_eq!(&decoded_obj, &obj);
assert_eq!(input.used_mem(), expected_used_mem);
Ok(decoded_obj)
}

#[test]
fn decode_simple_objects_works() {
const ARRAY: [u8; 1000] = [11; 1000];

// Test simple objects
assert!(decode_object(ARRAY, usize::MAX, 0).is_ok());
assert!(decode_object(Some(ARRAY), usize::MAX, 0).is_ok());
assert!(decode_object((ARRAY, ARRAY), usize::MAX, 0).is_ok());
assert!(decode_object(1u8, usize::MAX, 0).is_ok());
assert!(decode_object(1u32, usize::MAX, 0).is_ok());
assert!(decode_object(1f64, usize::MAX, 0).is_ok());

// Test heap objects
assert!(decode_object(Box::new(ARRAY), usize::MAX, 1000).is_ok());
#[cfg(target_has_atomic = "ptr")]
{
use parity_scale_codec::alloc::sync::Arc;
assert!(decode_object(Arc::new(ARRAY), usize::MAX, 1000).is_ok());
}
assert!(decode_object(Rc::new(ARRAY), usize::MAX, 1000).is_ok());
// Simple collections
assert!(decode_object(vec![ARRAY; 3], usize::MAX, 3000).is_ok());
assert!(decode_object(VecDeque::from(vec![ARRAY; 5]), usize::MAX, 5000).is_ok());
assert!(decode_object(String::from("test"), usize::MAX, 4).is_ok());
#[cfg(feature = "bytes")]
assert!(decode_object(bytes::Bytes::from(&ARRAY[..]), usize::MAX, 1000).is_ok());
// Complex Collections
assert!(decode_object(BTreeMap::<u8, u8>::from([(1, 2), (2, 3)]), usize::MAX, 4).is_ok());
assert!(decode_object(
BTreeMap::from([
("key1".to_string(), "value1".to_string()),
("key2".to_string(), "value2".to_string()),
]),
usize::MAX,
116,
)
.is_ok());
assert!(decode_object(BTreeSet::<u8>::from([1, 2, 3, 4, 5]), usize::MAX, 5).is_ok());
assert!(decode_object(LinkedList::<u8>::from([1, 2, 3, 4, 5]), usize::MAX, 5).is_ok());
}

#[test]
fn decode_complex_objects_works() {
assert!(decode_object(vec![vec![vec![vec![vec![1u8]]]]], usize::MAX, 97).is_ok());
assert!(decode_object(Box::new(Rc::new(vec![String::from("test")])), usize::MAX, 60).is_ok());
}

#[test]
fn decode_complex_derived_struct_works() {
#[derive(DeriveEncode, DeriveDecode, PartialEq, Debug)]
#[allow(clippy::large_enum_variant)]
enum TestEnum {
Empty,
Array([u8; 1000]),
}

impl DecodeWithMemTracking for TestEnum {}

#[derive(DeriveEncode, DeriveDecode, PartialEq, Debug)]
struct ComplexStruct {
test_enum: TestEnum,
boxed_test_enum: Box<TestEnum>,
box_field: Box<u32>,
vec: Vec<u8>,
}

impl DecodeWithMemTracking for ComplexStruct {}

assert!(decode_object(
ComplexStruct {
test_enum: TestEnum::Array([0; 1000]),
boxed_test_enum: Box::new(TestEnum::Empty),
box_field: Box::new(1),
vec: vec![1; 10],
},
usize::MAX,
1015
)
.is_ok())
}

0 comments on commit 0313d3b

Please sign in to comment.