Skip to content

Commit

Permalink
Follow-up on #605 (#615)
Browse files Browse the repository at this point in the history
* Address #605 code review comments

* Check MAX_PREALLOCATION >= mem::size_of::<T> statically

* Update CI image to paritytech/ci-unified:bullseye-1.79.0

This reverts commit c54689d.
  • Loading branch information
serban300 authored Jul 23, 2024
1 parent 36baa4f commit a388fa9
Show file tree
Hide file tree
Showing 20 changed files with 205 additions and 130 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
pull_request:

env:
IMAGE: paritytech/ci-unified:bullseye-1.73.0
IMAGE: paritytech/ci-unified:bullseye-1.79.0
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
publish-crate:
runs-on: ubuntu-latest
environment: release
container: paritytech/ci-unified:bullseye-1.73.0
container: paritytech/ci-unified:bullseye-1.79.0
steps:
- name: Checkout code
uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion benches/benches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ fn encode_decode_complex_type(c: &mut Criterion) {

let complex_types = vec![
ComplexType { _val: 3, _other_val: 345634635, _vec: vec![1, 2, 3, 5, 6, 7] },
ComplexType { _val: 1000, _other_val: 0980345634635, _vec: vec![1, 2, 3, 5, 6, 7] },
ComplexType { _val: 1000, _other_val: 980345634635, _vec: vec![1, 2, 3, 5, 6, 7] },
ComplexType { _val: 43564, _other_val: 342342345634635, _vec: vec![1, 2, 3, 5, 6, 7] },
];

Expand Down
8 changes: 4 additions & 4 deletions derive/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ enum FieldAttribute<'a> {
None(&'a Field),
Compact(&'a Field),
EncodedAs { field: &'a Field, encoded_as: &'a TokenStream },
Skip(&'a Field),
Skip,
}

fn iterate_over_fields<F, H, J>(
Expand Down Expand Up @@ -138,7 +138,7 @@ where
} else if let Some(ref encoded_as) = encoded_as {
field_handler(field, FieldAttribute::EncodedAs { field: f, encoded_as })
} else if skip {
field_handler(field, FieldAttribute::Skip(f))
field_handler(field, FieldAttribute::Skip)
} else {
field_handler(field, FieldAttribute::None(f))
}
Expand Down Expand Up @@ -191,7 +191,7 @@ where
}
}
},
FieldAttribute::Skip(_) => quote! {
FieldAttribute::Skip => quote! {
let _ = #field;
},
},
Expand Down Expand Up @@ -236,7 +236,7 @@ where
))
}
},
FieldAttribute::Skip(_) => quote!(),
FieldAttribute::Skip => quote!(),
},
|recurse| {
quote! {
Expand Down
2 changes: 1 addition & 1 deletion derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ pub fn compact_as_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStr
(&field.ty, quote!(&self.#field_name), constructor)
},
Fields::Unnamed(ref fields) if utils::filter_skip_unnamed(fields).count() == 1 => {
let recurse = fields.unnamed.iter().enumerate().map(|(_, f)| {
let recurse = fields.unnamed.iter().map(|f| {
let val_or_default = val_or_default(f);
quote_spanned!(f.span()=> #val_or_default)
});
Expand Down
20 changes: 5 additions & 15 deletions derive/src/max_encoded_len.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,10 @@ pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::Tok
/// generate an expression to sum up the max encoded length from several fields
fn fields_length_expr(fields: &Fields, crate_path: &syn::Path) -> proc_macro2::TokenStream {
let fields_iter: Box<dyn Iterator<Item = &Field>> = match fields {
Fields::Named(ref fields) => Box::new(fields.named.iter().filter_map(|field| {
if should_skip(&field.attrs) {
None
} else {
Some(field)
}
})),
Fields::Unnamed(ref fields) => Box::new(fields.unnamed.iter().filter_map(|field| {
if should_skip(&field.attrs) {
None
} else {
Some(field)
}
})),
Fields::Named(ref fields) =>
Box::new(fields.named.iter().filter(|field| !should_skip(&field.attrs))),
Fields::Unnamed(ref fields) =>
Box::new(fields.unnamed.iter().filter(|field| !should_skip(&field.attrs))),
Fields::Unit => Box::new(std::iter::empty()),
};
// expands to an expression like
Expand All @@ -94,7 +84,7 @@ fn fields_length_expr(fields: &Fields, crate_path: &syn::Path) -> proc_macro2::T
// caused the issue.
let expansion = fields_iter.map(|field| {
let ty = &field.ty;
if utils::is_compact(&field) {
if utils::is_compact(field) {
quote_spanned! {
ty.span() => .saturating_add(
<<#ty as #crate_path::HasCompact>::Type as #crate_path::MaxEncodedLen>::max_encoded_len()
Expand Down
7 changes: 6 additions & 1 deletion fuzzer/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,13 @@ pub enum MockEnum {
Empty,
Unit(u32),
UnitVec(Vec<u8>),
Complex { data: Vec<u32>, bitvec: BitVecWrapper<u8, Msb0>, string: String },
Complex {
data: Vec<u32>,
bitvec: BitVecWrapper<u8, Msb0>,
string: String,
},
Mock(MockStruct),
#[allow(clippy::type_complexity)]
NestedVec(Vec<Vec<Vec<Vec<Vec<Vec<Vec<Vec<Option<u8>>>>>>>>>),
}

Expand Down
33 changes: 18 additions & 15 deletions src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1071,8 +1071,10 @@ fn decode_vec_chunked<T, F>(len: usize, mut decode_chunk: F) -> Result<Vec<T>, E
where
F: FnMut(&mut Vec<T>, usize) -> Result<(), Error>,
{
debug_assert!(MAX_PREALLOCATION >= mem::size_of::<T>(), "Invalid precondition");
let chunk_len = MAX_PREALLOCATION / mem::size_of::<T>();
const { assert!(MAX_PREALLOCATION >= mem::size_of::<T>()) }
// we have to account for the fact that `mem::size_of::<T>` can be 0 for types like `()`
// for example.
let chunk_len = MAX_PREALLOCATION.checked_div(mem::size_of::<T>()).unwrap_or(usize::MAX);

let mut decoded_vec = vec![];
let mut num_undecoded_items = len;
Expand All @@ -1082,7 +1084,7 @@ where

decode_chunk(&mut decoded_vec, chunk_len)?;

num_undecoded_items = num_undecoded_items.saturating_sub(chunk_len);
num_undecoded_items -= chunk_len;
}

Ok(decoded_vec)
Expand Down Expand Up @@ -1125,13 +1127,6 @@ where
T: Decode,
I: Input,
{
// Check if there is enough data in the input buffer.
if let Some(input_len) = input.remaining_len()? {
if input_len < len {
return Err("Not enough data to decode vector".into());
}
}

input.descend_ref()?;
let vec = decode_vec_chunked(len, |decoded_vec, chunk_len| {
for _ in 0..chunk_len {
Expand Down Expand Up @@ -1668,6 +1663,14 @@ mod tests {
assert_eq!(<Vec<OptionBool>>::decode(&mut &encoded[..]).unwrap(), value);
}

#[test]
fn vec_of_empty_tuples_encoded_as_expected() {
let value = vec![(), (), (), (), ()];
let encoded = value.encode();
assert_eq!(hexify(&encoded), "14");
assert_eq!(<Vec<()>>::decode(&mut &encoded[..]).unwrap(), value);
}

#[cfg(feature = "bytes")]
#[test]
fn bytes_works_as_expected() {
Expand Down Expand Up @@ -1699,7 +1702,7 @@ mod tests {
assert_eq!(decoded, &b"hello"[..]);

// The `slice_ref` will panic if the `decoded` is not a subslice of `encoded`.
assert_eq!(encoded.slice_ref(&decoded), &b"hello"[..]);
assert_eq!(encoded.slice_ref(decoded), &b"hello"[..]);
}

fn test_encode_length<T: Encode + Decode + DecodeLength>(thing: &T, len: usize) {
Expand Down Expand Up @@ -1890,8 +1893,8 @@ mod tests {
fn boolean() {
assert_eq!(true.encode(), vec![1]);
assert_eq!(false.encode(), vec![0]);
assert_eq!(bool::decode(&mut &[1][..]).unwrap(), true);
assert_eq!(bool::decode(&mut &[0][..]).unwrap(), false);
assert!(bool::decode(&mut &[1][..]).unwrap());
assert!(!bool::decode(&mut &[0][..]).unwrap());
}

#[test]
Expand All @@ -1908,7 +1911,7 @@ mod tests {
let encoded = data.encode();

let decoded = Vec::<u32>::decode(&mut &encoded[..]).unwrap();
assert!(decoded.iter().all(|v| data.contains(&v)));
assert!(decoded.iter().all(|v| data.contains(v)));
assert_eq!(data.len(), decoded.len());

let encoded = decoded.encode();
Expand Down Expand Up @@ -1939,7 +1942,7 @@ mod tests {
let num_nanos = 37;

let duration = Duration::new(num_secs, num_nanos);
let expected = (num_secs, num_nanos as u32).encode();
let expected = (num_secs, num_nanos).encode();

assert_eq!(duration.encode(), expected);
assert_eq!(Duration::decode(&mut &expected[..]).unwrap(), duration);
Expand Down
18 changes: 9 additions & 9 deletions src/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ mod tests {
(u128::MAX, 17),
];
for &(n, l) in &tests {
let encoded = Compact(n as u128).encode();
let encoded = Compact(n).encode();
assert_eq!(encoded.len(), l);
assert_eq!(Compact::compact_len(&n), l);
assert_eq!(<Compact<u128>>::decode(&mut &encoded[..]).unwrap().0, n);
Expand All @@ -761,7 +761,7 @@ mod tests {
(u64::MAX, 9),
];
for &(n, l) in &tests {
let encoded = Compact(n as u64).encode();
let encoded = Compact(n).encode();
assert_eq!(encoded.len(), l);
assert_eq!(Compact::compact_len(&n), l);
assert_eq!(<Compact<u64>>::decode(&mut &encoded[..]).unwrap().0, n);
Expand All @@ -781,7 +781,7 @@ mod tests {
(u32::MAX, 5),
];
for &(n, l) in &tests {
let encoded = Compact(n as u32).encode();
let encoded = Compact(n).encode();
assert_eq!(encoded.len(), l);
assert_eq!(Compact::compact_len(&n), l);
assert_eq!(<Compact<u32>>::decode(&mut &encoded[..]).unwrap().0, n);
Expand All @@ -792,7 +792,7 @@ mod tests {
fn compact_16_encoding_works() {
let tests = [(0u16, 1usize), (63, 1), (64, 2), (16383, 2), (16384, 4), (65535, 4)];
for &(n, l) in &tests {
let encoded = Compact(n as u16).encode();
let encoded = Compact(n).encode();
assert_eq!(encoded.len(), l);
assert_eq!(Compact::compact_len(&n), l);
assert_eq!(<Compact<u16>>::decode(&mut &encoded[..]).unwrap().0, n);
Expand All @@ -804,7 +804,7 @@ mod tests {
fn compact_8_encoding_works() {
let tests = [(0u8, 1usize), (63, 1), (64, 2), (255, 2)];
for &(n, l) in &tests {
let encoded = Compact(n as u8).encode();
let encoded = Compact(n).encode();
assert_eq!(encoded.len(), l);
assert_eq!(Compact::compact_len(&n), l);
assert_eq!(<Compact<u8>>::decode(&mut &encoded[..]).unwrap().0, n);
Expand Down Expand Up @@ -840,7 +840,7 @@ mod tests {
];
for &(n, s) in &tests {
// Verify u64 encoding
let encoded = Compact(n as u64).encode();
let encoded = Compact(n).encode();
assert_eq!(hexify(&encoded), s);
assert_eq!(<Compact<u64>>::decode(&mut &encoded[..]).unwrap().0, n);

Expand All @@ -849,19 +849,19 @@ mod tests {
assert_eq!(<Compact<u32>>::decode(&mut &encoded[..]).unwrap().0, n as u32);
let encoded = Compact(n as u32).encode();
assert_eq!(hexify(&encoded), s);
assert_eq!(<Compact<u64>>::decode(&mut &encoded[..]).unwrap().0, n as u64);
assert_eq!(<Compact<u64>>::decode(&mut &encoded[..]).unwrap().0, n);
}
if n <= u16::MAX as u64 {
assert_eq!(<Compact<u16>>::decode(&mut &encoded[..]).unwrap().0, n as u16);
let encoded = Compact(n as u16).encode();
assert_eq!(hexify(&encoded), s);
assert_eq!(<Compact<u64>>::decode(&mut &encoded[..]).unwrap().0, n as u64);
assert_eq!(<Compact<u64>>::decode(&mut &encoded[..]).unwrap().0, n);
}
if n <= u8::MAX as u64 {
assert_eq!(<Compact<u8>>::decode(&mut &encoded[..]).unwrap().0, n as u8);
let encoded = Compact(n as u8).encode();
assert_eq!(hexify(&encoded), s);
assert_eq!(<Compact<u64>>::decode(&mut &encoded[..]).unwrap().0, n as u64);
assert_eq!(<Compact<u64>>::decode(&mut &encoded[..]).unwrap().0, n);
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/encode_append.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ mod tests {
#[test]
fn vec_encode_append_multiple_items_works() {
let encoded = (0..TEST_VALUE).fold(Vec::new(), |encoded, v| {
<Vec<u32> as EncodeAppend>::append_or_new(encoded, &[v, v, v, v]).unwrap()
<Vec<u32> as EncodeAppend>::append_or_new(encoded, [v, v, v, v]).unwrap()
});

let decoded = Vec::<u32>::decode(&mut &encoded[..]).unwrap();
Expand All @@ -184,7 +184,7 @@ mod tests {
#[test]
fn vecdeque_encode_append_multiple_items_works() {
let encoded = (0..TEST_VALUE).fold(Vec::new(), |encoded, v| {
<VecDeque<u32> as EncodeAppend>::append_or_new(encoded, &[v, v, v, v]).unwrap()
<VecDeque<u32> as EncodeAppend>::append_or_new(encoded, [v, v, v, v]).unwrap()
});

let decoded = VecDeque::<u32>::decode(&mut &encoded[..]).unwrap();
Expand Down Expand Up @@ -228,7 +228,7 @@ mod tests {
#[test]
fn vec_encode_like_append_works() {
let encoded = (0..TEST_VALUE).fold(Vec::new(), |encoded, v| {
<Vec<u32> as EncodeAppend>::append_or_new(encoded, std::iter::once(Box::new(v as u32)))
<Vec<u32> as EncodeAppend>::append_or_new(encoded, std::iter::once(Box::new(v)))
.unwrap()
});

Expand Down
2 changes: 1 addition & 1 deletion src/encode_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ mod tests {
#[test]
fn vec_and_slice_are_working() {
let slice: &[u8] = &[1, 2, 3, 4];
let data: Vec<u8> = slice.iter().copied().collect();
let data: Vec<u8> = slice.to_vec();

let data_encoded = data.encode();
let slice_encoded = ComplexStuff::<Vec<u8>>::complex_method(&slice);
Expand Down
14 changes: 7 additions & 7 deletions tests/chain-error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ struct StructNamed {
}

#[derive(DeriveDecode, Debug)]
struct StructUnnamed(u16);
struct StructUnnamed(#[allow(dead_code)] u16);

#[derive(DeriveDecode, Debug)]
enum E {
VariantNamed { _foo: u16 },
VariantUnnamed(u16),
VariantUnnamed(#[allow(dead_code)] u16),
}

#[test]
fn full_error_struct_named() {
let encoded = vec![0];
let encoded = [0];
let err = r#"Could not decode `Wrapper.0`:
Could not decode `StructNamed::_foo`:
Not enough data to fill buffer
Expand All @@ -48,7 +48,7 @@ fn full_error_struct_named() {

#[test]
fn full_error_struct_unnamed() {
let encoded = vec![0];
let encoded = [0];
let err = r#"Could not decode `Wrapper.0`:
Could not decode `StructUnnamed.0`:
Not enough data to fill buffer
Expand All @@ -62,15 +62,15 @@ fn full_error_struct_unnamed() {

#[test]
fn full_error_enum_unknown_variant() {
let encoded = vec![2];
let encoded = [2];
let err = r#"Could not decode `E`, variant doesn't exist"#;

assert_eq!(E::decode(&mut &encoded[..]).unwrap_err().to_string(), String::from(err),);
}

#[test]
fn full_error_enum_named_field() {
let encoded = vec![0, 0];
let encoded = [0, 0];
let err = r#"Could not decode `E::VariantNamed::_foo`:
Not enough data to fill buffer
"#;
Expand All @@ -80,7 +80,7 @@ fn full_error_enum_named_field() {

#[test]
fn full_error_enum_unnamed_field() {
let encoded = vec![1, 0];
let encoded = [1, 0];
let err = r#"Could not decode `E::VariantUnnamed.0`:
Not enough data to fill buffer
"#;
Expand Down
Loading

0 comments on commit a388fa9

Please sign in to comment.