diff --git a/guide/src/module.md b/guide/src/module.md index c9c7f78aaf5..057e27348a4 100644 --- a/guide/src/module.md +++ b/guide/src/module.md @@ -152,8 +152,38 @@ mod my_extension { # } ``` +The `#[pymodule]` macro automatically sets the `module` attribute of the `#[pyclass]` macros declared inside of it with its name. +For nested modules, the name of the parent module is automatically added. +In the following example, the `Unit` class will have for `module` `my_extension.submodule` because it is properly nested +but the `Ext` class will have for `module` the default `builtins` because it not nested. +```rust +# #[cfg(feature = "experimental-declarative-modules")] +# mod declarative_module_module_attr_test { +use pyo3::prelude::*; + +#[pyclass] +struct Ext; + + #[pymodule] +mod my_extension { + use super::*; + + #[pymodule_export] + use super::Ext; + + #[pymodule] + mod submodule { + // This is a submodule + + #[pyclass] // This will be part of the module + struct Unit; + } +} +# } +``` +It is possible to customize the `module` value for a `#[pymodule]` with the `#[pyo3(module = "MY_MODULE")]` option. + Some changes are planned to this feature before stabilization, like automatically -filling submodules into `sys.modules` to allow easier imports (see [issue #759](https://github.com/PyO3/pyo3/issues/759)) -and filling the `module` argument of inlined `#[pyclass]` automatically with the proper module name. +filling submodules into `sys.modules` to allow easier imports (see [issue #759](https://github.com/PyO3/pyo3/issues/759)). Macro names might also change. See [issue #3900](https://github.com/PyO3/pyo3/issues/3900) to track this feature progress. diff --git a/newsfragments/4213.added.md b/newsfragments/4213.added.md new file mode 100644 index 00000000000..6f553dc93ab --- /dev/null +++ b/newsfragments/4213.added.md @@ -0,0 +1 @@ +Properly fills the `module=` attribute of declarative modules child `#[pymodule]` and `#[pyclass]`. \ No newline at end of file diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index 756037263e3..71d776bf350 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -1,10 +1,13 @@ //! Code generation for the function that initializes a python module and adds classes and function. -use crate::utils::Ctx; use crate::{ - attributes::{self, take_attributes, take_pyo3_options, CrateAttribute, NameAttribute}, + attributes::{ + self, take_attributes, take_pyo3_options, CrateAttribute, ModuleAttribute, NameAttribute, + }, get_doc, + pyclass::PyClassPyO3Option, pyfunction::{impl_wrap_pyfunction, PyFunctionOptions}, + utils::Ctx, }; use proc_macro2::TokenStream; use quote::quote; @@ -12,15 +15,17 @@ use syn::{ ext::IdentExt, parse::{Parse, ParseStream}, parse_quote, parse_quote_spanned, + punctuated::Punctuated, spanned::Spanned, token::Comma, - Item, Path, Result, + Item, Meta, Path, Result, }; #[derive(Default)] pub struct PyModuleOptions { krate: Option, name: Option, + module: Option, } impl PyModuleOptions { @@ -31,6 +36,7 @@ impl PyModuleOptions { match option { PyModulePyO3Option::Name(name) => options.set_name(name.value.0)?, PyModulePyO3Option::Crate(path) => options.set_crate(path)?, + PyModulePyO3Option::Module(module) => options.set_module(module)?, } } @@ -56,6 +62,16 @@ impl PyModuleOptions { self.krate = Some(path); Ok(()) } + + fn set_module(&mut self, name: ModuleAttribute) -> Result<()> { + ensure_spanned!( + self.module.is_none(), + name.span() => "`module` may only be specified once" + ); + + self.module = Some(name); + Ok(()) + } } pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { @@ -77,6 +93,12 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { let ctx = &Ctx::new(&options.krate); let Ctx { pyo3_path } = ctx; let doc = get_doc(attrs, None); + let name = options.name.unwrap_or_else(|| ident.unraw()); + let full_name = if let Some(module) = &options.module { + format!("{}.{}", module.value.value(), name) + } else { + name.to_string() + }; let mut module_items = Vec::new(); let mut module_items_cfg_attrs = Vec::new(); @@ -156,6 +178,13 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { if has_attribute(&item_struct.attrs, "pyclass") { module_items.push(item_struct.ident.clone()); module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs)); + if !has_pyo3_module_declared::( + &item_struct.attrs, + "pyclass", + |option| matches!(option, PyClassPyO3Option::Module(_)), + )? { + set_module_attribute(&mut item_struct.attrs, &full_name); + } } } Item::Enum(item_enum) => { @@ -166,6 +195,13 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { if has_attribute(&item_enum.attrs, "pyclass") { module_items.push(item_enum.ident.clone()); module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs)); + if !has_pyo3_module_declared::( + &item_enum.attrs, + "pyclass", + |option| matches!(option, PyClassPyO3Option::Module(_)), + )? { + set_module_attribute(&mut item_enum.attrs, &full_name); + } } } Item::Mod(item_mod) => { @@ -176,6 +212,13 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { if has_attribute(&item_mod.attrs, "pymodule") { module_items.push(item_mod.ident.clone()); module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs)); + if !has_pyo3_module_declared::( + &item_mod.attrs, + "pymodule", + |option| matches!(option, PyModulePyO3Option::Module(_)), + )? { + set_module_attribute(&mut item_mod.attrs, &full_name); + } } } Item::ForeignMod(item) => { @@ -242,7 +285,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { } } - let initialization = module_initialization(options, ident); + let initialization = module_initialization(&name, ctx); Ok(quote!( #vis mod #ident { #(#items)* @@ -286,10 +329,11 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result let stmts = std::mem::take(&mut function.block.stmts); let Ctx { pyo3_path } = ctx; let ident = &function.sig.ident; + let name = options.name.unwrap_or_else(|| ident.unraw()); let vis = &function.vis; let doc = get_doc(&function.attrs, None); - let initialization = module_initialization(options, ident); + let initialization = module_initialization(&name, ctx); // Module function called with optional Python<'_> marker as first arg, followed by the module. let mut module_args = Vec::new(); @@ -354,9 +398,7 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result }) } -fn module_initialization(options: PyModuleOptions, ident: &syn::Ident) -> TokenStream { - let name = options.name.unwrap_or_else(|| ident.unraw()); - let ctx = &Ctx::new(&options.krate); +fn module_initialization(name: &syn::Ident, ctx: &Ctx) -> TokenStream { let Ctx { pyo3_path } = ctx; let pyinit_symbol = format!("PyInit_{}", name); @@ -491,9 +533,33 @@ fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool { attrs.iter().any(|attr| attr.path().is_ident(ident)) } +fn set_module_attribute(attrs: &mut Vec, module_name: &str) { + attrs.push(parse_quote!(#[pyo3(module = #module_name)])); +} + +fn has_pyo3_module_declared( + attrs: &[syn::Attribute], + root_attribute_name: &str, + is_module_option: impl Fn(&T) -> bool + Copy, +) -> Result { + for attr in attrs { + if (attr.path().is_ident("pyo3") || attr.path().is_ident(root_attribute_name)) + && matches!(attr.meta, Meta::List(_)) + { + for option in &attr.parse_args_with(Punctuated::::parse_terminated)? { + if is_module_option(option) { + return Ok(true); + } + } + } + } + Ok(false) +} + enum PyModulePyO3Option { Crate(CrateAttribute), Name(NameAttribute), + Module(ModuleAttribute), } impl Parse for PyModulePyO3Option { @@ -503,6 +569,8 @@ impl Parse for PyModulePyO3Option { input.parse().map(PyModulePyO3Option::Name) } else if lookahead.peek(syn::Token![crate]) { input.parse().map(PyModulePyO3Option::Crate) + } else if lookahead.peek(attributes::kw::module) { + input.parse().map(PyModulePyO3Option::Module) } else { Err(lookahead.error()) } diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 1e7f29d84c1..717fdfb3dea 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -78,7 +78,7 @@ pub struct PyClassPyO3Options { pub weakref: Option, } -enum PyClassPyO3Option { +pub enum PyClassPyO3Option { Crate(CrateAttribute), Dict(kw::dict), Eq(kw::eq), diff --git a/tests/test_declarative_module.rs b/tests/test_declarative_module.rs index 820cf63806d..0858f84e04a 100644 --- a/tests/test_declarative_module.rs +++ b/tests/test_declarative_module.rs @@ -3,6 +3,7 @@ use pyo3::create_exception; use pyo3::exceptions::PyException; use pyo3::prelude::*; +use pyo3::sync::GILOnceCell; #[cfg(not(Py_LIMITED_API))] use pyo3::types::PyBool; @@ -78,7 +79,7 @@ mod declarative_module { x * 3 } - #[pyclass] + #[pyclass(name = "Struct")] struct Struct; #[pymethods] @@ -89,12 +90,31 @@ mod declarative_module { } } - #[pyclass(eq, eq_int)] + #[pyclass(module = "foo")] + struct StructInCustomModule; + + #[pyclass(eq, eq_int, name = "Enum")] #[derive(PartialEq)] enum Enum { A, B, } + + #[pyclass(eq, eq_int, module = "foo")] + #[derive(PartialEq)] + enum EnumInCustomModule { + A, + B, + } + } + + #[pymodule] + #[pyo3(module = "custom_root")] + mod inner_custom_root { + use super::*; + + #[pyclass] + struct Struct; } #[pymodule_init] @@ -121,10 +141,17 @@ mod declarative_module2 { use super::double; } +fn declarative_module(py: Python<'_>) -> &Bound<'_, PyModule> { + static MODULE: GILOnceCell> = GILOnceCell::new(); + MODULE + .get_or_init(py, || pyo3::wrap_pymodule!(declarative_module)(py)) + .bind(py) +} + #[test] fn test_declarative_module() { Python::with_gil(|py| { - let m = pyo3::wrap_pymodule!(declarative_module)(py).into_bound(py); + let m = declarative_module(py); py_assert!( py, m, @@ -188,3 +215,27 @@ fn test_raw_ident_module() { py_assert!(py, m, "m.double(2) == 4"); }) } + +#[test] +fn test_module_names() { + Python::with_gil(|py| { + let m = declarative_module(py); + py_assert!( + py, + m, + "m.inner.Struct.__module__ == 'declarative_module.inner'" + ); + py_assert!(py, m, "m.inner.StructInCustomModule.__module__ == 'foo'"); + py_assert!( + py, + m, + "m.inner.Enum.__module__ == 'declarative_module.inner'" + ); + py_assert!(py, m, "m.inner.EnumInCustomModule.__module__ == 'foo'"); + py_assert!( + py, + m, + "m.inner_custom_root.Struct.__module__ == 'custom_root.inner_custom_root'" + ); + }) +}