Skip to content

Commit

Permalink
Merge pull request fedimint#5304 from elsirion/2024-05-derive-encode-…
Browse files Browse the repository at this point in the history
…custom-idx

feat: variant index annotations for encodable enums
  • Loading branch information
dpc authored May 16, 2024
2 parents 81da7c6 + 065a45a commit 5bd3555
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 43 deletions.
39 changes: 39 additions & 0 deletions fedimint-core/src/encoding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,45 @@ mod tests {
);
}

#[test]
fn test_custom_index_enum() {
#[derive(Debug, PartialEq, Eq, Encodable, Decodable)]
enum Old {
Foo,
Bar,
Baz,
}

#[derive(Debug, PartialEq, Eq, Encodable, Decodable)]
enum New {
#[encodable(index = 0)]
Foo,
#[encodable(index = 2)]
Baz,
#[encodable_default]
Default { variant: u64, bytes: Vec<u8> },
}

let test_vector = vec![
(Old::Foo, New::Foo),
(
Old::Bar,
New::Default {
variant: 1,
bytes: vec![],
},
),
(Old::Baz, New::Baz),
];

for (old, new) in test_vector {
let old_bytes = old.consensus_encode_to_vec();
let decoded_new =
New::consensus_decode_vec(old_bytes, &Default::default()).expect("Decoding failed");
assert_eq!(decoded_new, new);
}
}

fn encode_value<T: Encodable>(value: &T) -> Vec<u8> {
let mut writer = Vec::new();
value.consensus_encode(&mut writer).unwrap();
Expand Down
152 changes: 109 additions & 43 deletions fedimint-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use quote::{format_ident, quote};
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{
parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, Index, Variant,
parse_macro_input, Attribute, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, Index,
Lit, Token, Variant,
};

fn do_not_ignore(field: &Field) -> bool {
Expand Down Expand Up @@ -51,7 +52,9 @@ fn is_default_variant_enforce_valid(variant: &Variant) -> bool {
is_default
}

#[proc_macro_derive(Encodable, attributes(encodable_ignore, encodable_default))]
// TODO: use encodable attr for everything: #[encodable(ignore)],
// #[encodable(index = 42)], …
#[proc_macro_derive(Encodable, attributes(encodable_ignore, encodable_default, encodable))]
pub fn derive_encodable(input: TokenStream) -> TokenStream {
let DeriveInput {
ident,
Expand Down Expand Up @@ -107,51 +110,117 @@ fn derive_struct_encode(fields: &Fields) -> TokenStream2 {
}
}

/// Extracts the u64 index from an attribute if it matches `#[encodable(index =
/// <u64>)]`.
fn parse_index_attribute(attributes: &[Attribute]) -> Option<u64> {
attributes
.iter()
.filter_map(|attr| {
if attr.path().is_ident("encodable") {
attr.parse_args_with(|input: syn::parse::ParseStream| {
input.parse::<syn::Ident>()?.span(); // consume the ident 'index'
input.parse::<Token![=]>()?; // consume the '='
if let Lit::Int(lit_int) = input.parse::<Lit>()? {
lit_int.base10_parse()
} else {
Err(input.error("Expected an integer for 'index'"))
}
})
.ok()
} else {
None
}
})
.next()
}

/// Processes all variants in a `Punctuated` list extracting any specified
/// index.
fn extract_variants_with_indices(input_variants: Vec<Variant>) -> Vec<(Option<u64>, Variant)> {
input_variants
.into_iter()
.map(|variant| {
let index = parse_index_attribute(&variant.attrs);
(index, variant)
})
.collect()
}

fn non_default_variant_indices(variants: &Punctuated<Variant, Comma>) -> Vec<(u64, Variant)> {
let non_default_variants = variants
.into_iter()
.filter(|variant| !is_default_variant_enforce_valid(variant))
.cloned()
.collect::<Vec<_>>();

let attr_indices = extract_variants_with_indices(non_default_variants.clone());

let all_have_index = attr_indices.iter().all(|(idx, _)| idx.is_some());
let none_have_index = attr_indices.iter().all(|(idx, _)| idx.is_none());

assert!(
all_have_index || none_have_index,
"Either all or none of the variants should have an index annotation"
);

if all_have_index {
attr_indices
.into_iter()
.map(|(idx, variant)| (idx.expect("We made sure everything has an index"), variant))
.collect()
} else {
non_default_variants
.into_iter()
.enumerate()
.map(|(idx, variant)| (idx as u64, variant))
.collect()
}
}

fn derive_enum_encode(ident: &Ident, variants: &Punctuated<Variant, Comma>) -> TokenStream2 {
if variants.is_empty() {
return quote! {
match *self {}
};
}

let non_default_match_arms = variants
.iter()
.filter(|variant| !is_default_variant_enforce_valid(variant))
.enumerate()
.map(|(variant_idx, variant)| {
let variant_ident = variant.ident.clone();

if is_tuple_struct(&variant.fields) {
let variant_fields = variant
.fields
.iter()
.enumerate()
.filter(|(_, f)| do_not_ignore(f))
.map(|(idx, _)| format_ident!("bound_{}", idx))
.collect::<Vec<_>>();
let variant_encode_block =
derive_enum_variant_encode_block(variant_idx, &variant_fields);
quote! {
#ident::#variant_ident(#(#variant_fields,)*) => {
#variant_encode_block
let non_default_match_arms =
non_default_variant_indices(variants)
.into_iter()
.map(|(variant_idx, variant)| {
let variant_ident = variant.ident.clone();

if is_tuple_struct(&variant.fields) {
let variant_fields = variant
.fields
.iter()
.enumerate()
.filter(|(_, f)| do_not_ignore(f))
.map(|(idx, _)| format_ident!("bound_{}", idx))
.collect::<Vec<_>>();
let variant_encode_block =
derive_enum_variant_encode_block(variant_idx, &variant_fields);
quote! {
#ident::#variant_ident(#(#variant_fields,)*) => {
#variant_encode_block
}
}
}
} else {
let variant_fields = variant
.fields
.iter()
.filter(|f| do_not_ignore(f))
.map(|field| field.ident.clone().unwrap())
.collect::<Vec<_>>();
let variant_encode_block =
derive_enum_variant_encode_block(variant_idx, &variant_fields);
quote! {
#ident::#variant_ident { #(#variant_fields,)*} => {
#variant_encode_block
} else {
let variant_fields = variant
.fields
.iter()
.filter(|f| do_not_ignore(f))
.map(|field| field.ident.clone().unwrap())
.collect::<Vec<_>>();
let variant_encode_block =
derive_enum_variant_encode_block(variant_idx, &variant_fields);
quote! {
#ident::#variant_ident { #(#variant_fields,)*} => {
#variant_encode_block
}
}
}
}
});
});

let default_match_arm = variants
.iter()
Expand All @@ -176,9 +245,9 @@ fn derive_enum_encode(ident: &Ident, variants: &Punctuated<Variant, Comma>) -> T
}
}

fn derive_enum_variant_encode_block(idx: usize, fields: &[Ident]) -> TokenStream2 {
fn derive_enum_variant_encode_block(idx: u64, fields: &[Ident]) -> TokenStream2 {
quote! {
len += ::fedimint_core::encoding::Encodable::consensus_encode(&(#idx as u64), writer)?;
len += ::fedimint_core::encoding::Encodable::consensus_encode(&(#idx), writer)?;

let mut bytes = Vec::<u8>::new();
#(::fedimint_core::encoding::Encodable::consensus_encode(#fields, &mut bytes)?;)*
Expand Down Expand Up @@ -234,11 +303,8 @@ fn derive_enum_decode(ident: &Ident, variants: &Punctuated<Variant, Comma>) -> T
};
}

let non_default_match_arms = variants.iter()
.filter(|variant| !is_default_variant_enforce_valid(variant))
.enumerate()
let non_default_match_arms = non_default_variant_indices(variants).into_iter()
.map(|(variant_idx, variant)| {
let variant_idx = variant_idx as u64;
let variant_ident = variant.ident.clone();
let decode_block = derive_tuple_or_named_decode_block(
quote! { #ident::#variant_ident },
Expand Down

0 comments on commit 5bd3555

Please sign in to comment.