Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pkhry committed Nov 4, 2024
1 parent f2255ee commit add9eae
Show file tree
Hide file tree
Showing 15 changed files with 231 additions and 118 deletions.
15 changes: 8 additions & 7 deletions derive/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ pub fn quote(
}
}
},
Data::Union(_) =>
Error::new(Span::call_site(), "Union types are not supported.").to_compile_error(),
Data::Union(_) => {
Error::new(Span::call_site(), "Union types are not supported.").to_compile_error()
},
}
}

Expand All @@ -120,8 +121,8 @@ pub fn quote_decode_into(
let fields = match data {
Data::Struct(syn::DataStruct {
fields:
Fields::Named(syn::FieldsNamed { named: fields, .. }) |
Fields::Unnamed(syn::FieldsUnnamed { unnamed: fields, .. }),
Fields::Named(syn::FieldsNamed { named: fields, .. })
| Fields::Unnamed(syn::FieldsUnnamed { unnamed: fields, .. }),
..
}) => fields,
_ => return None,
Expand All @@ -133,9 +134,9 @@ pub fn quote_decode_into(

// Bail if there are any extra attributes which could influence how the type is decoded.
if fields.iter().any(|field| {
utils::get_encoded_as_type(field).is_some() ||
utils::is_compact(field) ||
utils::should_skip(&field.attrs)
utils::get_encoded_as_type(field).is_some()
|| utils::is_compact(field)
|| utils::should_skip(&field.attrs)
}) {
return None;
}
Expand Down
29 changes: 23 additions & 6 deletions derive/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS
}
};

[hinting, encoding]
[hinting, encoding, quote! { #index }]
},
Fields::Unnamed(ref fields) => {
let fields = &fields.unnamed;
Expand Down Expand Up @@ -378,7 +378,7 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS
}
};

[hinting, encoding]
[hinting, encoding, quote! { #index }]
},
Fields::Unit => {
let hinting = quote_spanned! { f.span() =>
Expand All @@ -394,15 +394,15 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS
}
};

[hinting, encoding]
[hinting, encoding, quote! { #index }]
},
};
items.push(item)
}

let recurse_hinting = items.iter().map(|[hinting, _]| hinting);
let recurse_encoding = items.iter().map(|[_, encoding]| encoding);

let recurse_hinting = items.iter().map(|[hinting, _, _]| hinting);
let recurse_encoding = items.iter().map(|[_, encoding, _]| encoding);
let recurse_indices = items.iter().map(|[_, _, index]| index);
let hinting = quote! {
// The variant index uses 1 byte.
1_usize + match *#self_ {
Expand All @@ -412,6 +412,23 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS
};

let encoding = quote! {
const _: () = {
let indices = [#( #recurse_indices ,)*];
let len = indices.len();

// Check each pair for uniqueness
let mut index = 0;
while index < len {
let mut next_index = index + 1;
while next_index < len {
if indices[index] == indices[next_index] {
panic!("TODO: good error message with variant names and indices");
}
next_index += 1;
}
index += 1;
}
};
match *#self_ {
#( #recurse_encoding )*,
_ => (),
Expand Down
12 changes: 7 additions & 5 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,17 +367,19 @@ pub fn compact_as_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStr
let constructor = quote!( #name(#( #recurse, )*));
(&field.ty, quote!(&self.#id), constructor)
},
_ =>
_ => {
return Error::new(
data.fields.span(),
"Only structs with a single non-skipped field can derive CompactAs",
)
.to_compile_error()
.into(),
.into()
},
},
Data::Enum(syn::DataEnum { enum_token: syn::token::Enum { span }, .. })
| Data::Union(syn::DataUnion { union_token: syn::token::Union { span }, .. }) => {
return Error::new(span, "Only structs can derive CompactAs").to_compile_error().into()
},
Data::Enum(syn::DataEnum { enum_token: syn::token::Enum { span }, .. }) |
Data::Union(syn::DataUnion { union_token: syn::token::Union { span }, .. }) =>
return Error::new(span, "Only structs can derive CompactAs").to_compile_error().into(),
};

let impl_block = quote! {
Expand Down
10 changes: 6 additions & 4 deletions derive/src/max_encoded_len.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::Tok
/// generate an expression to sum up the max encoded length from several fields
fn fields_length_expr(fields: &Fields, crate_path: &syn::Path) -> proc_macro2::TokenStream {
let fields_iter: Box<dyn Iterator<Item = &Field>> = match fields {
Fields::Named(ref fields) =>
Box::new(fields.named.iter().filter(|field| !should_skip(&field.attrs))),
Fields::Unnamed(ref fields) =>
Box::new(fields.unnamed.iter().filter(|field| !should_skip(&field.attrs))),
Fields::Named(ref fields) => {
Box::new(fields.named.iter().filter(|field| !should_skip(&field.attrs)))
},
Fields::Unnamed(ref fields) => {
Box::new(fields.unnamed.iter().filter(|field| !should_skip(&field.attrs)))
},
Fields::Unit => Box::new(std::iter::empty()),
};
// expands to an expression like
Expand Down
30 changes: 17 additions & 13 deletions derive/src/trait_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ pub fn add<N>(
generics.make_where_clause().predicates.extend(bounds);
return Ok(());
},
Some(CustomTraitBound::SkipTypeParams { type_names, .. }) =>
type_names.into_iter().collect::<Vec<_>>(),
Some(CustomTraitBound::SkipTypeParams { type_names, .. }) => {
type_names.into_iter().collect::<Vec<_>>()
},
None => Vec::new(),
};

Expand Down Expand Up @@ -189,9 +190,9 @@ fn get_types_to_add_trait_bound(
Ok(ty_params.iter().map(|t| parse_quote!( #t )).collect())
} else {
let needs_codec_bound = |f: &syn::Field| {
!utils::is_compact(f) &&
utils::get_encoded_as_type(f).is_none() &&
!utils::should_skip(&f.attrs)
!utils::is_compact(f)
&& utils::get_encoded_as_type(f).is_none()
&& !utils::should_skip(&f.attrs)
};
let res = collect_types(data, needs_codec_bound)?
.into_iter()
Expand Down Expand Up @@ -222,9 +223,10 @@ fn collect_types(data: &syn::Data, type_filter: fn(&syn::Field) -> bool) -> Resu

let types = match *data {
Data::Struct(ref data) => match &data.fields {
| Fields::Named(FieldsNamed { named: fields, .. }) |
Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) =>
fields.iter().filter(|f| type_filter(f)).map(|f| f.ty.clone()).collect(),
| Fields::Named(FieldsNamed { named: fields, .. })
| Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) => {
fields.iter().filter(|f| type_filter(f)).map(|f| f.ty.clone()).collect()
},

Fields::Unit => Vec::new(),
},
Expand All @@ -234,16 +236,18 @@ fn collect_types(data: &syn::Data, type_filter: fn(&syn::Field) -> bool) -> Resu
.iter()
.filter(|variant| !utils::should_skip(&variant.attrs))
.flat_map(|variant| match &variant.fields {
| Fields::Named(FieldsNamed { named: fields, .. }) |
Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) =>
fields.iter().filter(|f| type_filter(f)).map(|f| f.ty.clone()).collect(),
| Fields::Named(FieldsNamed { named: fields, .. })
| Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) => {
fields.iter().filter(|f| type_filter(f)).map(|f| f.ty.clone()).collect()
},

Fields::Unit => Vec::new(),
})
.collect(),

Data::Union(ref data) =>
return Err(Error::new(data.union_token.span(), "Union types are not supported.")),
Data::Union(ref data) => {
return Err(Error::new(data.union_token.span(), "Union types are not supported."))
},
};

Ok(types)
Expand Down
115 changes: 51 additions & 64 deletions derive/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,53 +43,22 @@ where
pub fn check_indexes<'a, I: Iterator<Item = &'a &'a Variant>>(values: I) -> syn::Result<()> {
let mut map: HashMap<u8, Span> = HashMap::new();
for (i, v) in values.enumerate() {
if let Some(index) = find_meta_item(v.attrs.iter(), |meta| {
if let Meta::NameValue(ref nv) = meta {
if nv.path.is_ident("index") {
if let Expr::Lit(ExprLit { lit: Lit::Int(ref v), .. }) = nv.value {
let byte = v
.base10_parse::<u8>()
.expect("Internal error, index attribute must have been checked");
return Some(byte);
}
}
}
None
}) {
if let Some(span) = map.insert(index, v.span()) {
let mut error = syn::Error::new(v.span(), "Duplicate variant index. qed");
error.combine(syn::Error::new(span, "Variant index already defined here."));
return Err(error)
}
} else {
match v.discriminant.as_ref() {
Some((_, syn::Expr::Lit(ExprLit { lit: syn::Lit::Int(lit_int), .. }))) => {
let index = lit_int
.base10_parse::<u8>()
.expect("Internal error, index attribute must have been checked");
if let Some(span) = map.insert(index, v.span()) {
let mut error = syn::Error::new(v.span(), "Duplicate variant index. qed");
error.combine(syn::Error::new(span, "Variant index already defined here."));
return Err(error)
}
},
Some((_, _)) => return Err(syn::Error::new(v.span(), "Invalid discriminant. qed")),
None =>
if let Some(span) = map.insert(i.try_into().unwrap(), v.span()) {
let mut error =
syn::Error::new(span, "Custom variant index is duplicated later. qed");
error.combine(syn::Error::new(v.span(), "Variant index derived here."));
return Err(error)
},
}
let index = variant_index(v, i)?;
if let Some(span) = map.insert(index, v.span()) {
let mut error = syn::Error::new(
v.span(),
"scale codec error: Invalid variant index, the variant index is duplicated.",
);
error.combine(syn::Error::new(span, "Variant index used here."));
return Err(error);
}
}
Ok(())
}

/// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute
/// is found, fall back to the discriminant or just the variant index.
pub fn variant_index(v: &Variant, index: usize) -> syn::Result<TokenStream> {
pub fn variant_index(v: &Variant, index: usize) -> syn::Result<u8> {
// first look for an attribute
let codec_index = find_meta_item(v.attrs.iter(), |meta| {
if let Meta::NameValue(ref nv) = meta {
Expand All @@ -106,13 +75,27 @@ pub fn variant_index(v: &Variant, index: usize) -> syn::Result<TokenStream> {
None
});
if let Some(index) = codec_index {
Ok(quote! { #index })
Ok(index)
} else {
match v.discriminant.as_ref() {
Some((_, expr @ syn::Expr::Lit(ExprLit { lit: syn::Lit::Int(_), .. }))) =>
Ok(quote! { #expr }),
Some((_, expr)) => Err(syn::Error::new(expr.span(), "Invalid discriminant. qed")),
None => Ok(quote! { #index }),
Some((_, syn::Expr::Lit(ExprLit { lit: syn::Lit::Int(v), .. }))) => {
let byte = v.base10_parse::<u8>().expect(
"scale codec error: Invalid variant index, discriminant doesn't fit u8.",
);
Ok(byte)
},
Some((_, expr)) => Err(syn::Error::new(
expr.span(),
"scale codec error: Invalid discriminant, only int literal are accepted, e.g. \
`= 32`.",
)),
None => index.try_into().map_err(|_| {
syn::Error::new(
v.span(),
"scale codec error: Variant index is too large, only 256 variants are \
supported.",
)
}),
}
}
}
Expand Down Expand Up @@ -363,16 +346,17 @@ pub fn check_attributes(input: &DeriveInput) -> syn::Result<()> {

match input.data {
Data::Struct(ref data) => match &data.fields {
| Fields::Named(FieldsNamed { named: fields, .. }) |
Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) =>
| Fields::Named(FieldsNamed { named: fields, .. })
| Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) => {
for field in fields {
for attr in &field.attrs {
check_field_attribute(attr)?;
}
},
}
},
Fields::Unit => (),
},
Data::Enum(ref data) =>
Data::Enum(ref data) => {
for variant in data.variants.iter() {
for attr in &variant.attrs {
check_variant_attribute(attr)?;
Expand All @@ -382,18 +366,19 @@ pub fn check_attributes(input: &DeriveInput) -> syn::Result<()> {
check_field_attribute(attr)?;
}
}
},
}
},
Data::Union(_) => (),
}
Ok(())
}

// Check if the attribute is `#[allow(..)]`, `#[deny(..)]`, `#[forbid(..)]` or `#[warn(..)]`.
pub fn is_lint_attribute(attr: &Attribute) -> bool {
attr.path().is_ident("allow") ||
attr.path().is_ident("deny") ||
attr.path().is_ident("forbid") ||
attr.path().is_ident("warn")
attr.path().is_ident("allow")
|| attr.path().is_ident("deny")
|| attr.path().is_ident("forbid")
|| attr.path().is_ident("warn")
}

// Ensure a field is decorated only with the following attributes:
Expand All @@ -418,10 +403,11 @@ fn check_field_attribute(attr: &Attribute) -> syn::Result<()> {
path,
value: Expr::Lit(ExprLit { lit: Lit::Str(lit_str), .. }),
..
}) if path.get_ident().map_or(false, |i| i == "encoded_as") =>
}) if path.get_ident().map_or(false, |i| i == "encoded_as") => {
TokenStream::from_str(&lit_str.value())
.map(|_| ())
.map_err(|_e| syn::Error::new(lit_str.span(), "Invalid token stream")),
.map_err(|_e| syn::Error::new(lit_str.span(), "Invalid token stream"))
},

elt => Err(syn::Error::new(elt.span(), field_error)),
}
Expand Down Expand Up @@ -468,20 +454,21 @@ fn check_top_attribute(attr: &Attribute) -> syn::Result<()> {
`#[codec(decode_bound(T: Decode))]`, \
`#[codec(decode_bound_with_mem_tracking_bound(T: DecodeWithMemTracking))]` or \
`#[codec(mel_bound(T: MaxEncodedLen))]` are accepted as top attribute";
if attr.path().is_ident("codec") &&
attr.parse_args::<CustomTraitBound<encode_bound>>().is_err() &&
attr.parse_args::<CustomTraitBound<decode_bound>>().is_err() &&
attr.parse_args::<CustomTraitBound<decode_with_mem_tracking_bound>>().is_err() &&
attr.parse_args::<CustomTraitBound<mel_bound>>().is_err() &&
codec_crate_path_inner(attr).is_none()
if attr.path().is_ident("codec")
&& attr.parse_args::<CustomTraitBound<encode_bound>>().is_err()
&& attr.parse_args::<CustomTraitBound<decode_bound>>().is_err()
&& attr.parse_args::<CustomTraitBound<decode_with_mem_tracking_bound>>().is_err()
&& attr.parse_args::<CustomTraitBound<mel_bound>>().is_err()
&& codec_crate_path_inner(attr).is_none()
{
let nested = attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
if nested.len() != 1 {
return Err(syn::Error::new(attr.meta.span(), top_error));
}
match nested.first().expect("Just checked that there is one item; qed") {
Meta::Path(path) if path.get_ident().map_or(false, |i| i == "dumb_trait_bound") =>
Ok(()),
Meta::Path(path) if path.get_ident().map_or(false, |i| i == "dumb_trait_bound") => {
Ok(())
},

elt => Err(syn::Error::new(elt.span(), top_error)),
}
Expand Down
Loading

0 comments on commit add9eae

Please sign in to comment.