diff --git a/core/src/hir/attrs.rs b/core/src/hir/attrs.rs index 2d3ac256e..ec8586d3f 100644 --- a/core/src/hir/attrs.rs +++ b/core/src/hir/attrs.rs @@ -5,7 +5,7 @@ use crate::ast::attrs::{AttrInheritContext, DiplomatBackendAttrCfg, StandardAttr use crate::hir::lowering::ErrorStore; use crate::hir::{ EnumVariant, LoweringError, Method, Mutability, OpaqueId, ReturnType, SelfType, SuccessType, - Type, TypeDef, TypeId, + TraitDef, Type, TypeDef, TypeId, }; use syn::Meta; @@ -137,6 +137,7 @@ pub struct SpecialMethodPresence { #[derive(Debug)] pub enum AttributeContext<'a, 'b> { Type(TypeDef<'a>), + Trait(&'a TraitDef), EnumVariant(&'a EnumVariant), Method(&'a Method, TypeId, &'b mut SpecialMethodPresence), Module, diff --git a/core/src/hir/lowering.rs b/core/src/hir/lowering.rs index de2c65883..4456faf85 100644 --- a/core/src/hir/lowering.rs +++ b/core/src/hir/lowering.rs @@ -417,12 +417,8 @@ impl<'ast> LoweringContext<'ast> { let lifetimes = self.lower_type_lifetime_env(&ast_trait.lifetimes); let def = TraitDef::new(ast_trait.docs.clone(), trait_name, fcts, attrs, lifetimes?); - // TODO fix this so it works for traits - // self.attr_validator.validate( - // &def.attrs, - // AttributeContext::Type(TypeDef::from(&def)), - // &mut self.errors, - // ); + self.attr_validator + .validate(&def.attrs, AttributeContext::Trait(&def), &mut self.errors); Ok(def) } diff --git a/core/src/hir/type_context.rs b/core/src/hir/type_context.rs index 1659ca9c5..747261783 100644 --- a/core/src/hir/type_context.rs +++ b/core/src/hir/type_context.rs @@ -174,12 +174,12 @@ impl TypeContext { /// Resolve and format a named type for use in diagnostics /// (don't apply rename rules and such) pub fn fmt_type_name_diagnostics(&self, id: TypeId) -> Cow { - self.fmt_symbol_name_diagnostics(id.into()) + self.resolve_type(id).name().as_str().into() } pub fn fmt_symbol_name_diagnostics(&self, id: SymbolId) -> Cow { match id { - SymbolId::TypeId(id) => self.resolve_type(id).name().as_str().into(), + SymbolId::TypeId(id) => self.fmt_type_name_diagnostics(id), SymbolId::TraitId(id) => self.resolve_trait(id).name.as_str().into(), } } diff --git a/macro/src/lib.rs b/macro/src/lib.rs index d02808ed8..81c47c5b5 100644 --- a/macro/src/lib.rs +++ b/macro/src/lib.rs @@ -100,8 +100,8 @@ fn gen_custom_vtable(custom_trait: &ast::Trait, custom_trait_vtable_type: &Ident let mut method_sigs: Vec = vec![]; method_sigs.push(quote!( pub destructor: Option, - pub SIZE: usize, - pub ALIGNMENT: usize, + pub size: usize, + pub alignment: usize, )); for m in &custom_trait.methods { // TODO check that this is the right conversion, it might be the wrong direction @@ -538,7 +538,7 @@ fn gen_bridge(mut input: ItemMod) -> ItemMod { new_contents.push(syn::parse_quote! { #[repr(C)] pub struct #custom_trait_name { - pub data: *const c_void, + data: *const c_void, pub vtable: #custom_trait_vtable_type, } }); diff --git a/macro/src/snapshots/diplomat__tests__traits.snap b/macro/src/snapshots/diplomat__tests__traits.snap index c636c5390..6b5648cc4 100644 --- a/macro/src/snapshots/diplomat__tests__traits.snap +++ b/macro/src/snapshots/diplomat__tests__traits.snap @@ -42,8 +42,8 @@ mod ffi { #[repr(C)] pub struct TesterTrait_VTable { pub destructor: Option, - pub SIZE: usize, - pub ALIGNMENT: usize, + pub size: usize, + pub alignment: usize, pub run_test_trait_fn_callback: unsafe extern "C" fn(*const c_void, i32) -> i32, pub run_test_void_trait_fn_callback: unsafe extern "C" fn(*const c_void), pub run_test_struct_trait_fn_callback: @@ -53,7 +53,7 @@ mod ffi { } #[repr(C)] pub struct DiplomatTraitStruct_TesterTrait { - pub data: *const c_void, + data: *const c_void, pub vtable: TesterTrait_VTable, } impl TesterTrait for DiplomatTraitStruct_TesterTrait { diff --git a/tool/src/c/formatter.rs b/tool/src/c/formatter.rs index 8c5a3cad9..4d2f8411e 100644 --- a/tool/src/c/formatter.rs +++ b/tool/src/c/formatter.rs @@ -1,6 +1,8 @@ //! This module contains functions for formatting types -use diplomat_core::hir::{self, StringEncoding, SymbolId, TyPosition, TypeContext}; +use diplomat_core::hir::{ + self, StringEncoding, SymbolId, TraitId, TyPosition, TypeContext, TypeId, +}; use std::borrow::Cow; /// This type mediates all formatting @@ -28,22 +30,26 @@ impl<'tcx> CFormatter<'tcx> { } /// Resolve and format a named type for use in code (without the namespace) - pub fn fmt_type_name(&self, id: SymbolId) -> Cow<'tcx, str> { - let (name, attrs) = match id { - SymbolId::TypeId(id) => { - let resolved = self.tcx.resolve_type(id); - let name: Cow<_> = resolved.name().as_str().into(); - let attrs = resolved.attrs(); - (name, attrs) - } - SymbolId::TraitId(id) => { - let resolved = self.tcx.resolve_trait(id); - let name: Cow<_> = resolved.name.as_str().into(); - let attrs = &resolved.attrs; - (name, attrs) - } - _ => panic!("Unexpected symbol ID type"), - }; + pub fn fmt_type_name(&self, id: TypeId) -> Cow<'tcx, str> { + let resolved = self.tcx.resolve_type(id); + let name: Cow<_> = resolved.name().as_str().into(); + let attrs = resolved.attrs(); + + // Only apply renames in cpp mode, in pure C mode you'd want the + // method names to match the type names. + // Potential future improvement: Use alias attributes in pure C mode. + if self.is_for_cpp { + attrs.rename.apply(name) + } else { + name + } + } + + pub fn fmt_trait_name(&self, id: TraitId) -> Cow<'tcx, str> { + let resolved = self.tcx.resolve_trait(id); + let name: Cow<_> = resolved.name.as_str().into(); + let attrs = &resolved.attrs; + // Only apply renames in cpp mode, in pure C mode you'd want the // method names to match the type names. // Potential future improvement: Use alias attributes in pure C mode. @@ -111,12 +117,20 @@ impl<'tcx> CFormatter<'tcx> { /// *just* the enum. It is included from Foo.h, and external users should not be importing /// it directly. (We can potentially add a #define guard that makes this actually private, if needed) pub fn fmt_decl_header_path(&self, id: SymbolId) -> String { - let type_name = self.fmt_type_name(id); + let type_name = match id { + SymbolId::TypeId(id) => self.fmt_type_name(id), + SymbolId::TraitId(id) => self.fmt_trait_name(id), + _ => panic!("Unexpected symbol ID type"), + }; format!("{type_name}.d.h") } /// Resolve and format the name of a type for use in header names: impl version pub fn fmt_impl_header_path(&self, id: SymbolId) -> String { - let type_name = self.fmt_type_name(id); + let type_name = match id { + SymbolId::TypeId(id) => self.fmt_type_name(id), + SymbolId::TraitId(id) => self.fmt_trait_name(id), + _ => panic!("Unexpected symbol ID type"), + }; format!("{type_name}.h") } diff --git a/tool/src/c/ty.rs b/tool/src/c/ty.rs index 610560f0b..8f44ad103 100644 --- a/tool/src/c/ty.rs +++ b/tool/src/c/ty.rs @@ -81,7 +81,7 @@ pub(crate) struct TyGenContext<'cx, 'tcx> { impl<'cx, 'tcx> TyGenContext<'cx, 'tcx> { pub(crate) fn gen_enum_def(&self, def: &'tcx hir::EnumDef) -> Header { let mut decl_header = Header::new(self.decl_header_path.clone(), self.is_for_cpp); - let ty_name = self.formatter.fmt_type_name(self.id); + let ty_name = self.formatter.fmt_type_name(self.id.try_into().unwrap()); EnumTemplate { ty: def, fmt: self.formatter, @@ -96,7 +96,7 @@ impl<'cx, 'tcx> TyGenContext<'cx, 'tcx> { pub(crate) fn gen_opaque_def(&self, _def: &'tcx hir::OpaqueDef) -> Header { let mut decl_header = Header::new(self.decl_header_path.clone(), self.is_for_cpp); - let ty_name = self.formatter.fmt_type_name(self.id); + let ty_name = self.formatter.fmt_type_name(self.id.try_into().unwrap()); OpaqueTemplate { ty_name, is_for_cpp: self.is_for_cpp, @@ -109,7 +109,7 @@ impl<'cx, 'tcx> TyGenContext<'cx, 'tcx> { pub(crate) fn gen_struct_def(&self, def: &'tcx hir::StructDef

) -> Header { let mut decl_header = Header::new(self.decl_header_path.clone(), self.is_for_cpp); - let ty_name = self.formatter.fmt_type_name(self.id); + let ty_name = self.formatter.fmt_type_name(self.id.try_into().unwrap()); let mut fields = vec![]; let mut cb_structs_and_defs = vec![]; for field in def.fields.iter() { @@ -135,7 +135,7 @@ impl<'cx, 'tcx> TyGenContext<'cx, 'tcx> { pub(crate) fn gen_trait_def(&self, def: &'tcx hir::TraitDef) -> Header { let mut decl_header = Header::new(self.decl_header_path.clone(), self.is_for_cpp); - let trt_name = self.formatter.fmt_type_name(self.id); + let trt_name = self.formatter.fmt_trait_name(self.id.try_into().unwrap()); let mut method_sigs = vec![]; for m in &def.methods { let mut param_types: Vec> = m @@ -186,7 +186,7 @@ impl<'cx, 'tcx> TyGenContext<'cx, 'tcx> { cb_structs_and_defs.extend_from_slice(&callback_defs); } - let ty_name = self.formatter.fmt_type_name(self.id); + let ty_name = self.formatter.fmt_type_name(self.id.try_into().unwrap()); let dtor_name = if let TypeDef::Opaque(opaque) = ty { Some(opaque.dtor_abi_name.as_str())