From f67242b280e77f5df3a98bae0e0bd70efcc82c6c Mon Sep 17 00:00:00 2001 From: Mads Marquart Date: Wed, 25 Dec 2024 16:19:48 +0100 Subject: [PATCH] Refactor how superclasses and superprotocols are stored We need to know their "spatial" layout to do further analysis, i.e. need to know the entire tree of which protocol needs which superprotocol. --- crates/header-translator/src/config.rs | 4 +- crates/header-translator/src/lib.rs | 1 + crates/header-translator/src/protocol.rs | 70 +++++ crates/header-translator/src/rust_type.rs | 254 ++++++++++-------- crates/header-translator/src/stmt.rs | 190 ++++++------- crates/header-translator/src/thread_safety.rs | 5 +- .../objc2-ui-kit/translation-config.toml | 2 +- 7 files changed, 302 insertions(+), 224 deletions(-) create mode 100644 crates/header-translator/src/protocol.rs diff --git a/crates/header-translator/src/config.rs b/crates/header-translator/src/config.rs index 5b8e04182..d14150deb 100644 --- a/crates/header-translator/src/config.rs +++ b/crates/header-translator/src/config.rs @@ -149,9 +149,9 @@ pub struct ExternalData { #[serde(rename = "thread-safety")] #[serde(default)] pub thread_safety: Option, - #[serde(rename = "required-items")] + #[serde(rename = "super-items")] #[serde(default)] - pub required_items: Vec, + pub super_items: Vec, } #[derive(Deserialize, Debug, Default, Clone, PartialEq, Eq)] diff --git a/crates/header-translator/src/lib.rs b/crates/header-translator/src/lib.rs index a59a79485..104133c46 100644 --- a/crates/header-translator/src/lib.rs +++ b/crates/header-translator/src/lib.rs @@ -25,6 +25,7 @@ mod method; mod module; mod name_translation; mod objc2_utils; +mod protocol; mod rust_type; mod stmt; mod thread_safety; diff --git a/crates/header-translator/src/protocol.rs b/crates/header-translator/src/protocol.rs new file mode 100644 index 000000000..15acdc7c3 --- /dev/null +++ b/crates/header-translator/src/protocol.rs @@ -0,0 +1,70 @@ +use std::iter; + +use clang::{Entity, EntityKind}; + +use crate::{immediate_children, Context, ItemIdentifier}; + +/// Parse the directly referenced protocols of a declaration. +pub(crate) fn parse_direct_protocols<'clang>( + entity: &Entity<'clang>, + _context: &Context<'_>, +) -> Vec> { + let mut protocols = Vec::new(); + + #[allow(clippy::single_match)] + immediate_children(entity, |child, _span| match child.get_kind() { + EntityKind::ObjCProtocolRef => { + let child = child + .get_reference() + .expect("ObjCProtocolRef to reference entity"); + if child == *entity { + error!(?entity, "recursive protocol"); + } else { + protocols.push(child); + } + } + _ => {} + }); + + protocols +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ProtocolRef { + pub(crate) id: ItemIdentifier, + pub(crate) super_protocols: Vec, +} + +impl ProtocolRef { + pub(crate) fn super_protocols(entity: &Entity<'_>, context: &Context<'_>) -> Vec { + let mut super_protocols = Vec::new(); + + for entity in parse_direct_protocols(entity, context) { + super_protocols.push(ProtocolRef::from_entity(&entity, context)); + } + + super_protocols + } + + pub(crate) fn from_entity(entity: &Entity<'_>, context: &Context<'_>) -> Self { + let mut super_protocols = Vec::new(); + + for entity in parse_direct_protocols(entity, context) { + super_protocols.push(ProtocolRef::from_entity(&entity, context)); + } + + Self { + id: context.replace_protocol_name(ItemIdentifier::new(&entity, context)), + super_protocols, + } + } + + pub(crate) fn required_items(&self) -> Vec { + self.super_protocols + .iter() + .flat_map(|super_protocol| super_protocol.required_items()) + .chain(iter::once(self.id.clone())) + .chain(iter::once(ItemIdentifier::objc("__macros__"))) + .collect() + } +} diff --git a/crates/header-translator/src/rust_type.rs b/crates/header-translator/src/rust_type.rs index 4fc07c377..ea7977b29 100644 --- a/crates/header-translator/src/rust_type.rs +++ b/crates/header-translator/src/rust_type.rs @@ -1,6 +1,6 @@ +use std::fmt; use std::str::FromStr; use std::sync::LazyLock; -use std::{fmt, iter}; use clang::{CallingConvention, Entity, EntityKind, Nullability, Type, TypeKind}; use proc_macro2::{TokenStream, TokenTree}; @@ -8,7 +8,8 @@ use proc_macro2::{TokenStream, TokenTree}; use crate::context::Context; use crate::display_helper::FormatterFn; use crate::id::ItemIdentifier; -use crate::stmt::items_required_by_decl; +use crate::protocol::ProtocolRef; +use crate::stmt::parse_superclasses; use crate::stmt::{anonymous_record_name, is_bridged}; use crate::thread_safety::ThreadSafety; use crate::unexposed_attr::UnexposedAttr; @@ -352,75 +353,107 @@ impl fmt::Display for Primitive { } } -/// A reference to a class or a protocol declaration. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ItemRef { - id: ItemIdentifier, - thread_safety: ThreadSafety, - required_items: Vec, +fn get_class_data( + entity_ref: &Entity<'_>, + context: &Context<'_>, +) -> (ItemIdentifier, ThreadSafety, Vec) { + // @class produces a ObjCInterfaceDecl if we didn't load the actual + // declaration, but we don't actually want that, since it'll point to the + // wrong place. + let entity = entity_ref + .get_location() + .expect("class location") + .get_entity() + .expect("class entity"); + + let mut id = ItemIdentifier::new(&entity, context); + + if let Some(external) = context.library(id.library_name()).external.get(&id.name) { + let id = ItemIdentifier::from_raw(id.name, external.module.clone()); + let thread_safety = external + .thread_safety + .as_deref() + .map(ThreadSafety::from_string) + .unwrap_or(ThreadSafety::dummy()); + return (id, thread_safety, external.super_items.clone()); + } + + match entity.get_kind() { + EntityKind::ObjCInterfaceDecl => { + let thread_safety = ThreadSafety::from_decl(&entity, context); + let superclasses = parse_superclasses(&entity, context) + .into_iter() + .map(|(id, _, _)| id) + .collect(); + + (id, thread_safety, superclasses) + } + EntityKind::MacroExpansion => { + id.name = entity_ref.get_name().unwrap_or_else(|| { + error!(?entity_ref, ?entity, "macro ref did not have name"); + id.name + }); + // We cannot get thread safety from macro expansions + let thread_safety = ThreadSafety::dummy(); + // Similarly, we cannot get for required items + let superclasses = vec![]; + (id, thread_safety, superclasses) + } + _ => { + error!(?entity, "could not get declaration. Add appropriate external.{}.module = \"...\" to translation-config.toml", id.name); + (id, ThreadSafety::dummy(), vec![]) + } + } } -impl ItemRef { - fn required_items(&self) -> Vec { - self.required_items.clone() - } - - fn new(entity_ref: &Entity<'_>, context: &Context<'_>) -> Self { - let entity = entity_ref - .get_location() - .expect("itemref location") - .get_entity() - .expect("itemref entity"); - - let mut id = ItemIdentifier::new(&entity, context); - - if let Some(external) = context.library(id.library_name()).external.get(&id.name) { - let id = ItemIdentifier::from_raw(id.name, external.module.clone()); - let thread_safety = external - .thread_safety - .as_deref() - .map(ThreadSafety::from_string) - .unwrap_or(ThreadSafety::dummy()); - let required_items = external - .required_items +fn parse_protocol(entity: Entity<'_>, context: &Context<'_>) -> (ProtocolRef, ThreadSafety) { + let entity = entity.get_definition().unwrap_or(entity); + // @protocol produces a ObjCProtocolDecl if we didn't + // load the actual declaration, but we don't actually + // want that, since it'll point to the wrong place. + let entity = entity + .get_location() + .expect("itemref location") + .get_entity() + .expect("itemref entity"); + + let id = ItemIdentifier::new(&entity, context); + + if let Some(external) = context.library(id.library_name()).external.get(&id.name) { + let id = ItemIdentifier::from_raw(id.name, external.module.clone()); + let thread_safety = external + .thread_safety + .as_deref() + .map(ThreadSafety::from_string) + .unwrap_or(ThreadSafety::dummy()); + let protocol = ProtocolRef { + id, + super_protocols: external + .super_items .iter() - .cloned() - .chain(iter::once(id.clone())) - .collect(); - return Self { - id, - thread_safety, - required_items, - }; - } + .map(|item| ProtocolRef { + id: item.clone(), + // TODO: Populate this somehow? + super_protocols: vec![], + }) + .collect(), + }; + return (protocol, thread_safety); + } - match entity.get_kind() { - EntityKind::ObjCInterfaceDecl | EntityKind::ObjCProtocolDecl => Self { + match entity.get_kind() { + EntityKind::ObjCProtocolDecl => { + let protocol = ProtocolRef::from_entity(&entity, context); + let thread_safety = ThreadSafety::from_decl(&entity, context); + (protocol, thread_safety) + } + _ => { + error!(?entity, "could not get declaration. Add appropriate external.{}.module = \"...\" to translation-config.toml", id.name); + let protocol = ProtocolRef { id, - thread_safety: ThreadSafety::from_decl(&entity, context), - required_items: items_required_by_decl(&entity, context), - }, - EntityKind::MacroExpansion => { - id.name = entity_ref.get_name().unwrap_or_else(|| { - error!(?entity_ref, ?entity, "macro ref did not have name"); - id.name - }); - Self { - id: id.clone(), - // We cannot get thread safety from macro expansions - thread_safety: ThreadSafety::dummy(), - // Similarly, we cannot get for required items - required_items: vec![id], - } - } - _ => { - error!(?entity, "could not get declaration. Add appropriate external.{}.module = \"...\" to translation-config.toml", id.name); - Self { - id: id.clone(), - thread_safety: ThreadSafety::dummy(), - required_items: vec![id], - } - } + super_protocols: vec![], + }; + (protocol, ThreadSafety::dummy()) } } } @@ -434,19 +467,21 @@ pub enum Ty { size: u8, }, Class { - decl: ItemRef, + id: ItemIdentifier, + thread_safety: ThreadSafety, + superclasses: Vec, generics: Vec, - protocols: Vec, + protocols: Vec<(ProtocolRef, ThreadSafety)>, }, GenericParam { name: String, }, AnyObject { - protocols: Vec, + protocols: Vec<(ProtocolRef, ThreadSafety)>, }, AnyProtocol, AnyClass { - protocols: Vec, + protocols: Vec<(ProtocolRef, ThreadSafety)>, }, Self_, Sel { @@ -765,12 +800,14 @@ impl Ty { if name == "Protocol" { Self::AnyProtocol } else { - let decl = ItemRef::new(&declaration, context); - if decl.id.name != name.strip_prefix("const ").unwrap_or(&name) { + let (id, thread_safety, superclasses) = get_class_data(&declaration, context); + if id.name != name.strip_prefix("const ").unwrap_or(&name) { error!(?name, "invalid interface name"); } Self::Class { - decl, + id, + thread_safety, + superclasses, protocols: vec![], generics: vec![], } @@ -791,13 +828,7 @@ impl Ty { let protocols: Vec<_> = ty .get_objc_protocol_declarations() .into_iter() - .map(|entity| { - // ItemRef::new will fall back if we can't find it here. - let maybe_definition = entity.get_definition().unwrap_or(entity); - let mut decl = ItemRef::new(&maybe_definition, context); - decl.id = context.replace_protocol_name(decl.id); - decl - }) + .map(|entity| parse_protocol(entity, context)) .collect(); match base_ty.get_kind() { @@ -814,8 +845,9 @@ impl Ty { let declaration = base_ty .get_declaration() .expect("ObjCObject -> ObjCInterface declaration"); - let decl = ItemRef::new(&declaration, context); - if decl.id.name != name { + let (id, thread_safety, superclasses) = + get_class_data(&declaration, context); + if id.name != name { error!(?name, "ObjCObject -> ObjCInterface invalid name"); } @@ -828,7 +860,9 @@ impl Ty { } Self::Class { - decl, + id, + thread_safety, + superclasses, generics, protocols, } @@ -1268,15 +1302,18 @@ impl Ty { items } Self::Class { - decl, + id, + thread_safety: _, + superclasses, generics, protocols, } => { - let mut items = decl.required_items(); + let mut items = vec![id.clone()]; + items.extend(superclasses.iter().cloned()); for generic in generics { items.extend(generic.required_items()); } - for protocol in protocols { + for (protocol, _) in protocols { items.extend(protocol.required_items()); } items @@ -1284,7 +1321,7 @@ impl Ty { Self::GenericParam { .. } => Vec::new(), Self::AnyObject { protocols } => { let mut items = vec![ItemIdentifier::objc("AnyObject")]; - for protocol in protocols { + for (protocol, _) in protocols { items.extend(protocol.required_items()); } items @@ -1292,7 +1329,7 @@ impl Ty { Self::AnyProtocol => vec![ItemIdentifier::objc("AnyProtocol")], Self::AnyClass { protocols } => { let mut items = vec![ItemIdentifier::objc("AnyClass")]; - for protocol in protocols { + for (protocol, _) in protocols { items.extend(protocol.required_items()); } items @@ -1380,26 +1417,28 @@ impl Ty { Self::Primitive(_) => false, Self::Simd { .. } => false, Self::Class { - decl, + id: _, + thread_safety, + superclasses: _, generics, protocols, } => { - decl.thread_safety.inferred_mainthreadonly() + thread_safety.inferred_mainthreadonly() || generics .iter() .any(|generic| generic.requires_mainthreadmarker(self_requires)) || protocols .iter() - .any(|protocol| protocol.thread_safety.inferred_mainthreadonly()) + .any(|(_, thread_safety)| thread_safety.inferred_mainthreadonly()) } Self::GenericParam { .. } => false, Self::AnyObject { protocols } => protocols .iter() - .any(|protocol| protocol.thread_safety.inferred_mainthreadonly()), + .any(|(_, thread_safety)| thread_safety.inferred_mainthreadonly()), Self::AnyProtocol => false, Self::AnyClass { protocols } => protocols .iter() - .any(|protocol| protocol.thread_safety.inferred_mainthreadonly()), + .any(|(_, thread_safety)| thread_safety.inferred_mainthreadonly()), Self::Self_ => self_requires, Self::Sel { .. } => false, Self::Pointer { pointee, .. } => pointee.requires_mainthreadmarker(self_requires), @@ -1441,11 +1480,11 @@ impl Ty { // Important: We mostly visit the top-level types, to not include // optional things like `Option<&NSView>` or `&NSArray`. match self { - Self::Class { decl, .. } => decl.thread_safety.inferred_mainthreadonly(), + Self::Class { thread_safety, .. } => thread_safety.inferred_mainthreadonly(), Self::AnyObject { protocols } => { match &**protocols { [] => false, - [decl] => decl.thread_safety.inferred_mainthreadonly(), + [(_, thread_safety)] => thread_safety.inferred_mainthreadonly(), // TODO: Handle this better _ => false, } @@ -1703,11 +1742,13 @@ impl Ty { fn behind_pointer(&self) -> impl fmt::Display + '_ { FormatterFn(move |f| match self { Self::Class { - decl, + id, + thread_safety: _, + superclasses: _, generics, protocols: _, } => { - write!(f, "{}", decl.id.path())?; + write!(f, "{}", id.path())?; if !generics.is_empty() { write!(f, "<")?; for generic in generics { @@ -1733,11 +1774,11 @@ impl Ty { Self::GenericParam { name } => write!(f, "{name}"), Self::AnyObject { protocols } => match &**protocols { [] => write!(f, "AnyObject"), - [decl] => write!(f, "ProtocolObject", decl.id.path()), + [(protocol, _)] => write!(f, "ProtocolObject", protocol.id.path()), // TODO: Handle this better - [first, rest @ ..] => { + [(first, _), rest @ ..] => { write!(f, "AnyObject /* {}", first.id.path())?; - for protocol in rest { + for (protocol, _) in rest { write!(f, "+ {}", protocol.id.path())?; } write!(f, " */")?; @@ -2462,12 +2503,13 @@ impl Ty { } = &**pointee { if let Self::Class { - decl, + id, generics, protocols, + .. } = &**pointee { - if !decl.id.is_nserror() { + if !id.is_nserror() { return false; } assert!(!is_const, "expected error not const {self:?}"); @@ -2695,11 +2737,9 @@ mod tests { is_const: false, lifetime: Lifetime::Unspecified, pointee: Box::new(Ty::Class { - decl: ItemRef { - id: ItemIdentifier::dummy(), - thread_safety: ThreadSafety::dummy(), - required_items: vec![], - }, + id: ItemIdentifier::dummy(), + thread_safety: ThreadSafety::dummy(), + superclasses: vec![], generics: vec![], protocols: vec![], }), diff --git a/crates/header-translator/src/stmt.rs b/crates/header-translator/src/stmt.rs index f51c88857..ba04d19f1 100644 --- a/crates/header-translator/src/stmt.rs +++ b/crates/header-translator/src/stmt.rs @@ -23,6 +23,8 @@ use crate::immediate_children; use crate::method::{handle_reserved, Method}; use crate::name_translation::enum_prefix; use crate::name_translation::split_words; +use crate::protocol::parse_direct_protocols; +use crate::protocol::ProtocolRef; use crate::rust_type::Ty; use crate::thread_safety::ThreadSafety; use crate::unexposed_attr::UnexposedAttr; @@ -74,31 +76,6 @@ fn parse_protocols<'tu>( }); } -/// Parse the directly referenced protocols of a declaration. -pub(crate) fn parse_direct_protocols<'clang>( - entity: &Entity<'clang>, - _context: &Context<'_>, -) -> Vec> { - let mut protocols = Vec::new(); - - #[allow(clippy::single_match)] - immediate_children(entity, |child, _span| match child.get_kind() { - EntityKind::ObjCProtocolRef => { - let child = child - .get_reference() - .expect("ObjCProtocolRef to reference entity"); - if child == *entity { - error!(?entity, "recursive protocol"); - } else { - protocols.push(child); - } - } - _ => {} - }); - - protocols -} - pub(crate) fn parse_superclasses<'ty>( entity: &Entity<'ty>, context: &Context<'_>, @@ -282,35 +259,6 @@ fn parse_methods( (methods, designated_initializers) } -/// Get the items required for a given interface or protocol declaration to be -/// enabled. -pub(crate) fn items_required_by_decl( - entity: &Entity<'_>, - context: &Context<'_>, -) -> Vec { - let id = ItemIdentifier::new(entity, context); - - let mut items = vec![ItemIdentifier::objc("__macros__")]; - - match entity.get_kind() { - EntityKind::ObjCInterfaceDecl => { - for (superclass, _, _) in parse_superclasses(entity, context) { - items.push(superclass); - } - } - EntityKind::ObjCProtocolDecl => { - for entity in parse_direct_protocols(entity, context) { - items.extend(items_required_by_decl(&entity, context)); - } - } - _ => panic!("invalid required_by_decl kind {entity:?}"), - } - - items.push(id); - - items -} - /// Takes one of: /// - `EntityKind::ObjCInterfaceDecl` /// - `EntityKind::ObjCProtocolDecl` @@ -454,7 +402,6 @@ pub enum Stmt { /// extern_class! ClassDecl { id: ItemIdentifier, - required_items: Vec, generics: Vec, availability: Availability, superclasses: Vec<(ItemIdentifier, Vec)>, @@ -472,7 +419,7 @@ pub enum Stmt { location: Location, availability: Availability, cls: ItemIdentifier, - cls_required_items: Vec, + cls_superclasses: Vec, source_superclass: Option, cls_generics: Vec, category_name: Option, @@ -487,7 +434,7 @@ pub enum Stmt { actual_name: Option, availability: Availability, cls: ItemIdentifier, - cls_required_items: Vec, + cls_superclasses: Vec, methods: Vec, documentation: Documentation, }, @@ -496,10 +443,9 @@ pub enum Stmt { /// extern_protocol! ProtocolDecl { id: ItemIdentifier, - required_items: Vec, actual_name: Option, availability: Availability, - protocols: Vec, + super_protocols: Vec, methods: Vec, required_sendable: bool, required_mainthreadonly: bool, @@ -510,10 +456,10 @@ pub enum Stmt { ProtocolImpl { location: Location, cls: ItemIdentifier, - cls_required_items: Vec, + cls_superclasses: Vec, cls_counterpart: Counterpart, protocol: ItemIdentifier, - protocol_required_items: Vec, + protocol_super_protocols: Vec, generics: Vec, availability: Availability, }, @@ -690,7 +636,6 @@ impl Stmt { let availability = Availability::parse(entity, context); let thread_safety = ThreadSafety::from_decl(entity, context); - let required_items = items_required_by_decl(entity, context); let counterpart = data .map(|data| data.counterpart.clone()) .unwrap_or_default(); @@ -718,6 +663,10 @@ impl Stmt { .iter() .map(|(id, generics, _)| (id.clone(), generics.clone())) .collect(); + let cls_superclasses: Vec<_> = superclasses_full + .iter() + .map(|(id, _, _)| id.clone()) + .collect(); // Used for duplicate checking (sometimes the subclass // defines the same method that the superclass did). @@ -758,7 +707,7 @@ impl Stmt { location: id.location().clone(), availability: availability.clone(), cls: id.clone(), - cls_required_items: required_items.clone(), + cls_superclasses: cls_superclasses.clone(), source_superclass: Some(superclass_id.clone()), cls_generics: generics.clone(), category_name: None, @@ -773,7 +722,7 @@ impl Stmt { location: id.location().clone(), availability: availability.clone(), cls: id.clone(), - cls_required_items: required_items.clone(), + cls_superclasses: cls_superclasses.clone(), source_superclass: None, cls_generics: generics.clone(), category_name: None, @@ -783,7 +732,6 @@ impl Stmt { iter::once(Self::ClassDecl { id: id.clone(), - required_items: required_items.clone(), generics: generics.clone(), availability: availability.clone(), superclasses, @@ -799,10 +747,10 @@ impl Stmt { .chain(protocols.into_iter().map(|(p, entity)| Self::ProtocolImpl { location: id.location().clone(), cls: id.clone(), - cls_required_items: required_items.clone(), + cls_superclasses: cls_superclasses.clone(), cls_counterpart: counterpart.clone(), protocol: context.replace_protocol_name(p), - protocol_required_items: items_required_by_decl(&entity, context), + protocol_super_protocols: ProtocolRef::super_protocols(&entity, context), generics: generics.clone(), availability: availability.clone(), })) @@ -854,7 +802,10 @@ impl Stmt { } let cls_thread_safety = ThreadSafety::from_decl(&cls_entity, context); - let cls_required_items = items_required_by_decl(&cls_entity, context); + let cls_superclasses: Vec<_> = parse_superclasses(&cls_entity, context) + .into_iter() + .map(|(id, _, _)| id) + .collect(); let cls_counterpart = data .map(|data| data.counterpart.clone()) .unwrap_or_default(); @@ -875,12 +826,12 @@ impl Stmt { let protocol_impls = protocols.into_iter().map(|(p, entity)| Self::ProtocolImpl { location: category.location().clone(), cls: cls.clone(), - cls_required_items: cls_required_items.clone(), + cls_superclasses: cls_superclasses.clone(), cls_counterpart: cls_counterpart.clone(), generics: generics.clone(), availability: availability.clone(), protocol: context.replace_protocol_name(p), - protocol_required_items: items_required_by_decl(&entity, context), + protocol_super_protocols: ProtocolRef::super_protocols(&entity, context), }); // For ease-of-use, if the category is defined in the same @@ -941,7 +892,7 @@ impl Stmt { availability: availability.clone(), cls: subclass, // ... the same required items ... - cls_required_items: cls_required_items.clone(), + cls_superclasses: cls_superclasses.clone(), // ... and that they have the same amount of generics. cls_generics: generics.clone(), category_name: category.name.clone(), @@ -957,7 +908,7 @@ impl Stmt { location: category.location().clone(), availability: availability.clone(), cls: cls.clone(), - cls_required_items: cls_required_items.clone(), + cls_superclasses: cls_superclasses.clone(), source_superclass: None, cls_generics: generics.clone(), category_name: category.name.clone(), @@ -1030,7 +981,7 @@ impl Stmt { actual_name: category.name.clone(), availability: availability.clone(), cls: cls.clone(), - cls_required_items: cls_required_items.clone(), + cls_superclasses: cls_superclasses.clone(), methods, documentation: Documentation::from_entity(entity), }) @@ -1061,12 +1012,6 @@ impl Stmt { let thread_safety = ThreadSafety::from_decl(entity, context); verify_objc_decl(entity, context); - let protocols = parse_direct_protocols(entity, context); - let protocols: Vec<_> = protocols - .into_iter() - .map(|protocol| ItemIdentifier::new(&protocol, context)) - .map(|protocol| context.replace_protocol_name(protocol)) - .collect(); let (methods, designated_initializers) = parse_methods( entity, |name| { @@ -1088,10 +1033,9 @@ impl Stmt { vec![Self::ProtocolDecl { id, - required_items: items_required_by_decl(entity, context), actual_name, availability, - protocols, + super_protocols: ProtocolRef::super_protocols(entity, context), methods, required_sendable: thread_safety.explicit_sendable(), required_mainthreadonly: thread_safety.explicit_mainthreadonly(), @@ -1802,22 +1746,48 @@ impl Stmt { /// Items required by the statement at the top-level. pub(crate) fn required_items(&self) -> Vec { match self { - Self::ClassDecl { required_items, .. } => required_items.clone(), + Self::ClassDecl { superclasses, .. } => { + let mut items = vec![ItemIdentifier::objc("__macros__")]; + items.extend( + superclasses + .iter() + .map(|(superclass, _)| superclass.clone()), + ); + items + } Self::ExternMethods { - cls_required_items, .. - } => cls_required_items.clone(), + cls, + cls_superclasses, + .. + } => { + let mut items = vec![cls.clone(), ItemIdentifier::objc("__macros__")]; + items.extend(cls_superclasses.clone()); + items + } // Intentionally doesn't require anything, the impl itself is // cfg-gated Self::ExternCategory { .. } => vec![ItemIdentifier::objc("__macros__")], - Self::ProtocolDecl { required_items, .. } => required_items.clone(), + Self::ProtocolDecl { + super_protocols, .. + } => { + let mut items = vec![ItemIdentifier::objc("__macros__")]; + for super_protocol in super_protocols { + items.extend(super_protocol.required_items()); + } + items + } Self::ProtocolImpl { - cls_required_items, - protocol_required_items, + cls, + cls_superclasses, + protocol, + protocol_super_protocols, .. } => { - let mut items = Vec::new(); - items.extend(cls_required_items.clone()); - items.extend(protocol_required_items.clone()); + let mut items = vec![cls.clone(), protocol.clone()]; + items.extend(cls_superclasses.clone()); + for super_protocol in protocol_super_protocols { + items.extend(super_protocol.required_items()); + } items } Self::RecordDecl { fields, .. } => { @@ -1878,13 +1848,15 @@ impl Stmt { pub(crate) fn required_items_inner(&self) -> Vec { let required_by_inner: Vec<_> = match self { Self::ExternCategory { - cls_required_items, + cls, + cls_superclasses, methods, .. } => methods .iter() .flat_map(|method| method.required_items()) - .chain(cls_required_items.clone()) + .chain(iter::once(cls.clone())) + .chain(cls_superclasses.clone()) .collect(), Self::ExternMethods { methods, .. } | Self::ProtocolDecl { methods, .. } => methods .iter() @@ -2005,7 +1977,6 @@ impl Stmt { match self { Self::ClassDecl { id, - required_items: _, generics, availability, superclasses, @@ -2073,7 +2044,7 @@ impl Stmt { location: _, availability: _, cls, - cls_required_items: _, + cls_superclasses: _, source_superclass, cls_generics, category_name, @@ -2161,7 +2132,7 @@ impl Stmt { actual_name, availability, cls, - cls_required_items, + cls_superclasses, methods, documentation, } => { @@ -2215,7 +2186,7 @@ impl Stmt { f, " {}", cfg_gate_ln( - cls_required_items, + iter::once(cls).chain(cls_superclasses), [self.location()], config, self.location() @@ -2233,11 +2204,11 @@ impl Stmt { Self::ProtocolImpl { location: id, cls, - cls_required_items: _, + cls_superclasses: _, cls_counterpart, generics, protocol, - protocol_required_items: _, + protocol_super_protocols: _, availability: _, } => { let (generic_bound, where_bound) = if !generics.is_empty() { @@ -2352,10 +2323,9 @@ impl Stmt { } Self::ProtocolDecl { id, - required_items: _, actual_name, availability, - protocols, + super_protocols, methods, required_sendable: _, required_mainthreadonly, @@ -2377,14 +2347,14 @@ impl Stmt { writeln!(f, " #[name = {actual_name:?}]")?; } write!(f, " pub unsafe trait {}", id.name)?; - if !protocols.is_empty() { - for (i, protocol) in protocols.iter().enumerate() { + if !super_protocols.is_empty() { + for (i, protocol) in super_protocols.iter().enumerate() { if i == 0 { write!(f, ": ")?; } else { write!(f, "+ ")?; } - write!(f, "{}", protocol.path())?; + write!(f, "{}", protocol.id.path())?; } } // TODO @@ -2397,7 +2367,7 @@ impl Stmt { // write!(f, "Send + Sync")?; // } if *required_mainthreadonly { - if protocols.is_empty() { + if super_protocols.is_empty() { write!(f, ": ")?; } else { write!(f, "+ ")?; @@ -3012,12 +2982,11 @@ impl Stmt { } pub(crate) fn encoding_test<'a>(&'a self, config: &'a Config) -> Option { - let (data, availability, cls, cls_required_items, cls_generics, methods) = match self { + let (data, availability, cls, cls_generics, methods) = match self { Stmt::ExternMethods { location, availability, cls, - cls_required_items, cls_generics, methods, .. @@ -3025,7 +2994,6 @@ impl Stmt { config.library(location.library_name()), availability, cls, - cls_required_items, &**cls_generics, methods, ), @@ -3033,14 +3001,12 @@ impl Stmt { id, availability, cls, - cls_required_items, methods, .. } => ( config.library(id.library_name()), availability, cls, - cls_required_items, &[] as &[_], methods, ), @@ -3052,7 +3018,7 @@ impl Stmt { write!( f, "{}", - simple_platform_gate(data, cls_required_items, &[], config) + simple_platform_gate(data, &self.required_items(), &[], config) )?; if let Some(check) = availability.check_is_available() { writeln!(f, " if {check} ")?; @@ -3076,7 +3042,7 @@ impl Stmt { simple_platform_gate( data, &method.required_items(), - cls_required_items, + &self.required_items(), config ) )?; @@ -3107,7 +3073,7 @@ impl Stmt { "{}", simple_platform_gate( config.library(id.library_name()), - &ty.required_items(), + &self.required_items(), &[], config, ) diff --git a/crates/header-translator/src/thread_safety.rs b/crates/header-translator/src/thread_safety.rs index 3456758a8..24ebbefa5 100644 --- a/crates/header-translator/src/thread_safety.rs +++ b/crates/header-translator/src/thread_safety.rs @@ -4,7 +4,8 @@ use serde::Deserialize; use crate::{ immediate_children, method::MethodModifiers, - stmt::{method_or_property_entities, parse_direct_protocols, parse_superclasses}, + protocol::parse_direct_protocols, + stmt::{method_or_property_entities, parse_superclasses}, unexposed_attr::UnexposedAttr, Context, ItemIdentifier, }; @@ -182,7 +183,7 @@ impl ThreadSafetyAttr { /// Information about thread-safety properties of a type. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub(crate) struct ThreadSafety { +pub struct ThreadSafety { /// What the attribute was explicitly declared as. explicit: Option, /// What the attribute was inferred to be. diff --git a/framework-crates/objc2-ui-kit/translation-config.toml b/framework-crates/objc2-ui-kit/translation-config.toml index d3e34489c..f456e4432 100644 --- a/framework-crates/objc2-ui-kit/translation-config.toml +++ b/framework-crates/objc2-ui-kit/translation-config.toml @@ -11,7 +11,7 @@ visionos = "1.0" external.UTType.module = "UniformTypeIdentifiers.UTType" external.CKContainer.module = "CloudKit.CKContainer" external.CKShare.module = "CloudKit.CKShare" -external.CKShare.required-items = ["CloudKit.CKRecord.CKRecord"] +external.CKShare.super-items = ["CloudKit.CKRecord.CKRecord"] external.CKShareMetadata.module = "CloudKit.CKShareMetadata" external.NSManagedObjectContext.module = "CoreData.NSManagedObjectContext" external.NSManagedObjectModel.module = "CoreData.NSManagedObjectModel"