From e09ffa80556a9a113a10f8800103ef8e2542f3dd Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Mon, 12 Aug 2024 13:14:45 -0700 Subject: [PATCH] Improve codegen for enums with many cases This commit improves the compile time of generating bindings for enums with many cases in them (e.g. 1000+). This is done by optimizing for enums specifically rather than handling them generically like other variants which can reduce the amount of code going into rustc to O(1) instead of O(N) with the number of cases. This in turn can greatly reduce compile time. The tradeoff made in this commit is that enums are now required to have `#[repr(...)]` annotations along with no Rust-level discriminants specified. This enables the use of a `transmute` to lift a discriminant into Rust with a simple bounds check. Previously this was one large `match` statement. Closes #9081 --- crates/component-macro/src/component.rs | 369 ++++++++++++++---- .../tests/expanded/simple-wasi.rs | 1 + .../tests/expanded/simple-wasi_async.rs | 1 + .../tests/expanded/small-anonymous.rs | 2 + .../tests/expanded/small-anonymous_async.rs | 2 + .../tests/expanded/variants.rs | 4 + .../tests/expanded/variants_async.rs | 4 + crates/environ/src/component/types.rs | 18 + crates/wasmtime/src/runtime/component/mod.rs | 1 + crates/wit-bindgen/src/lib.rs | 7 + tests/all/component_model/macros.rs | 3 + 11 files changed, 334 insertions(+), 78 deletions(-) diff --git a/crates/component-macro/src/component.rs b/crates/component-macro/src/component.rs index 89a5314ad4fb..930e84165c25 100644 --- a/crates/component-macro/src/component.rs +++ b/crates/component-macro/src/component.rs @@ -16,26 +16,22 @@ mod kw { } #[derive(Debug, Copy, Clone)] -pub enum VariantStyle { - Variant, +enum Style { + Record, Enum, + Variant, } -impl fmt::Display for VariantStyle { +impl fmt::Display for Style { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Self::Variant => "variant", - Self::Enum => "enum", - }) + match self { + Style::Record => f.write_str("record"), + Style::Enum => f.write_str("enum"), + Style::Variant => f.write_str("variant"), + } } } -#[derive(Debug, Copy, Clone)] -enum Style { - Record, - Variant(VariantStyle), -} - #[derive(Debug, Clone)] enum ComponentAttr { Style(Style), @@ -50,10 +46,10 @@ impl Parse for ComponentAttr { Ok(ComponentAttr::Style(Style::Record)) } else if lookahead.peek(kw::variant) { input.parse::()?; - Ok(ComponentAttr::Style(Style::Variant(VariantStyle::Variant))) + Ok(ComponentAttr::Style(Style::Variant)) } else if lookahead.peek(Token![enum]) { input.parse::()?; - Ok(ComponentAttr::Style(Style::Variant(VariantStyle::Enum))) + Ok(ComponentAttr::Style(Style::Enum)) } else if lookahead.peek(kw::wasmtime_crate) { input.parse::()?; input.parse::()?; @@ -126,7 +122,14 @@ pub trait Expander { generics: &syn::Generics, discriminant_size: DiscriminantSize, cases: &[VariantCase], - style: VariantStyle, + wasmtime_crate: &syn::Path, + ) -> Result; + + fn expand_enum( + &self, + name: &syn::Ident, + discriminant_size: DiscriminantSize, + cases: &[VariantCase], wasmtime_crate: &syn::Path, ) -> Result; } @@ -157,7 +160,7 @@ pub fn expand(expander: &dyn Expander, input: &DeriveInput) -> Result expand_record(expander, input, &wasmtime_crate), - Style::Variant(style) => expand_variant(expander, input, style, &wasmtime_crate), + Style::Enum | Style::Variant => expand_variant(expander, input, style, &wasmtime_crate), } } @@ -199,7 +202,7 @@ fn expand_record( fn expand_variant( expander: &dyn Expander, input: &DeriveInput, - style: VariantStyle, + style: Style, wasmtime_crate: &syn::Path, ) -> Result { let name = &input.ident; @@ -253,8 +256,9 @@ fn expand_variant( containing variants with {}", style, match style { - VariantStyle::Variant => "at most one unnamed field each", - VariantStyle::Enum => "no fields", + Style::Variant => "at most one unnamed field each", + Style::Enum => "no fields", + Style::Record => unreachable!(), } ), )) @@ -265,14 +269,77 @@ fn expand_variant( ) .collect::>>()?; - expander.expand_variant( - &input.ident, - &input.generics, - discriminant_size, - &cases, - style, - wasmtime_crate, - ) + match style { + Style::Variant => expander.expand_variant( + &input.ident, + &input.generics, + discriminant_size, + &cases, + wasmtime_crate, + ), + Style::Enum => { + validate_enum(input, &body, discriminant_size)?; + expander.expand_enum(&input.ident, discriminant_size, &cases, wasmtime_crate) + } + Style::Record => unreachable!(), + } +} + +/// Validates component model `enum` definitions are accompanied with +/// appropriate `#[repr]` tags. Additionally requires that no discriminants are +/// listed to ensure that unsafe transmutes in lift are valid. +fn validate_enum(input: &DeriveInput, body: &syn::DataEnum, size: DiscriminantSize) -> Result<()> { + if !input.generics.params.is_empty() { + return Err(Error::new_spanned( + &input.generics.params, + "cannot have generics on an `enum`", + )); + } + if let Some(clause) = &input.generics.where_clause { + return Err(Error::new_spanned( + clause, + "cannot have a where clause on an `enum`", + )); + } + let expected_discr = match size { + DiscriminantSize::Size1 => "u8", + DiscriminantSize::Size2 => "u16", + DiscriminantSize::Size4 => "u32", + }; + let mut found_repr = false; + for attr in input.attrs.iter() { + if !attr.meta.path().is_ident("repr") { + continue; + } + let list = attr.meta.require_list()?; + found_repr = true; + if list.tokens.to_string() != expected_discr { + return Err(Error::new_spanned( + &list.tokens, + format!( + "expected `repr({expected_discr})`, found `repr({})`", + list.tokens + ), + )); + } + } + if !found_repr { + return Err(Error::new_spanned( + &body.enum_token, + format!("missing required `#[repr({expected_discr})]`"), + )); + } + + for case in body.variants.iter() { + if let Some((_, expr)) = &case.discriminant { + return Err(Error::new_spanned( + expr, + "cannot have an explicit discriminant", + )); + } + } + + Ok(()) } fn expand_record_for_component_type( @@ -452,7 +519,6 @@ impl Expander for LiftExpander { generics: &syn::Generics, discriminant_size: DiscriminantSize, cases: &[VariantCase], - style: VariantStyle, wt: &syn::Path, ) -> Result { let internal = quote!(#wt::component::__internal); @@ -460,23 +526,13 @@ impl Expander for LiftExpander { let mut lifts = TokenStream::new(); let mut loads = TokenStream::new(); - let interface_type_variant = match style { - VariantStyle::Variant => quote!(Variant), - VariantStyle::Enum => quote!(Enum), - }; - for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() { let index_u32 = u32::try_from(index).unwrap(); let index_quoted = quote(discriminant_size, index); if let Some(ty) = ty { - let payload_ty = match style { - VariantStyle::Variant => { - quote!(ty.cases[#index].unwrap_or_else(#internal::bad_type_info)) - } - VariantStyle::Enum => unreachable!(), - }; + let payload_ty = quote!(ty.cases[#index].unwrap_or_else(#internal::bad_type_info)); lifts.extend( quote!(#index_u32 => Self::#ident(<#ty as #wt::component::Lift>::lift( cx, #payload_ty, unsafe { &src.payload.#ident } @@ -506,7 +562,7 @@ impl Expander for LiftExpander { let extract_ty = quote! { let ty = match ty { - #internal::InterfaceType::#interface_type_variant(i) => &cx.types[i], + #internal::InterfaceType::Variant(i) => &cx.types[i], _ => #internal::bad_type_info(), }; }; @@ -548,6 +604,75 @@ impl Expander for LiftExpander { Ok(expanded) } + + fn expand_enum( + &self, + name: &syn::Ident, + discriminant_size: DiscriminantSize, + cases: &[VariantCase], + wt: &syn::Path, + ) -> Result { + let internal = quote!(#wt::component::__internal); + + let (from_bytes, discrim_ty) = match discriminant_size { + DiscriminantSize::Size1 => (quote!(bytes[0]), quote!(u8)), + DiscriminantSize::Size2 => ( + quote!(u16::from_le_bytes(bytes[0..2].try_into()?)), + quote!(u16), + ), + DiscriminantSize::Size4 => ( + quote!(u32::from_le_bytes(bytes[0..4].try_into()?)), + quote!(u32), + ), + }; + let discrim_limit = proc_macro2::Literal::usize_unsuffixed(cases.len()); + + let extract_ty = quote! { + let ty = match ty { + #internal::InterfaceType::Variant(i) => &cx.types[i], + _ => #internal::bad_type_info(), + }; + }; + + let expanded = quote! { + unsafe impl #wt::component::Lift for #name { + #[inline] + fn lift( + cx: &mut #internal::LiftContext<'_>, + ty: #internal::InterfaceType, + src: &Self::Lower, + ) -> #internal::anyhow::Result { + #extract_ty + let discrim = src.tag.get_u32(); + if discrim >= #discrim_limit { + #internal::anyhow::bail!("unexpected discriminant: {discrim}"); + } + Ok(unsafe { + #internal::transmute::<#discrim_ty, #name>(discrim as #discrim_ty) + }) + } + + #[inline] + fn load( + cx: &mut #internal::LiftContext<'_>, + ty: #internal::InterfaceType, + bytes: &[u8], + ) -> #internal::anyhow::Result { + let align = ::ALIGN32; + debug_assert!((bytes.as_ptr() as usize) % (align as usize) == 0); + let discrim = #from_bytes; + if discrim >= #discrim_limit { + #internal::anyhow::bail!("unexpected discriminant: {discrim}"); + } + Ok(unsafe { + #internal::transmute::<#discrim_ty, #name>(discrim) + }) + } + } + }; + + Ok(expanded) + } } pub struct LowerExpander; @@ -627,7 +752,6 @@ impl Expander for LowerExpander { generics: &syn::Generics, discriminant_size: DiscriminantSize, cases: &[VariantCase], - style: VariantStyle, wt: &syn::Path, ) -> Result { let internal = quote!(#wt::component::__internal); @@ -635,11 +759,6 @@ impl Expander for LowerExpander { let mut lowers = TokenStream::new(); let mut stores = TokenStream::new(); - let interface_type_variant = match style { - VariantStyle::Variant => quote!(Variant), - VariantStyle::Enum => quote!(Enum), - }; - for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() { let index_u32 = u32::try_from(index).unwrap(); @@ -652,12 +771,7 @@ impl Expander for LowerExpander { let store; if ty.is_some() { - let ty = match style { - VariantStyle::Variant => { - quote!(ty.cases[#index].unwrap_or_else(#internal::bad_type_info)) - } - VariantStyle::Enum => unreachable!(), - }; + let ty = quote!(ty.cases[#index].unwrap_or_else(#internal::bad_type_info)); pattern = quote!(Self::#ident(value)); lower = quote!(value.lower(cx, #ty, dst)); store = quote!(value.store( @@ -693,7 +807,7 @@ impl Expander for LowerExpander { let extract_ty = quote! { let ty = match ty { - #internal::InterfaceType::#interface_type_variant(i) => &cx.types[i], + #internal::InterfaceType::Variant(i) => &cx.types[i], _ => #internal::bad_type_info(), }; }; @@ -731,6 +845,63 @@ impl Expander for LowerExpander { Ok(expanded) } + + fn expand_enum( + &self, + name: &syn::Ident, + discriminant_size: DiscriminantSize, + _cases: &[VariantCase], + wt: &syn::Path, + ) -> Result { + let internal = quote!(#wt::component::__internal); + + let extract_ty = quote! { + let ty = match ty { + #internal::InterfaceType::Enum(i) => &cx.types[i], + _ => #internal::bad_type_info(), + }; + }; + + let (size, ty) = match discriminant_size { + DiscriminantSize::Size1 => (1, quote!(u8)), + DiscriminantSize::Size2 => (2, quote!(u16)), + DiscriminantSize::Size4 => (4, quote!(u32)), + }; + let size = proc_macro2::Literal::usize_unsuffixed(size); + + let expanded = quote! { + unsafe impl #wt::component::Lower for #name { + #[inline] + fn lower( + &self, + cx: &mut #internal::LowerContext<'_, T>, + ty: #internal::InterfaceType, + dst: &mut core::mem::MaybeUninit, + ) -> #internal::anyhow::Result<()> { + #extract_ty + #internal::map_maybe_uninit!(dst.tag) + .write(#wt::ValRaw::u32(*self as u32)); + Ok(()) + } + + #[inline] + fn store( + &self, + cx: &mut #internal::LowerContext<'_, T>, + ty: #internal::InterfaceType, + mut offset: usize + ) -> #internal::anyhow::Result<()> { + #extract_ty + debug_assert!(offset % (::ALIGN32 as usize) == 0); + let discrim = *self as #ty; + *cx.get::<#size>(offset) = discrim.to_le_bytes(); + Ok(()) + } + } + }; + + Ok(expanded) + } } pub struct ComponentTypeExpander; @@ -773,7 +944,6 @@ impl Expander for ComponentTypeExpander { generics: &syn::Generics, _discriminant_size: DiscriminantSize, cases: &[VariantCase], - style: VariantStyle, wt: &syn::Path, ) -> Result { let internal = quote!(#wt::component::__internal); @@ -794,17 +964,9 @@ impl Expander for ComponentTypeExpander { if let Some(ty) = ty { abi_list.extend(quote!(Some(<#ty as #wt::component::ComponentType>::ABI),)); - case_names_and_checks.extend(match style { - VariantStyle::Variant => { - quote!((#name, Some(<#ty as #wt::component::ComponentType>::typecheck)),) - } - VariantStyle::Enum => { - return Err(Error::new( - ident.span(), - "payloads are not permitted for `enum` cases", - )) - } - }); + case_names_and_checks.extend( + quote!((#name, Some(<#ty as #wt::component::ComponentType>::typecheck)),), + ); let generic = format_ident!("T{}", index); @@ -816,21 +978,11 @@ impl Expander for ComponentTypeExpander { unique_types.insert(ty); } else { abi_list.extend(quote!(None,)); - case_names_and_checks.extend(match style { - VariantStyle::Variant => { - quote!((#name, None),) - } - VariantStyle::Enum => quote!(#name,), - }); + case_names_and_checks.extend(quote!((#name, None),)); lower_payload_case_declarations.extend(quote!(#ident: [#wt::ValRaw; 0],)); } } - let typecheck = match style { - VariantStyle::Variant => quote!(typecheck_variant), - VariantStyle::Enum => quote!(typecheck_enum), - }; - let generics = add_trait_bounds(generics, parse_quote!(#wt::component::ComponentType)); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let lower = format_ident!("Lower{}", name); @@ -869,7 +1021,7 @@ impl Expander for ComponentTypeExpander { ty: &#internal::InterfaceType, types: &#internal::InstanceType<'_>, ) -> #internal::anyhow::Result<()> { - #internal::#typecheck(ty, types, &[#case_names_and_checks]) + #internal::typecheck_variant(ty, types, &[#case_names_and_checks]) } const ABI: #internal::CanonicalAbiInfo = @@ -883,6 +1035,67 @@ impl Expander for ComponentTypeExpander { Ok(quote!(const _: () = { #expanded };)) } + + fn expand_enum( + &self, + name: &syn::Ident, + _discriminant_size: DiscriminantSize, + cases: &[VariantCase], + wt: &syn::Path, + ) -> Result { + let internal = quote!(#wt::component::__internal); + + let mut case_names = TokenStream::new(); + let mut abi_list = TokenStream::new(); + + for VariantCase { attrs, ident, ty } in cases.iter() { + let rename = find_rename(attrs)?; + + let name = rename.unwrap_or_else(|| syn::LitStr::new(&ident.to_string(), ident.span())); + + if ty.is_some() { + return Err(Error::new( + ident.span(), + "payloads are not permitted for `enum` cases", + )); + } + abi_list.extend(quote!(None,)); + case_names.extend(quote!(#name,)); + } + + let lower = format_ident!("Lower{}", name); + + let cases_len = cases.len(); + let expanded = quote! { + #[doc(hidden)] + #[derive(Clone, Copy)] + #[repr(C)] + pub struct #lower { + tag: #wt::ValRaw, + } + + unsafe impl #wt::component::ComponentType for #name { + type Lower = #lower; + + #[inline] + fn typecheck( + ty: &#internal::InterfaceType, + types: &#internal::InstanceType<'_>, + ) -> #internal::anyhow::Result<()> { + #internal::typecheck_enum(ty, types, &[#case_names]) + } + + const ABI: #internal::CanonicalAbiInfo = + #internal::CanonicalAbiInfo::enum_(#cases_len); + } + + unsafe impl #internal::ComponentVariant for #name { + const CASES: &'static [Option<#internal::CanonicalAbiInfo>] = &[#abi_list]; + } + }; + + Ok(quote!(const _: () = { #expanded };)) + } } #[derive(Debug)] diff --git a/crates/component-macro/tests/expanded/simple-wasi.rs b/crates/component-macro/tests/expanded/simple-wasi.rs index b35cd47d4c64..247346794bea 100644 --- a/crates/component-macro/tests/expanded/simple-wasi.rs +++ b/crates/component-macro/tests/expanded/simple-wasi.rs @@ -114,6 +114,7 @@ pub mod foo { #[derive(wasmtime::component::Lower)] #[component(enum)] #[derive(Clone, Copy, Eq, PartialEq)] + #[repr(u8)] pub enum Errno { #[component(name = "e")] E, diff --git a/crates/component-macro/tests/expanded/simple-wasi_async.rs b/crates/component-macro/tests/expanded/simple-wasi_async.rs index 0d8706e7ffc1..21887ae6be21 100644 --- a/crates/component-macro/tests/expanded/simple-wasi_async.rs +++ b/crates/component-macro/tests/expanded/simple-wasi_async.rs @@ -121,6 +121,7 @@ pub mod foo { #[derive(wasmtime::component::Lower)] #[component(enum)] #[derive(Clone, Copy, Eq, PartialEq)] + #[repr(u8)] pub enum Errno { #[component(name = "e")] E, diff --git a/crates/component-macro/tests/expanded/small-anonymous.rs b/crates/component-macro/tests/expanded/small-anonymous.rs index 309b6b999325..be5d84153b66 100644 --- a/crates/component-macro/tests/expanded/small-anonymous.rs +++ b/crates/component-macro/tests/expanded/small-anonymous.rs @@ -105,6 +105,7 @@ pub mod foo { #[derive(wasmtime::component::Lower)] #[component(enum)] #[derive(Clone, Copy, Eq, PartialEq)] + #[repr(u8)] pub enum Error { #[component(name = "success")] Success, @@ -207,6 +208,7 @@ pub mod exports { #[derive(wasmtime::component::Lower)] #[component(enum)] #[derive(Clone, Copy, Eq, PartialEq)] + #[repr(u8)] pub enum Error { #[component(name = "success")] Success, diff --git a/crates/component-macro/tests/expanded/small-anonymous_async.rs b/crates/component-macro/tests/expanded/small-anonymous_async.rs index 2e4d950ed208..a53f57452821 100644 --- a/crates/component-macro/tests/expanded/small-anonymous_async.rs +++ b/crates/component-macro/tests/expanded/small-anonymous_async.rs @@ -112,6 +112,7 @@ pub mod foo { #[derive(wasmtime::component::Lower)] #[component(enum)] #[derive(Clone, Copy, Eq, PartialEq)] + #[repr(u8)] pub enum Error { #[component(name = "success")] Success, @@ -220,6 +221,7 @@ pub mod exports { #[derive(wasmtime::component::Lower)] #[component(enum)] #[derive(Clone, Copy, Eq, PartialEq)] + #[repr(u8)] pub enum Error { #[component(name = "success")] Success, diff --git a/crates/component-macro/tests/expanded/variants.rs b/crates/component-macro/tests/expanded/variants.rs index 6ef1217a7943..0b5a684aeeda 100644 --- a/crates/component-macro/tests/expanded/variants.rs +++ b/crates/component-macro/tests/expanded/variants.rs @@ -105,6 +105,7 @@ pub mod foo { #[derive(wasmtime::component::Lower)] #[component(enum)] #[derive(Clone, Copy, Eq, PartialEq)] + #[repr(u8)] pub enum E1 { #[component(name = "a")] A, @@ -313,6 +314,7 @@ pub mod foo { #[derive(wasmtime::component::Lower)] #[component(enum)] #[derive(Clone, Copy, Eq, PartialEq)] + #[repr(u8)] pub enum MyErrno { #[component(name = "bad1")] Bad1, @@ -846,6 +848,7 @@ pub mod exports { #[derive(wasmtime::component::Lower)] #[component(enum)] #[derive(Clone, Copy, Eq, PartialEq)] + #[repr(u8)] pub enum E1 { #[component(name = "a")] A, @@ -1109,6 +1112,7 @@ pub mod exports { #[derive(wasmtime::component::Lower)] #[component(enum)] #[derive(Clone, Copy, Eq, PartialEq)] + #[repr(u8)] pub enum MyErrno { #[component(name = "bad1")] Bad1, diff --git a/crates/component-macro/tests/expanded/variants_async.rs b/crates/component-macro/tests/expanded/variants_async.rs index a5b16f1884c7..eada61a58e23 100644 --- a/crates/component-macro/tests/expanded/variants_async.rs +++ b/crates/component-macro/tests/expanded/variants_async.rs @@ -112,6 +112,7 @@ pub mod foo { #[derive(wasmtime::component::Lower)] #[component(enum)] #[derive(Clone, Copy, Eq, PartialEq)] + #[repr(u8)] pub enum E1 { #[component(name = "a")] A, @@ -320,6 +321,7 @@ pub mod foo { #[derive(wasmtime::component::Lower)] #[component(enum)] #[derive(Clone, Copy, Eq, PartialEq)] + #[repr(u8)] pub enum MyErrno { #[component(name = "bad1")] Bad1, @@ -862,6 +864,7 @@ pub mod exports { #[derive(wasmtime::component::Lower)] #[component(enum)] #[derive(Clone, Copy, Eq, PartialEq)] + #[repr(u8)] pub enum E1 { #[component(name = "a")] A, @@ -1125,6 +1128,7 @@ pub mod exports { #[derive(wasmtime::component::Lower)] #[component(enum)] #[derive(Clone, Copy, Eq, PartialEq)] + #[repr(u8)] pub enum MyErrno { #[component(name = "bad1")] Bad1, diff --git a/crates/environ/src/component/types.rs b/crates/environ/src/component/types.rs index 998075840e0e..5e70bd6bec48 100644 --- a/crates/environ/src/component/types.rs +++ b/crates/environ/src/component/types.rs @@ -735,6 +735,24 @@ impl CanonicalAbiInfo { } } + /// Calculates ABI information for an enum with `cases` cases. + pub const fn enum_(cases: usize) -> CanonicalAbiInfo { + // NB: this is basically a duplicate definition of + // `CanonicalAbiInfo::variant`, these should be kept in sync. + + let discrim_size = match DiscriminantSize::from_count(cases) { + Some(size) => size.byte_size(), + None => unreachable!(), + }; + CanonicalAbiInfo { + size32: discrim_size, + align32: discrim_size, + size64: discrim_size, + align64: discrim_size, + flat_count: Some(1), + } + } + /// Returns the flat count of this ABI information so long as the count /// doesn't exceed the `max` specified. pub fn flat_count(&self, max: usize) -> Option { diff --git a/crates/wasmtime/src/runtime/component/mod.rs b/crates/wasmtime/src/runtime/component/mod.rs index fa8bb2fd5cd2..bc466ea58912 100644 --- a/crates/wasmtime/src/runtime/component/mod.rs +++ b/crates/wasmtime/src/runtime/component/mod.rs @@ -143,6 +143,7 @@ pub mod __internal { pub use anyhow; #[cfg(feature = "async")] pub use async_trait::async_trait; + pub use core::mem::transmute; pub use wasmtime_environ; pub use wasmtime_environ::component::{CanonicalAbiInfo, ComponentTypes, InterfaceType}; } diff --git a/crates/wit-bindgen/src/lib.rs b/crates/wit-bindgen/src/lib.rs index 887817a6b877..fe91fc17e13d 100644 --- a/crates/wit-bindgen/src/lib.rs +++ b/crates/wit-bindgen/src/lib.rs @@ -1930,6 +1930,13 @@ impl<'a> InterfaceGenerator<'a> { self.push_str(&derives.into_iter().collect::>().join(", ")); self.push_str(")]\n"); + let repr = match enum_.cases.len().ilog2() { + 0..8 => "u8", + 8..16 => "u16", + _ => "u32", + }; + uwriteln!(self.src, "#[repr({repr})]"); + self.push_str(&format!("pub enum {name} {{\n")); for case in enum_.cases.iter() { self.rustdoc(&case.docs); diff --git a/tests/all/component_model/macros.rs b/tests/all/component_model/macros.rs index a7b9726ca11c..9e7dc50386f9 100644 --- a/tests/all/component_model/macros.rs +++ b/tests/all/component_model/macros.rs @@ -236,6 +236,7 @@ fn variant_derive() -> Result<()> { fn enum_derive() -> Result<()> { #[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)] #[component(enum)] + #[repr(u8)] enum Foo { #[component(name = "foo-bar-baz")] A, @@ -299,6 +300,7 @@ fn enum_derive() -> Result<()> { #[add_variants(257)] #[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)] #[component(enum)] + #[repr(u16)] enum Many {} let component = Component::new( @@ -330,6 +332,7 @@ fn enum_derive() -> Result<()> { // #[add_variants(65537)] // #[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)] // #[component(enum)] + // #[repr(u32)] // enum ManyMore {} Ok(())