From d10fbfc6ffb9a2adf2faa001cd614bd7dfe20f02 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Wed, 6 Sep 2023 17:37:07 +0200 Subject: [PATCH] allow deriving `CheckedBitPattern` for enums with fields (#171) * simplify `ToTokens` impl for `Representation` Instead of collecting the representation and modifier into `Option`s and determining whether a comma is needed manually, we can use the `Puncutuated` struct which handles commas automatically. This will also make emitting the `align` modifier in the future easier. * emit alignment modifier This is required for correctly implementing `CheckedBitPattern` because we need the layout of the type and its `Bits` type to have the same layout. * add unit test for `#[repr]` parsing * allow multiple alignment modifiers According to RFC #1358, if multiple alignment modifiers are specified, the resulting alignment is the maximum of all alignment modifiers. * actually return the error we just created * factor out the integer Repr's into their own type This is a preparation step for adding support for `#[repr(C, int)]`. * allow parsing `#[repr(C, int)]` This can be used on enums with fields. * derive `CheckedBitPattern` for enums with fields The implementation mostly mirrors the desugaring described at https://doc.rust-lang.org/reference/type-layout.html * add comments and rename some idents * update error message * update docs for `CheckedBitPattern` derive * add new nested test case, change generated type naming scheme * fix wrong comment * small nit --------- Co-authored-by: Gray Olson --- derive/src/lib.rs | 13 +- derive/src/traits.rs | 611 ++++++++++++++++++++++++++++++++++-------- derive/tests/basic.rs | 196 +++++++++++++- 3 files changed, 706 insertions(+), 114 deletions(-) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 7ca9527..20f4b0f 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -218,7 +218,7 @@ pub fn derive_zeroable( /// - The struct must contain no generic parameters /// /// If applied to an enum: -/// - The enum must be explicit `#[repr(Int)]` +/// - The enum must be explicit `#[repr(Int)]`, `#[repr(C)]`, or both /// - All variants must be fieldless /// - The enum must contain no generic parameters #[proc_macro_derive(NoUninit)] @@ -237,16 +237,17 @@ pub fn derive_no_uninit( /// for the `CheckedBitPattern` trait and derives the required `Bits` type /// definition and `is_valid_bit_pattern` method for the type automatically. /// -/// The following constraints need to be satisfied for the macro to succeed -/// (the rest of the constraints are guaranteed by the `CheckedBitPattern` -/// subtrait bounds, i.e. are guaranteed by the requirements of the `NoUninit` -/// trait which `CheckedBitPattern` is a subtrait of): +/// The following constraints need to be satisfied for the macro to succeed: /// /// If applied to a struct: /// - All fields must implement `CheckedBitPattern` +/// - The struct must be `#[repr(C)]` or `#[repr(transparent)]` +/// - The struct must contain no generic parameters /// /// If applied to an enum: -/// - All requirements already checked by `NoUninit`, just impls the trait +/// - The enum must be explicit `#[repr(Int)]` +/// - All fields in variants must implement `CheckedBitPattern` +/// - The enum must contain no generic parameters #[proc_macro_derive(CheckedBitPattern)] pub fn derive_maybe_pod( input: proc_macro::TokenStream, diff --git a/derive/src/traits.rs b/derive/src/traits.rs index f513356..c0ddf61 100644 --- a/derive/src/traits.rs +++ b/derive/src/traits.rs @@ -1,4 +1,6 @@ #![allow(unused_imports)] +use std::{cmp, convert::TryFrom}; + use proc_macro2::{Ident, Span, TokenStream, TokenTree}; use quote::{quote, quote_spanned, ToTokens}; use syn::{ @@ -204,11 +206,19 @@ impl Derivable for CheckedBitPattern { Repr::C | Repr::Transparent => Ok(()), _ => bail!("CheckedBitPattern derive requires the struct to be #[repr(C)] or #[repr(transparent)]"), }, - Data::Enum(_) => if repr.repr.is_integer() { - Ok(()) - } else { - bail!("CheckedBitPattern requires the enum to be an explicit #[repr(Int)]") - }, + Data::Enum(DataEnum { variants,.. }) => { + if !enum_has_fields(variants.iter()){ + if repr.repr.is_integer() { + Ok(()) + } else { + bail!("CheckedBitPattern requires the enum to be an explicit #[repr(Int)]") + } + } else if matches!(repr.repr, Repr::Rust) { + bail!("CheckedBitPattern requires an explicit repr annotation because `repr(Rust)` doesn't have a specified type layout") + } else { + Ok(()) + } + } Data::Union(_) => bail!("CheckedBitPattern can only be derived on enums and structs") } } @@ -235,7 +245,9 @@ impl Derivable for CheckedBitPattern { Data::Struct(DataStruct { fields, .. }) => { generate_checked_bit_pattern_struct(&input.ident, fields, &input.attrs) } - Data::Enum(_) => generate_checked_bit_pattern_enum(input), + Data::Enum(DataEnum { variants, .. }) => { + generate_checked_bit_pattern_enum(input, variants) + } Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */ } } @@ -347,13 +359,20 @@ impl Derivable for Contiguous { fn trait_impl(input: &DeriveInput) -> Result<(TokenStream, TokenStream)> { let repr = get_repr(&input.attrs)?; - let integer_ty = if let Some(integer_ty) = repr.repr.as_integer_type() { + let integer_ty = if let Some(integer_ty) = repr.repr.as_integer() { integer_ty } else { bail!("Contiguous requires the enum to be #[repr(Int)]"); }; let variants = get_enum_variants(input)?; + if enum_has_fields(variants.clone()) { + return Err(Error::new_spanned( + &input, + "Only fieldless enums are supported", + )); + } + let mut variants_with_discriminator = VariantDiscriminantIterator::new(variants); @@ -426,7 +445,7 @@ fn get_fields(input: &DeriveInput) -> Result { fn get_enum_variants<'a>( input: &'a DeriveInput, -) -> Result + 'a> { +) -> Result + Clone + 'a> { if let Data::Enum(DataEnum { variants, .. }) = &input.data { Ok(variants.iter()) } else { @@ -486,11 +505,21 @@ fn generate_checked_bit_pattern_struct( } fn generate_checked_bit_pattern_enum( - input: &DeriveInput, + input: &DeriveInput, variants: &Punctuated, +) -> Result<(TokenStream, TokenStream)> { + if enum_has_fields(variants.iter()) { + generate_checked_bit_pattern_enum_with_fields(input, variants) + } else { + generate_checked_bit_pattern_enum_without_fields(input, variants) + } +} + +fn generate_checked_bit_pattern_enum_without_fields( + input: &DeriveInput, variants: &Punctuated, ) -> Result<(TokenStream, TokenStream)> { let span = input.span(); let mut variants_with_discriminant = - VariantDiscriminantIterator::new(get_enum_variants(input)?); + VariantDiscriminantIterator::new(variants.iter()); let (min, max, count) = variants_with_discriminant.try_fold( (i64::max_value(), i64::min_value(), 0), @@ -514,13 +543,12 @@ fn generate_checked_bit_pattern_enum( quote!(*bits >= #min_lit && *bits <= #max_lit) } else { // not contiguous range, check for each - let variant_lits = - VariantDiscriminantIterator::new(get_enum_variants(input)?) - .map(|res| { - let variant = res?; - Ok(LitInt::new(&format!("{}", variant), span)) - }) - .collect::>>()?; + let variant_lits = VariantDiscriminantIterator::new(variants.iter()) + .map(|res| { + let variant = res?; + Ok(LitInt::new(&format!("{}", variant), span)) + }) + .collect::>>()?; // count is at least 1 let first = &variant_lits[0]; @@ -530,11 +558,11 @@ fn generate_checked_bit_pattern_enum( }; let repr = get_repr(&input.attrs)?; - let integer_ty = repr.repr.as_integer_type().unwrap(); // should be checked in attr check already + let integer = repr.repr.as_integer().unwrap(); // should be checked in attr check already Ok(( quote!(), quote! { - type Bits = #integer_ty; + type Bits = #integer; #[inline] #[allow(clippy::double_comparisons)] @@ -545,6 +573,244 @@ fn generate_checked_bit_pattern_enum( )) } +fn generate_checked_bit_pattern_enum_with_fields( + input: &DeriveInput, variants: &Punctuated, +) -> Result<(TokenStream, TokenStream)> { + let representation = get_repr(&input.attrs)?; + let vis = &input.vis; + + let derive_dbg = + quote!(#[cfg_attr(not(target_arch = "spirv"), derive(Debug))]); + + match representation.repr { + Repr::Rust => unreachable!(), + repr @ (Repr::C | Repr::CWithDiscriminant(_)) => { + let integer = match repr { + Repr::C => quote!(::core::ffi::c_int), + Repr::CWithDiscriminant(integer) => quote!(#integer), + _ => unreachable!(), + }; + let input_ident = &input.ident; + + let bits_repr = Representation { repr: Repr::C, ..representation }; + + // the enum manually re-configured as the actual tagged union it represents, + // thus circumventing the requirements rust imposes on the tag even when using + // #[repr(C)] enum layout + // see: https://doc.rust-lang.org/reference/type-layout.html#reprc-enums-with-fields + let bits_ty_ident = Ident::new(&format!("{input_ident}Bits"), input.span()); + + // the variants union part of the tagged union. These get put into a union which gets the + // AnyBitPattern derive applied to it, thus checking that the fields of the union obey the requriements of AnyBitPattern. + // The types that actually go in the union are one more level of indirection deep: we generate new structs for each variant + // (`variant_struct_definitions`) which themselves have the `CheckedBitPattern` derive applied, thus generating `{variant_struct_ident}Bits` + // structs, which are the ones that go into this union. + let variants_union_ident = + Ident::new(&format!("{}Variants", input.ident), input.span()); + + let variant_struct_idents = variants + .iter() + .map(|v| Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span())); + + let variant_struct_definitions = + variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| { + let fields = v.fields.iter().map(|v| &v.ty); + + quote! { + #[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::CheckedBitPattern)] + #[repr(C)] + #vis struct #variant_struct_ident(#(#fields),*); + } + }); + + let union_fields = + variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| { + let variant_struct_bits_ident = + Ident::new(&format!("{variant_struct_ident}Bits"), input.span()); + let field_ident = &v.ident; + quote! { + #field_ident: #variant_struct_bits_ident + } + }); + + let variant_checks = variant_struct_idents + .clone() + .zip(VariantDiscriminantIterator::new(variants.iter())) + .zip(variants.iter()) + .map(|((variant_struct_ident, discriminant), v)| -> Result<_> { + let discriminant = discriminant?; + let discriminant = LitInt::new(&discriminant.to_string(), v.span()); + let ident = &v.ident; + Ok(quote! { + #discriminant => { + let payload = unsafe { &bits.payload.#ident }; + <#variant_struct_ident as ::bytemuck::CheckedBitPattern>::is_valid_bit_pattern(payload) + } + }) + }) + .collect::>>()?; + + Ok(( + quote! { + #[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::AnyBitPattern)] + #derive_dbg + #bits_repr + #vis struct #bits_ty_ident { + tag: #integer, + payload: #variants_union_ident, + } + + #[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::AnyBitPattern)] + #[repr(C)] + #[allow(non_snake_case)] + #vis union #variants_union_ident { + #(#union_fields,)* + } + + #[cfg(not(target_arch = "spirv"))] + impl ::core::fmt::Debug for #variants_union_ident { + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { + let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#variants_union_ident)); + ::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct) + } + } + + #(#variant_struct_definitions)* + }, + quote! { + type Bits = #bits_ty_ident; + + #[inline] + #[allow(clippy::double_comparisons)] + fn is_valid_bit_pattern(bits: &Self::Bits) -> bool { + match bits.tag { + #(#variant_checks)* + _ => false, + } + } + }, + )) + } + Repr::Transparent => { + if variants.len() != 1 { + bail!("enums with more than one variant cannot be transparent") + } + + let variant = &variants[0]; + + let bits_ty = Ident::new(&format!("{}Bits", input.ident), input.span()); + let fields = variant.fields.iter().map(|v| &v.ty); + + Ok(( + quote! { + #[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::CheckedBitPattern)] + #[repr(C)] + #vis struct #bits_ty(#(#fields),*); + }, + quote! { + type Bits = <#bits_ty as ::bytemuck::CheckedBitPattern>::Bits; + + #[inline] + #[allow(clippy::double_comparisons)] + fn is_valid_bit_pattern(bits: &Self::Bits) -> bool { + <#bits_ty as ::bytemuck::CheckedBitPattern>::is_valid_bit_pattern(bits) + } + }, + )) + } + Repr::Integer(integer) => { + let bits_repr = Representation { repr: Repr::C, ..representation }; + let input_ident = &input.ident; + + // the enum manually re-configured as the union it represents. such a union is the union of variants + // as a repr(c) struct with the discriminator type inserted at the beginning. + // in our case we union the `Bits` representation of each variant rather than the variant itself, which we generate + // via a nested `CheckedBitPattern` derive on the `variant_struct_definitions` generated below. + // + // see: https://doc.rust-lang.org/reference/type-layout.html#primitive-representation-of-enums-with-fields + let bits_ty_ident = Ident::new(&format!("{input_ident}Bits"), input.span()); + + let variant_struct_idents = variants + .iter() + .map(|v| Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span())); + + let variant_struct_definitions = + variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| { + let fields = v.fields.iter().map(|v| &v.ty); + + // adding the discriminant repr integer as first field, as described above + quote! { + #[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::CheckedBitPattern)] + #[repr(C)] + #vis struct #variant_struct_ident(#integer, #(#fields),*); + } + }); + + let union_fields = + variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| { + let variant_struct_bits_ident = + Ident::new(&format!("{variant_struct_ident}Bits"), input.span()); + let field_ident = &v.ident; + quote! { + #field_ident: #variant_struct_bits_ident + } + }); + + let variant_checks = variant_struct_idents + .clone() + .zip(VariantDiscriminantIterator::new(variants.iter())) + .zip(variants.iter()) + .map(|((variant_struct_ident, discriminant), v)| -> Result<_> { + let discriminant = discriminant?; + let discriminant = LitInt::new(&discriminant.to_string(), v.span()); + let ident = &v.ident; + Ok(quote! { + #discriminant => { + let payload = unsafe { &bits.#ident }; + <#variant_struct_ident as ::bytemuck::CheckedBitPattern>::is_valid_bit_pattern(payload) + } + }) + }) + .collect::>>()?; + + Ok(( + quote! { + #[derive(::core::clone::Clone, ::core::marker::Copy, ::bytemuck::AnyBitPattern)] + #bits_repr + #[allow(non_snake_case)] + #vis union #bits_ty_ident { + __tag: #integer, + #(#union_fields,)* + } + + #[cfg(not(target_arch = "spirv"))] + impl ::core::fmt::Debug for #bits_ty_ident { + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { + let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty_ident)); + ::core::fmt::DebugStruct::field(&mut debug_struct, "tag", unsafe { &self.__tag }); + ::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct) + } + } + + #(#variant_struct_definitions)* + }, + quote! { + type Bits = #bits_ty_ident; + + #[inline] + #[allow(clippy::double_comparisons)] + fn is_valid_bit_pattern(bits: &Self::Bits) -> bool { + match unsafe { bits.__tag } { + #(#variant_checks)* + _ => false, + } + } + }, + )) + } + } +} + /// Check that a struct has no padding by asserting that the size of the struct /// is equal to the sum of the size of it's fields fn generate_assert_no_padding(input: &DeriveInput) -> Result { @@ -637,9 +903,9 @@ fn get_repr(attributes: &[Attribute]) -> Result { _ => bail!("conflicting representation hints"), }, align: match (a.align, b.align) { + (Some(a), Some(b)) => Some(cmp::max(a, b)), (a, None) => a, (None, b) => b, - _ => bail!("conflicting representation hints"), }, }) }) @@ -665,111 +931,168 @@ macro_rules! mk_repr {( $Xn:ident => $xn:ident ),* $(,)? ) => ( - #[derive(Clone, Copy, PartialEq)] - enum Repr { - Rust, - C, - Transparent, + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + enum IntegerRepr { $($Xn),* } - impl Repr { - fn is_integer(self) -> bool { - match self { - Repr::Rust | Repr::C | Repr::Transparent => false, - _ => true, + impl<'a> TryFrom<&'a str> for IntegerRepr { + type Error = &'a str; + + fn try_from(value: &'a str) -> std::result::Result { + match value { + $( + stringify!($xn) => Ok(Self::$Xn), + )* + _ => Err(value), } } + } - fn as_integer_type(self) -> Option { + impl ToTokens for IntegerRepr { + fn to_tokens(&self, tokens: &mut TokenStream) { match self { - Repr::Rust | Repr::C | Repr::Transparent => None, $( - Repr::$Xn => Some(quote! { ::core::primitive::$xn }), + Self::$Xn => tokens.extend(quote!($xn)), )* } } } +)} +use mk_repr; - #[derive(Clone, Copy)] - struct Representation { - packed: Option, - align: Option, - repr: Repr, +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Repr { + Rust, + C, + Transparent, + Integer(IntegerRepr), + CWithDiscriminant(IntegerRepr), +} + +impl Repr { + fn is_integer(&self) -> bool { + matches!(self, Self::Integer(..)) } - impl Default for Representation { - fn default() -> Self { - Self { packed: None, align: None, repr: Repr::Rust } + fn as_integer(&self) -> Option { + if let Self::Integer(v) = self { + Some(*v) + } else { + None } } +} - impl Parse for Representation { - fn parse(input: ParseStream<'_>) -> Result { - let mut ret = Representation::default(); - while !input.is_empty() { - let keyword = input.parse::()?; - // preƫmptively call `.to_string()` *once* (rather than on `is_ident()`) - let keyword_str = keyword.to_string(); - let new_repr = match keyword_str.as_str() { - "C" => Repr::C, - "transparent" => Repr::Transparent, - "packed" => { - ret.packed = Some(if input.peek(token::Paren) { - let contents; parenthesized!(contents in input); - LitInt::base10_parse::(&contents.parse()?)? - } else { - 1 - }); - let _: Option = input.parse()?; - continue; - }, - "align" => { - let contents; parenthesized!(contents in input); - ret.align = Some(LitInt::base10_parse::(&contents.parse()?)?); - let _: Option = input.parse()?; - continue; - }, - $( - stringify!($xn) => Repr::$Xn, - )* - _ => return Err(input.error("unrecognized representation hint")) - }; - if ::core::mem::replace(&mut ret.repr, new_repr) != Repr::Rust { - input.error("duplicate representation hint"); - } - let _: Option = input.parse()?; - } - Ok(ret) - } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct Representation { + packed: Option, + align: Option, + repr: Repr, +} + +impl Default for Representation { + fn default() -> Self { + Self { packed: None, align: None, repr: Repr::Rust } } +} - impl ToTokens for Representation { - fn to_tokens(&self, tokens: &mut TokenStream) { - let repr = match self.repr { - Repr::Rust => None, - Repr::C => Some(quote!(C)), - Repr::Transparent => Some(quote!(transparent)), - $( - Repr::$Xn => Some(quote!($xn)), - )* +impl Parse for Representation { + fn parse(input: ParseStream<'_>) -> Result { + let mut ret = Representation::default(); + while !input.is_empty() { + let keyword = input.parse::()?; + // preƫmptively call `.to_string()` *once* (rather than on `is_ident()`) + let keyword_str = keyword.to_string(); + let new_repr = match keyword_str.as_str() { + "C" => Repr::C, + "transparent" => Repr::Transparent, + "packed" => { + ret.packed = Some(if input.peek(token::Paren) { + let contents; + parenthesized!(contents in input); + LitInt::base10_parse::(&contents.parse()?)? + } else { + 1 + }); + let _: Option = input.parse()?; + continue; + } + "align" => { + let contents; + parenthesized!(contents in input); + let new_align = LitInt::base10_parse::(&contents.parse()?)?; + ret.align = Some( + ret + .align + .map_or(new_align, |old_align| cmp::max(old_align, new_align)), + ); + let _: Option = input.parse()?; + continue; + } + ident => { + let primitive = IntegerRepr::try_from(ident) + .map_err(|_| input.error("unrecognized representation hint"))?; + Repr::Integer(primitive) + } }; - let packed = self.packed.map(|p| { - let lit = LitInt::new(&p.to_string(), Span::call_site()); - quote!(packed(#lit)) - }); - let comma = if packed.is_some() && repr.is_some() { - Some(quote!(,)) - } else { - None + ret.repr = match (ret.repr, new_repr) { + (Repr::Rust, new_repr) => { + // This is the first explicit repr. + new_repr + } + (Repr::C, Repr::Integer(integer)) + | (Repr::Integer(integer), Repr::C) => { + // Both the C repr and an integer repr have been specified + // -> merge into a C wit discriminant. + Repr::CWithDiscriminant(integer) + } + (_, _) => { + return Err(input.error("duplicate representation hint")); + } }; - tokens.extend(quote!( - #[repr( #repr #comma #packed )] - )); + let _: Option = input.parse()?; } + Ok(ret) } -)} -use mk_repr; +} + +impl ToTokens for Representation { + fn to_tokens(&self, tokens: &mut TokenStream) { + let mut meta = Punctuated::<_, Token![,]>::new(); + + match self.repr { + Repr::Rust => {} + Repr::C => meta.push(quote!(C)), + Repr::Transparent => meta.push(quote!(transparent)), + Repr::Integer(primitive) => meta.push(quote!(#primitive)), + Repr::CWithDiscriminant(primitive) => { + meta.push(quote!(C)); + meta.push(quote!(#primitive)); + } + } + + if let Some(packed) = self.packed.as_ref() { + let lit = LitInt::new(&packed.to_string(), Span::call_site()); + meta.push(quote!(packed(#lit))); + } + + if let Some(align) = self.align.as_ref() { + let lit = LitInt::new(&align.to_string(), Span::call_site()); + meta.push(quote!(align(#lit))); + } + + tokens.extend(quote!( + #[repr(#meta)] + )); + } +} + +fn enum_has_fields<'a>( + mut variants: impl Iterator, +) -> bool { + variants.any(|v| matches!(v.fields, Fields::Named(_) | Fields::Unnamed(_))) +} struct VariantDiscriminantIterator<'a, I: Iterator + 'a> { inner: I, @@ -791,12 +1114,6 @@ impl<'a, I: Iterator + 'a> Iterator fn next(&mut self) -> Option { let variant = self.inner.next()?; - if !variant.fields.is_empty() { - return Some(Err(Error::new_spanned( - &variant.fields, - "Only fieldless enums are supported", - ))); - } if let Some((_, discriminant)) = &variant.discriminant { let discriminant_value = match parse_int_expr(discriminant) { @@ -822,3 +1139,83 @@ fn parse_int_expr(expr: &Expr) -> Result { _ => bail!("Not an integer expression"), } } + +#[cfg(test)] +mod tests { + use syn::parse_quote; + + use super::{get_repr, IntegerRepr, Repr, Representation}; + + #[test] + fn parse_basic_repr() { + let attr = parse_quote!(#[repr(C)]); + let repr = get_repr(&[attr]).unwrap(); + assert_eq!(repr, Representation { repr: Repr::C, ..Default::default() }); + + let attr = parse_quote!(#[repr(transparent)]); + let repr = get_repr(&[attr]).unwrap(); + assert_eq!( + repr, + Representation { repr: Repr::Transparent, ..Default::default() } + ); + + let attr = parse_quote!(#[repr(u8)]); + let repr = get_repr(&[attr]).unwrap(); + assert_eq!( + repr, + Representation { + repr: Repr::Integer(IntegerRepr::U8), + ..Default::default() + } + ); + + let attr = parse_quote!(#[repr(packed)]); + let repr = get_repr(&[attr]).unwrap(); + assert_eq!(repr, Representation { packed: Some(1), ..Default::default() }); + + let attr = parse_quote!(#[repr(packed(1))]); + let repr = get_repr(&[attr]).unwrap(); + assert_eq!(repr, Representation { packed: Some(1), ..Default::default() }); + + let attr = parse_quote!(#[repr(packed(2))]); + let repr = get_repr(&[attr]).unwrap(); + assert_eq!(repr, Representation { packed: Some(2), ..Default::default() }); + + let attr = parse_quote!(#[repr(align(2))]); + let repr = get_repr(&[attr]).unwrap(); + assert_eq!(repr, Representation { align: Some(2), ..Default::default() }); + } + + #[test] + fn parse_advanced_repr() { + let attr = parse_quote!(#[repr(align(4), align(2))]); + let repr = get_repr(&[attr]).unwrap(); + assert_eq!(repr, Representation { align: Some(4), ..Default::default() }); + + let attr1 = parse_quote!(#[repr(align(1))]); + let attr2 = parse_quote!(#[repr(align(4))]); + let attr3 = parse_quote!(#[repr(align(2))]); + let repr = get_repr(&[attr1, attr2, attr3]).unwrap(); + assert_eq!(repr, Representation { align: Some(4), ..Default::default() }); + + let attr = parse_quote!(#[repr(C, u8)]); + let repr = get_repr(&[attr]).unwrap(); + assert_eq!( + repr, + Representation { + repr: Repr::CWithDiscriminant(IntegerRepr::U8), + ..Default::default() + } + ); + + let attr = parse_quote!(#[repr(u8, C)]); + let repr = get_repr(&[attr]).unwrap(); + assert_eq!( + repr, + Representation { + repr: Repr::CWithDiscriminant(IntegerRepr::U8), + ..Default::default() + } + ); + } +} diff --git a/derive/tests/basic.rs b/derive/tests/basic.rs index b719c23..5e2797e 100644 --- a/derive/tests/basic.rs +++ b/derive/tests/basic.rs @@ -2,7 +2,7 @@ use bytemuck::{ AnyBitPattern, CheckedBitPattern, Contiguous, NoUninit, Pod, - TransparentWrapper, Zeroable, + TransparentWrapper, Zeroable, checked::CheckedCastError, }; use std::marker::{PhantomData, PhantomPinned}; @@ -160,6 +160,66 @@ struct AnyBitPatternTest { b: B, } +#[derive(Clone, Copy, CheckedBitPattern)] +#[repr(C, align(8))] +struct CheckedBitPatternAlignedStruct { + a: u16, +} + +#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)] +#[repr(C)] +enum CheckedBitPatternCDefaultDiscriminantEnumWithFields { + A(u64), + B { c: u64 }, +} + +#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)] +#[repr(C, u8)] +enum CheckedBitPatternCEnumWithFields { + A(u32), + B { c: u32 }, +} + +#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)] +#[repr(u8)] +enum CheckedBitPatternIntEnumWithFields { + A(u8), + B { c: u32 }, +} + +#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)] +#[repr(transparent)] +enum CheckedBitPatternTransparentEnumWithFields { + A { b: u32 }, +} + +// size 24, align 8. +// first byte always the u8 discriminant, then 7 bytes of padding until the payload union since the align of the payload +// is the greatest of the align of all the variants, which is 8 (from CheckedBitPatternCDefaultDiscriminantEnumWithFields) +#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)] +#[repr(C, u8)] +enum CheckedBitPatternEnumNested { + A(CheckedBitPatternCEnumWithFields), + B(CheckedBitPatternCDefaultDiscriminantEnumWithFields), +} + +/// ```compile_fail +/// use bytemuck::{Pod, Zeroable}; +/// +/// #[derive(Pod, Zeroable)] +/// #[repr(transparent)] +/// struct TransparentSingle(T); +/// +/// struct NotPod(u32); +/// +/// let _: u32 = bytemuck::cast(TransparentSingle(NotPod(0u32))); +/// ``` +#[derive( + Debug, Copy, Clone, PartialEq, Eq, Pod, Zeroable, TransparentWrapper, +)] +#[repr(transparent)] +struct NewtypeWrapperTest(T); + #[test] fn fails_cast_contiguous() { let can_cast = CheckedBitPatternEnumWithValues::is_valid_bit_pattern(&5); @@ -246,6 +306,140 @@ fn checkedbitpattern_try_pod_read_unaligned() { assert!(res.is_err()); } +#[test] +fn checkedbitpattern_aligned_struct() { + let pod = [0u8; 8]; + bytemuck::checked::pod_read_unaligned::(&pod); +} + +#[test] +fn checkedbitpattern_c_default_discriminant_enum_with_fields() { + let pod = [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0x55, + 0x55, 0x55, 0x55, 0xcc, + ]; + let value = bytemuck::checked::pod_read_unaligned::< + CheckedBitPatternCDefaultDiscriminantEnumWithFields, + >(&pod); + assert_eq!( + value, + CheckedBitPatternCDefaultDiscriminantEnumWithFields::A(0xcc555555555555cc) + ); + + let pod = [ + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0x55, + 0x55, 0x55, 0x55, 0xcc, + ]; + let value = bytemuck::checked::pod_read_unaligned::< + CheckedBitPatternCDefaultDiscriminantEnumWithFields, + >(&pod); + assert_eq!( + value, + CheckedBitPatternCDefaultDiscriminantEnumWithFields::B { + c: 0xcc555555555555cc + } + ); +} + +#[test] +fn checkedbitpattern_c_enum_with_fields() { + let pod = [0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc]; + let value = bytemuck::checked::pod_read_unaligned::< + CheckedBitPatternCEnumWithFields, + >(&pod); + assert_eq!(value, CheckedBitPatternCEnumWithFields::A(0xcc5555cc)); + + let pod = [0x01, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc]; + let value = bytemuck::checked::pod_read_unaligned::< + CheckedBitPatternCEnumWithFields, + >(&pod); + assert_eq!(value, CheckedBitPatternCEnumWithFields::B { c: 0xcc5555cc }); +} + +#[test] +fn checkedbitpattern_int_enum_with_fields() { + let pod = [0x00, 0x55, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + let value = bytemuck::checked::pod_read_unaligned::< + CheckedBitPatternIntEnumWithFields, + >(&pod); + assert_eq!(value, CheckedBitPatternIntEnumWithFields::A(0x55)); + + let pod = [0x01, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc]; + let value = bytemuck::checked::pod_read_unaligned::< + CheckedBitPatternIntEnumWithFields, + >(&pod); + assert_eq!(value, CheckedBitPatternIntEnumWithFields::B { c: 0xcc5555cc }); +} + +#[test] +fn checkedbitpattern_nested_enum_with_fields() { + // total size 24 bytes. first byte always the u8 discriminant. + + #[repr(C, align(8))] + struct Align8Bytes([u8; 24]); + + // first we'll check variantA, nested variant A + let pod = Align8Bytes([ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // byte 0 discriminant = 0 = variant A, bytes 1-7 irrelevant padding. + 0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc, // bytes 8-15 are the nested CheckedBitPatternCEnumWithFields, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 16-23 padding + ]); + let value = bytemuck::checked::from_bytes::< + CheckedBitPatternEnumNested, + >(&pod.0); + assert_eq!(value, &CheckedBitPatternEnumNested::A(CheckedBitPatternCEnumWithFields::A(0xcc5555cc))); + + // next we'll check invalid first discriminant fails + let pod = Align8Bytes([ + 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // byte 0 discriminant = 2 = invalid, bytes 1-7 padding + 0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc, // bytes 8-15 are the nested CheckedBitPatternCEnumWithFields = A, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 16-23 padding + ]); + let result = bytemuck::checked::try_from_bytes::< + CheckedBitPatternEnumNested, + >(&pod.0); + assert_eq!(result, Err(CheckedCastError::InvalidBitPattern)); + + + // next we'll check variant B, nested variant B + let pod = Align8Bytes([ + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // byte 0 discriminant = 1 = variant B, bytes 1-7 padding + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 8-15 is C int size discriminant of CheckedBitPatternCDefaultDiscrimimantEnumWithFields, 1 (LE byte order) = variant B + 0xcc, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0xcc, // bytes 16-13 is the data contained in nested variant B + ]); + let value = bytemuck::checked::from_bytes::< + CheckedBitPatternEnumNested, + >(&pod.0); + assert_eq!( + value, + &CheckedBitPatternEnumNested::B(CheckedBitPatternCDefaultDiscriminantEnumWithFields::B { + c: 0xcc555555555555cc + }) + ); + + // finally we'll check variant B, nested invalid discriminant + let pod = Align8Bytes([ + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 1 discriminant = variant B, bytes 1-7 padding + 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 8-15 is C int size discriminant of CheckedBitPatternCDefaultDiscrimimantEnumWithFields, 0x08 is invalid + 0xcc, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0xcc, // bytes 16-13 is the data contained in nested variant B + ]); + let result = bytemuck::checked::try_from_bytes::< + CheckedBitPatternEnumNested, + >(&pod.0); + assert_eq!(result, Err(CheckedCastError::InvalidBitPattern)); +} +#[test] +fn checkedbitpattern_transparent_enum_with_fields() { + let pod = [0xcc, 0x55, 0x55, 0xcc]; + let value = bytemuck::checked::pod_read_unaligned::< + CheckedBitPatternTransparentEnumWithFields, + >(&pod); + assert_eq!( + value, + CheckedBitPatternTransparentEnumWithFields::A { b: 0xcc5555cc } + ); +} + #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] #[repr(C, align(16))] struct Issue127 {}