From 1c4df967a0f072926c5f5bc83402fbf5a9a218e4 Mon Sep 17 00:00:00 2001 From: Ben Dean-Kawamura Date: Sat, 23 Sep 2023 15:37:59 -0400 Subject: [PATCH] Refactoring the trait interface code This is prep work for allowing the foreign bindings to implement traits. Moved trait interface code in `uniffi_macros` to its own module. Updated the test trait in the coverall fixture. I think this one will work better to test foreign trait impls. The `ancestor_names()` method is a good way to test trait implementations that bounce between the Rust and foreign side of the FFI. Added Getters test like we have with callback interfaces. We want to replace callback interfaces with trait interfaces, so we should have good test coverage. This actually revealed that the trait code didn't support exceptions yet. --- fixtures/coverall/src/coverall.udl | 46 +++-- fixtures/coverall/src/lib.rs | 10 +- fixtures/coverall/src/traits.rs | 184 ++++++++++++++---- .../coverall/tests/bindings/test_coverall.kts | 75 ++++++- .../coverall/tests/bindings/test_coverall.py | 57 +++++- .../tests/bindings/test_coverall.swift | 79 ++++++++ .../scaffolding/templates/ObjectTemplate.rs | 8 +- uniffi_macros/src/export.rs | 117 +---------- uniffi_macros/src/export/trait_interface.rs | 124 ++++++++++++ uniffi_macros/src/lib.rs | 8 - 10 files changed, 518 insertions(+), 190 deletions(-) create mode 100644 uniffi_macros/src/export/trait_interface.rs diff --git a/fixtures/coverall/src/coverall.udl b/fixtures/coverall/src/coverall.udl index 2c0d9550d6..6b1c2d304b 100644 --- a/fixtures/coverall/src/coverall.udl +++ b/fixtures/coverall/src/coverall.udl @@ -4,7 +4,7 @@ namespace coverall { u64 get_num_alive(); - sequence get_traits(); + sequence get_traits(); MaybeSimpleDict get_maybe_simple_dict(i8 index); @@ -17,6 +17,11 @@ namespace coverall { [Throws=CoverallRichErrorNoVariantData] void throw_rich_error_no_variant_data(); + + Getters make_rust_getters(); + void test_getters(Getters g); + + sequence ancestor_names(NodeTrait node); }; dictionary SimpleDict { @@ -41,7 +46,7 @@ dictionary SimpleDict { double float64; double? maybe_float64; Coveralls? coveralls; - TestTrait? test_trait; + NodeTrait? test_trait; }; dictionary DictWithDefaults { @@ -190,24 +195,37 @@ interface ThreadsafeCounter { i32 increment_if_busy(); }; -// This is a trait implemented on the Rust side. +// Test trait #1 +// +// The goal here is to test all possible arg, return, and error types. [Trait] -interface TestTrait { - string name(); // The name of the implementation +interface Getters { + boolean get_bool(boolean v, boolean arg2); + [Throws=CoverallError] + string get_string(string v, boolean arg2); + [Throws=ComplexError] + string? get_option(string v, boolean arg2); + sequence get_list(sequence v, boolean arg2); + void get_nothing(string v); +}; - [Self=ByArc] - u64 number(); +// Test trait #2 +// +// The goal here is test passing objects back and forth between Rust and the foreign side +[Trait] +interface NodeTrait { + string name(); // The name of the this node + + /// Takes an `Arc` and stores it our parent node, dropping any existing / reference. Note + //you can create circular references with this. + void set_parent(NodeTrait? parent); + + /// Returns what was previously set via `set_parent()`, or null. + NodeTrait? get_parent(); /// Calls `Arc::strong_count()` on the `Arc` containing `self`. [Self=ByArc] u64 strong_count(); - - /// Takes an `Arc` and stores it in `self`, dropping the existing - /// reference. Note you can create circular references by passing `self`. - void take_other(TestTrait? other); - - /// Returns what was previously set via `take_other()`, or null. - TestTrait? get_other(); }; // Forward/backward declarations are fine in UDL. diff --git a/fixtures/coverall/src/lib.rs b/fixtures/coverall/src/lib.rs index e591c0db79..82b7236cc9 100644 --- a/fixtures/coverall/src/lib.rs +++ b/fixtures/coverall/src/lib.rs @@ -10,11 +10,11 @@ use std::time::SystemTime; use once_cell::sync::Lazy; mod traits; -pub use traits::{get_traits, TestTrait}; +pub use traits::{ancestor_names, get_traits, make_rust_getters, test_getters, Getters, NodeTrait}; static NUM_ALIVE: Lazy> = Lazy::new(|| RwLock::new(0)); -#[derive(Debug, thiserror::Error)] +#[derive(Debug, thiserror::Error, PartialEq, Eq)] pub enum CoverallError { #[error("The coverall has too many holes")] TooManyHoles, @@ -80,7 +80,7 @@ impl From for CoverallError { } } -#[derive(Debug, thiserror::Error)] +#[derive(Debug, thiserror::Error, PartialEq, Eq)] pub enum ComplexError { #[error("OsError: {code} ({extended_code})")] OsError { code: i16, extended_code: i16 }, @@ -131,7 +131,7 @@ pub struct SimpleDict { float64: f64, maybe_float64: Option, coveralls: Option>, - test_trait: Option>, + test_trait: Option>, } #[derive(Debug, Clone)] @@ -204,7 +204,7 @@ fn create_some_dict() -> SimpleDict { float64: 0.0, maybe_float64: Some(1.0), coveralls: Some(Arc::new(Coveralls::new("some_dict".to_string()))), - test_trait: Some(Arc::new(traits::Trait2 {})), + test_trait: Some(Arc::new(traits::Trait2::default())), } } diff --git a/fixtures/coverall/src/traits.rs b/fixtures/coverall/src/traits.rs index fe9694ef0e..15785ef0c6 100644 --- a/fixtures/coverall/src/traits.rs +++ b/fixtures/coverall/src/traits.rs @@ -2,73 +2,187 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ +use super::{ComplexError, CoverallError}; use std::sync::{Arc, Mutex}; // namespace functions. -pub fn get_traits() -> Vec> { - vec![ - Arc::new(Trait1 { - ..Default::default() - }), - Arc::new(Trait2 {}), - ] +pub fn get_traits() -> Vec> { + vec![Arc::new(Trait1::default()), Arc::new(Trait2::default())] } -pub trait TestTrait: Send + Sync + std::fmt::Debug { +pub trait NodeTrait: Send + Sync + std::fmt::Debug { fn name(&self) -> String; - fn number(self: Arc) -> u64; + fn set_parent(&self, parent: Option>); + + fn get_parent(&self) -> Option>; fn strong_count(self: Arc) -> u64 { Arc::strong_count(&self) as u64 } +} - fn take_other(&self, other: Option>); +pub fn ancestor_names(node: Arc) -> Vec { + let mut names = vec![]; + let mut parent = node.get_parent(); + while let Some(node) = parent { + names.push(node.name()); + parent = node.get_parent(); + } + names +} - fn get_other(&self) -> Option>; +/// Test trait +/// +/// The goal here is to test all possible arg, return, and error types. +pub trait Getters: Send + Sync { + fn get_bool(&self, v: bool, arg2: bool) -> bool; + fn get_string(&self, v: String, arg2: bool) -> Result; + fn get_option(&self, v: String, arg2: bool) -> Result, ComplexError>; + fn get_list(&self, v: Vec, arg2: bool) -> Vec; + fn get_nothing(&self, v: String); +} + +struct RustGetters; + +impl Getters for RustGetters { + fn get_bool(&self, v: bool, arg2: bool) -> bool { + v ^ arg2 + } + + fn get_string(&self, v: String, arg2: bool) -> Result { + if v == "too-many-holes" { + Err(CoverallError::TooManyHoles) + } else if v == "unexpected-error" { + panic!("unexpected error") + } else if arg2 { + Ok(v.to_uppercase()) + } else { + Ok(v) + } + } + + fn get_option(&self, v: String, arg2: bool) -> Result, ComplexError> { + if v == "os-error" { + Err(ComplexError::OsError { + code: 100, + extended_code: 200, + }) + } else if v == "unknown-error" { + Err(ComplexError::UnknownError) + } else if arg2 { + if !v.is_empty() { + Ok(Some(v.to_uppercase())) + } else { + Ok(None) + } + } else { + Ok(Some(v)) + } + } + + fn get_list(&self, v: Vec, arg2: bool) -> Vec { + if arg2 { + v + } else { + vec![] + } + } + + fn get_nothing(&self, _v: String) {} +} + +pub fn make_rust_getters() -> Arc { + Arc::new(RustGetters) +} + +pub fn test_getters(getters: Arc) { + assert!(!getters.get_bool(true, true)); + assert!(getters.get_bool(true, false)); + assert!(getters.get_bool(false, true)); + assert!(!getters.get_bool(false, false)); + + assert_eq!( + getters.get_string("hello".to_owned(), false).unwrap(), + "hello" + ); + assert_eq!( + getters.get_string("hello".to_owned(), true).unwrap(), + "HELLO" + ); + + assert_eq!( + getters.get_option("hello".to_owned(), true).unwrap(), + Some("HELLO".to_owned()) + ); + assert_eq!( + getters.get_option("hello".to_owned(), false).unwrap(), + Some("hello".to_owned()) + ); + assert_eq!(getters.get_option("".to_owned(), true).unwrap(), None); + + assert_eq!(getters.get_list(vec![1, 2, 3], true), vec![1, 2, 3]); + assert_eq!(getters.get_list(vec![1, 2, 3], false), Vec::::new()); + + // Call get_nothing to make sure it doesn't panic. There's no point in checking the output + // though + getters.get_nothing("hello".to_owned()); + + assert_eq!( + getters.get_string("too-many-holes".to_owned(), true), + Err(CoverallError::TooManyHoles) + ); + assert_eq!( + getters.get_option("os-error".to_owned(), true), + Err(ComplexError::OsError { + code: 100, + extended_code: 200 + }) + ); + assert_eq!( + getters.get_option("unknown-error".to_owned(), true), + Err(ComplexError::UnknownError) + ); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + getters.get_string("unexpected-error".to_owned(), true) + })); + assert!(result.is_err()); } #[derive(Debug, Default)] pub(crate) struct Trait1 { // A reference to another trait. - other: Mutex>>, + parent: Mutex>>, } -impl TestTrait for Trait1 { +impl NodeTrait for Trait1 { fn name(&self) -> String { - "trait 1".to_string() + "node-1".to_string() } - fn number(self: Arc) -> u64 { - 1_u64 + fn set_parent(&self, parent: Option>) { + *self.parent.lock().unwrap() = parent.map(|arc| Arc::clone(&arc)) } - fn take_other(&self, other: Option>) { - *self.other.lock().unwrap() = other.map(|arc| Arc::clone(&arc)) - } - - fn get_other(&self) -> Option> { - (*self.other.lock().unwrap()).as_ref().map(Arc::clone) + fn get_parent(&self) -> Option> { + (*self.parent.lock().unwrap()).as_ref().map(Arc::clone) } } -#[derive(Debug)] -pub(crate) struct Trait2 {} -impl TestTrait for Trait2 { +#[derive(Debug, Default)] +pub(crate) struct Trait2 { + parent: Mutex>>, +} +impl NodeTrait for Trait2 { fn name(&self) -> String { - "trait 2".to_string() - } - - fn number(self: Arc) -> u64 { - 2_u64 + "node-2".to_string() } - // Don't bother implementing these here - the test on the struct above is ok. - fn take_other(&self, _other: Option>) { - unimplemented!(); + fn set_parent(&self, parent: Option>) { + *self.parent.lock().unwrap() = parent.map(|arc| Arc::clone(&arc)) } - fn get_other(&self) -> Option> { - unimplemented!() + fn get_parent(&self) -> Option> { + (*self.parent.lock().unwrap()).as_ref().map(Arc::clone) } } diff --git a/fixtures/coverall/tests/bindings/test_coverall.kts b/fixtures/coverall/tests/bindings/test_coverall.kts index 62fbbb5a70..3cefa65783 100644 --- a/fixtures/coverall/tests/bindings/test_coverall.kts +++ b/fixtures/coverall/tests/bindings/test_coverall.kts @@ -210,8 +210,79 @@ Coveralls("test_interfaces_in_dicts").use { coveralls -> assert(coveralls.getRepairs().size == 2) } -Coveralls("test_regressions").use { coveralls -> - assert(coveralls.getStatus("success") == "status: success") +// Test traits implemented in Rust +makeRustGetters().let { rustGetters -> + testGetters(rustGetters) + testGettersFromKotlin(rustGetters) +} + +fun testGettersFromKotlin(getters: Getters) { + assert(getters.getBool(true, true) == false); + assert(getters.getBool(true, false) == true); + assert(getters.getBool(false, true) == true); + assert(getters.getBool(false, false) == false); + + assert(getters.getString("hello", false) == "hello"); + assert(getters.getString("hello", true) == "HELLO"); + + assert(getters.getOption("hello", true) == "HELLO"); + assert(getters.getOption("hello", false) == "hello"); + assert(getters.getOption("", true) == null); + + assert(getters.getList(listOf(1, 2, 3), true) == listOf(1, 2, 3)) + assert(getters.getList(listOf(1, 2, 3), false) == listOf()) + + assert(getters.getNothing("hello") == Unit); + + try { + getters.getString("too-many-holes", true) + throw RuntimeException("Expected method to throw exception") + } catch(e: CoverallException.TooManyHoles) { + // Expected + } + + try { + getters.getOption("os-error", true) + throw RuntimeException("Expected method to throw exception") + } catch(e: ComplexException.OsException) { + assert(e.code.toInt() == 100) + assert(e.extendedCode.toInt() == 200) + } + + try { + getters.getOption("unknown-error", true) + throw RuntimeException("Expected method to throw exception") + } catch(e: ComplexException.UnknownException) { + // Expected + } + + try { + getters.getString("unexpected-error", true) + } catch(e: InternalException) { + // Expected + } +} + +// Test NodeTrait +getTraits().let { traits -> + assert(traits[0].name() == "node-1") + // Note: strong counts are 1 more than you might expect, because the strongCount() method + // holds a strong ref. + assert(traits[0].strongCount() == 2UL) + + assert(traits[1].name() == "node-2") + assert(traits[1].strongCount() == 2UL) + + traits[0].setParent(traits[1]) + assert(ancestorNames(traits[0]) == listOf("node-2")) + assert(ancestorNames(traits[1]).isEmpty()) + assert(traits[1].strongCount() == 3UL) + assert(traits[0].getParent()!!.name() == "node-2") + traits[0].setParent(null) + + Coveralls("test_regressions").use { coveralls -> + assert(coveralls.getStatus("success") == "status: success") + } } // This tests that the UniFFI-generated scaffolding doesn't introduce any unexpected locking. diff --git a/fixtures/coverall/tests/bindings/test_coverall.py b/fixtures/coverall/tests/bindings/test_coverall.py index a63a2e80b9..4f9d1e2df0 100644 --- a/fixtures/coverall/tests/bindings/test_coverall.py +++ b/fixtures/coverall/tests/bindings/test_coverall.py @@ -31,7 +31,7 @@ def test_some_dict(self): self.assertEqual(d.signed64, 9223372036854775807) self.assertEqual(d.maybe_signed64, 0) self.assertEqual(d.coveralls.get_name(), "some_dict") - self.assertEqual(d.test_trait.name(), "trait 2") + self.assertEqual(d.test_trait.name(), "node-2") # floats should be "close enough" - although it's mildly surprising that # we need to specify `places=6` whereas the default is 7. @@ -279,20 +279,59 @@ def test_bytes(self): self.assertEqual(coveralls.reverse(b"123"), b"321") class TraitsTest(unittest.TestCase): - def test_simple(self): + # Test traits implemented in Rust + def test_rust_getters(self): + test_getters(make_rust_getters()) + self.check_getters_from_python(make_rust_getters()) + + def check_getters_from_python(self, getters): + self.assertEqual(getters.get_bool(True, True), False); + self.assertEqual(getters.get_bool(True, False), True); + self.assertEqual(getters.get_bool(False, True), True); + self.assertEqual(getters.get_bool(False, False), False); + + self.assertEqual(getters.get_string("hello", False), "hello"); + self.assertEqual(getters.get_string("hello", True), "HELLO"); + + self.assertEqual(getters.get_option("hello", True), "HELLO"); + self.assertEqual(getters.get_option("hello", False), "hello"); + self.assertEqual(getters.get_option("", True), None); + + self.assertEqual(getters.get_list([1, 2, 3], True), [1, 2, 3]); + self.assertEqual(getters.get_list([1, 2, 3], False), []) + + self.assertEqual(getters.get_nothing("hello"), None); + + with self.assertRaises(CoverallError.TooManyHoles): + getters.get_string("too-many-holes", True) + + with self.assertRaises(ComplexError.OsError) as cm: + getters.get_option("os-error", True) + self.assertEqual(cm.exception.code, 100) + self.assertEqual(cm.exception.extended_code, 200) + + with self.assertRaises(ComplexError.UnknownError): + getters.get_option("unknown-error", True) + + with self.assertRaises(InternalError): + getters.get_string("unexpected-error", True) + + def test_node(self): traits = get_traits() - self.assertEqual(traits[0].name(), "trait 1") - self.assertEqual(traits[0].number(), 1) + self.assertEqual(traits[0].name(), "node-1") + # Note: strong counts are 1 more than you might expect, because the strong_count() method + # holds a strong ref. self.assertEqual(traits[0].strong_count(), 2) - self.assertEqual(traits[1].name(), "trait 2") - self.assertEqual(traits[1].number(), 2) + self.assertEqual(traits[1].name(), "node-2") self.assertEqual(traits[1].strong_count(), 2) - traits[0].take_other(traits[1]) + traits[0].set_parent(traits[1]) + self.assertEqual(ancestor_names(traits[0]), ["node-2"]) + self.assertEqual(ancestor_names(traits[1]), []) self.assertEqual(traits[1].strong_count(), 3) - self.assertEqual(traits[0].get_other().name(), "trait 2") - traits[0].take_other(None) + self.assertEqual(traits[0].get_parent().name(), "node-2") + traits[0].set_parent(None) if __name__=='__main__': unittest.main() diff --git a/fixtures/coverall/tests/bindings/test_coverall.swift b/fixtures/coverall/tests/bindings/test_coverall.swift index d380881184..8db21ad78a 100644 --- a/fixtures/coverall/tests/bindings/test_coverall.swift +++ b/fixtures/coverall/tests/bindings/test_coverall.swift @@ -247,3 +247,82 @@ do { let coveralls = Coveralls(name: "test_bytes") assert(coveralls.reverse(value: Data("123".utf8)) == Data("321".utf8)) } + +// Test traits implemented in Rust +do { + let rustGetters = makeRustGetters() + testGetters(g: rustGetters) + testGettersFromSwift(getters: rustGetters) +} + +func testGettersFromSwift(getters: Getters) { + assert(getters.getBool(v: true, arg2: true) == false); + assert(getters.getBool(v: true, arg2: false) == true); + assert(getters.getBool(v: false, arg2: true) == true); + assert(getters.getBool(v: false, arg2: false) == false); + + assert(try! getters.getString(v: "hello", arg2: false) == "hello"); + assert(try! getters.getString(v: "hello", arg2: true) == "HELLO"); + + assert(try! getters.getOption(v: "hello", arg2: true) == "HELLO"); + assert(try! getters.getOption(v: "hello", arg2: false) == "hello"); + assert(try! getters.getOption(v: "", arg2: true) == nil); + + assert(getters.getList(v: [1, 2, 3], arg2: true) == [1, 2, 3]) + assert(getters.getList(v: [1, 2, 3], arg2: false) == []) + + assert(getters.getNothing(v: "hello") == ()); + + do { + let _ = try getters.getString(v: "too-many-holes", arg2: true) + fatalError("should have thrown") + } catch CoverallError.TooManyHoles { + // Expected + } catch { + fatalError("Unexpected error: \(error)") + } + + do { + try getters.getOption(v: "os-error", arg2: true) + fatalError("should have thrown") + } catch ComplexError.OsError(let code, let extendedCode) { + assert(code == 100) + assert(extendedCode == 200) + } catch { + fatalError("Unexpected error: \(error)") + } + + do { + try getters.getOption(v: "unknown-error", arg2: true) + fatalError("should have thrown") + } catch ComplexError.UnknownError { + // Expected + } catch { + fatalError("Unexpected error: \(error)") + } + + do { + try getters.getString(v: "unexpected-error", arg2: true) + } catch { + // Expected + } +} + +// Test Node trait +do { + let traits = getTraits() + assert(traits[0].name() == "node-1") + // Note: strong counts are 1 more than you might expect, because the strongCount() method + // holds a strong ref. + assert(traits[0].strongCount() == 2) + + assert(traits[1].name() == "node-2") + assert(traits[1].strongCount() == 2) + + traits[0].setParent(parent: traits[1]) + assert(ancestorNames(node: traits[0]) == ["node-2"]) + assert(ancestorNames(node: traits[1]) == []) + assert(traits[1].strongCount() == 3) + assert(traits[0].getParent()!.name() == "node-2") + traits[0].setParent(parent: nil) +} diff --git a/uniffi_bindgen/src/scaffolding/templates/ObjectTemplate.rs b/uniffi_bindgen/src/scaffolding/templates/ObjectTemplate.rs index 3346d4fbfb..e2445c670d 100644 --- a/uniffi_bindgen/src/scaffolding/templates/ObjectTemplate.rs +++ b/uniffi_bindgen/src/scaffolding/templates/ObjectTemplate.rs @@ -21,9 +21,11 @@ pub trait r#{{ obj.name() }} { {{ arg.name() }}: {% if arg.by_ref() %}&{% endif %}{{ arg.as_type().borrow()|type_rs }}, {%- endfor %} ) - {%- match meth.return_type() %} - {%- when Some(return_type) %} -> {{ return_type|type_rs }}; - {%- when None %}; + {%- match (meth.return_type(), meth.throws_type()) %} + {%- when (Some(return_type), None) %} -> {{ return_type|type_rs }}; + {%- when (Some(return_type), Some(error_type)) %} -> ::std::result::Result::<{{ return_type|type_rs }}, {{ error_type|type_rs }}>; + {%- when (None, Some(error_type)) %} -> ::std::result::Result::<(), {{ error_type|type_rs }}>; + {%- when (None, None) %}; {%- endmatch %} {% endfor %} } diff --git a/uniffi_macros/src/export.rs b/uniffi_macros/src/export.rs index bbb16acf90..bc7f8b40ea 100644 --- a/uniffi_macros/src/export.rs +++ b/uniffi_macros/src/export.rs @@ -10,6 +10,7 @@ mod attributes; mod callback_interface; mod item; mod scaffolding; +mod trait_interface; mod utrait; use self::{ @@ -18,13 +19,9 @@ use self::{ gen_constructor_scaffolding, gen_ffi_function, gen_fn_scaffolding, gen_method_scaffolding, }, }; -use crate::{ - object::interface_meta_static_var, - util::{ident_to_string, mod_path, tagged_impl_header}, -}; +use crate::util::{ident_to_string, mod_path}; pub use attributes::ExportAttributeArguments; pub use callback_interface::ffi_converter_callback_interface_impl; -use uniffi_meta::free_fn_symbol_name; // TODO(jplatte): Ensure no generics, … // TODO(jplatte): Aggregate errors instead of short-circuiting, wherever possible @@ -71,59 +68,7 @@ pub(crate) fn expand_export( items, self_ident, callback_interface: false, - } => { - if let Some(rt) = args.async_runtime { - return Err(syn::Error::new_spanned(rt, "not supported for traits")); - } - - let name = ident_to_string(&self_ident); - let free_fn_ident = - Ident::new(&free_fn_symbol_name(&mod_path, &name), Span::call_site()); - - let free_tokens = quote! { - #[doc(hidden)] - #[no_mangle] - pub extern "C" fn #free_fn_ident( - ptr: *const ::std::ffi::c_void, - call_status: &mut ::uniffi::RustCallStatus - ) { - uniffi::rust_call(call_status, || { - assert!(!ptr.is_null()); - drop(unsafe { ::std::boxed::Box::from_raw(ptr as *mut std::sync::Arc) }); - Ok(()) - }); - } - }; - - let impl_tokens: TokenStream = items - .into_iter() - .map(|item| match item { - ImplItem::Method(sig) => { - if sig.is_async { - return Err(syn::Error::new( - sig.span, - "async trait methods are not supported", - )); - } - gen_method_scaffolding(sig, &args, udl_mode) - } - _ => unreachable!("traits have no constructors"), - }) - .collect::>()?; - - let meta_static_var = (!udl_mode).then(|| { - interface_meta_static_var(&self_ident, true, &mod_path) - .unwrap_or_else(syn::Error::into_compile_error) - }); - let ffi_converter_tokens = ffi_converter_trait_impl(&self_ident, false); - - Ok(quote_spanned! { self_ident.span() => - #meta_static_var - #free_tokens - #ffi_converter_tokens - #impl_tokens - }) - } + } => trait_interface::gen_trait_scaffolding(&mod_path, args, self_ident, items, udl_mode), ExportItem::Trait { items, self_ident, @@ -182,62 +127,6 @@ pub(crate) fn expand_export( } } -pub(crate) fn ffi_converter_trait_impl(trait_ident: &Ident, udl_mode: bool) -> TokenStream { - let impl_spec = tagged_impl_header("FfiConverterArc", "e! { dyn #trait_ident }, udl_mode); - let lift_ref_impl_spec = tagged_impl_header("LiftRef", "e! { dyn #trait_ident }, udl_mode); - let name = ident_to_string(trait_ident); - let mod_path = match mod_path() { - Ok(p) => p, - Err(e) => return e.into_compile_error(), - }; - - quote! { - // All traits must be `Sync + Send`. The generated scaffolding will fail to compile - // if they are not, but unfortunately it fails with an unactionably obscure error message. - // By asserting the requirement explicitly, we help Rust produce a more scrutable error message - // and thus help the user debug why the requirement isn't being met. - uniffi::deps::static_assertions::assert_impl_all!(dyn #trait_ident: Sync, Send); - - unsafe #impl_spec { - type FfiType = *const ::std::os::raw::c_void; - - fn lower(obj: ::std::sync::Arc) -> Self::FfiType { - ::std::boxed::Box::into_raw(::std::boxed::Box::new(obj)) as *const ::std::os::raw::c_void - } - - fn try_lift(v: Self::FfiType) -> ::uniffi::Result<::std::sync::Arc> { - let foreign_arc = ::std::boxed::Box::leak(unsafe { Box::from_raw(v as *mut ::std::sync::Arc) }); - // Take a clone for our own use. - Ok(::std::sync::Arc::clone(foreign_arc)) - } - - fn write(obj: ::std::sync::Arc, buf: &mut Vec) { - ::uniffi::deps::static_assertions::const_assert!(::std::mem::size_of::<*const ::std::ffi::c_void>() <= 8); - ::uniffi::deps::bytes::BufMut::put_u64( - buf, - >::lower(obj) as u64, - ); - } - - fn try_read(buf: &mut &[u8]) -> ::uniffi::Result<::std::sync::Arc> { - ::uniffi::deps::static_assertions::const_assert!(::std::mem::size_of::<*const ::std::ffi::c_void>() <= 8); - ::uniffi::check_remaining(buf, 8)?; - >::try_lift( - ::uniffi::deps::bytes::Buf::get_u64(buf) as Self::FfiType) - } - - const TYPE_ID_META: ::uniffi::MetadataBuffer = ::uniffi::MetadataBuffer::from_code(::uniffi::metadata::codes::TYPE_INTERFACE) - .concat_str(#mod_path) - .concat_str(#name) - .concat_bool(true); - } - - unsafe #lift_ref_impl_spec { - type LiftType = ::std::sync::Arc; - } - } -} - /// Rewrite Self type alias usage in an impl block to the type itself. /// /// For example, diff --git a/uniffi_macros/src/export/trait_interface.rs b/uniffi_macros/src/export/trait_interface.rs new file mode 100644 index 0000000000..e6cdaff5c9 --- /dev/null +++ b/uniffi_macros/src/export/trait_interface.rs @@ -0,0 +1,124 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use proc_macro2::{Ident, Span, TokenStream}; +use quote::{quote, quote_spanned}; + +use crate::{ + export::{attributes::ExportAttributeArguments, gen_method_scaffolding, item::ImplItem}, + object::interface_meta_static_var, + util::{ident_to_string, tagged_impl_header}, +}; +use uniffi_meta::free_fn_symbol_name; + +pub(super) fn gen_trait_scaffolding( + mod_path: &str, + args: ExportAttributeArguments, + self_ident: Ident, + items: Vec, + udl_mode: bool, +) -> syn::Result { + if let Some(rt) = args.async_runtime { + return Err(syn::Error::new_spanned(rt, "not supported for traits")); + } + + let name = ident_to_string(&self_ident); + let free_fn_ident = Ident::new(&free_fn_symbol_name(mod_path, &name), Span::call_site()); + + let free_tokens = quote! { + #[doc(hidden)] + #[no_mangle] + pub extern "C" fn #free_fn_ident( + ptr: *const ::std::ffi::c_void, + call_status: &mut ::uniffi::RustCallStatus + ) { + uniffi::rust_call(call_status, || { + assert!(!ptr.is_null()); + drop(unsafe { ::std::boxed::Box::from_raw(ptr as *mut std::sync::Arc) }); + Ok(()) + }); + } + }; + + let impl_tokens: TokenStream = items + .into_iter() + .map(|item| match item { + ImplItem::Method(sig) => { + if sig.is_async { + return Err(syn::Error::new( + sig.span, + "async trait methods are not supported", + )); + } + gen_method_scaffolding(sig, &args, udl_mode) + } + _ => unreachable!("traits have no constructors"), + }) + .collect::>()?; + + let meta_static_var = (!udl_mode).then(|| { + interface_meta_static_var(&self_ident, true, mod_path) + .unwrap_or_else(syn::Error::into_compile_error) + }); + let ffi_converter_tokens = ffi_converter(mod_path, &self_ident, false); + + Ok(quote_spanned! { self_ident.span() => + #meta_static_var + #free_tokens + #ffi_converter_tokens + #impl_tokens + }) +} + +pub(crate) fn ffi_converter(mod_path: &str, trait_ident: &Ident, udl_mode: bool) -> TokenStream { + let impl_spec = tagged_impl_header("FfiConverterArc", "e! { dyn #trait_ident }, udl_mode); + let lift_ref_impl_spec = tagged_impl_header("LiftRef", "e! { dyn #trait_ident }, udl_mode); + let name = ident_to_string(trait_ident); + + quote! { + // All traits must be `Sync + Send`. The generated scaffolding will fail to compile + // if they are not, but unfortunately it fails with an unactionably obscure error message. + // By asserting the requirement explicitly, we help Rust produce a more scrutable error message + // and thus help the user debug why the requirement isn't being met. + uniffi::deps::static_assertions::assert_impl_all!(dyn #trait_ident: Sync, Send); + + unsafe #impl_spec { + type FfiType = *const ::std::os::raw::c_void; + + fn lower(obj: ::std::sync::Arc) -> Self::FfiType { + ::std::boxed::Box::into_raw(::std::boxed::Box::new(obj)) as *const ::std::os::raw::c_void + } + + fn try_lift(v: Self::FfiType) -> ::uniffi::Result<::std::sync::Arc> { + let foreign_arc = ::std::boxed::Box::leak(unsafe { Box::from_raw(v as *mut ::std::sync::Arc) }); + // Take a clone for our own use. + Ok(::std::sync::Arc::clone(foreign_arc)) + } + + fn write(obj: ::std::sync::Arc, buf: &mut Vec) { + ::uniffi::deps::static_assertions::const_assert!(::std::mem::size_of::<*const ::std::ffi::c_void>() <= 8); + ::uniffi::deps::bytes::BufMut::put_u64( + buf, + >::lower(obj) as u64, + ); + } + + fn try_read(buf: &mut &[u8]) -> ::uniffi::Result<::std::sync::Arc> { + ::uniffi::deps::static_assertions::const_assert!(::std::mem::size_of::<*const ::std::ffi::c_void>() <= 8); + ::uniffi::check_remaining(buf, 8)?; + >::try_lift( + ::uniffi::deps::bytes::Buf::get_u64(buf) as Self::FfiType) + } + + const TYPE_ID_META: ::uniffi::MetadataBuffer = ::uniffi::MetadataBuffer::from_code(::uniffi::metadata::codes::TYPE_INTERFACE) + .concat_str(#mod_path) + .concat_str(#name) + .concat_bool(true); + } + + unsafe #lift_ref_impl_spec { + type LiftType = ::std::sync::Arc; + } + } +} diff --git a/uniffi_macros/src/lib.rs b/uniffi_macros/src/lib.rs index 4cffddfa0e..b7ba86ddc1 100644 --- a/uniffi_macros/src/lib.rs +++ b/uniffi_macros/src/lib.rs @@ -257,14 +257,6 @@ pub fn export_for_udl(attrs: TokenStream, input: TokenStream) -> TokenStream { do_export(attrs, input, true) } -/// Generate various support elements, including the FfiConverter implementation, -/// for a trait interface for the scaffolding code -#[doc(hidden)] -#[proc_macro] -pub fn expand_trait_interface_support(tokens: TokenStream) -> TokenStream { - export::ffi_converter_trait_impl(&syn::parse_macro_input!(tokens), true).into() -} - /// Generate the FfiConverter implementation for an trait interface for the scaffolding code #[doc(hidden)] #[proc_macro]