Skip to content

Commit

Permalink
Automated module= field creation in declarative modules
Browse files Browse the repository at this point in the history
Sets automatically the "module" field of all contained classes and submodules in a declarative module

Adds the "module" field to pymodule attributes in order to set the name of the parent modules. By default, the module is assumed to be a root module
  • Loading branch information
Tpt committed Jun 5, 2024
1 parent 93ef056 commit 9c9a0a4
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 14 deletions.
34 changes: 32 additions & 2 deletions guide/src/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,38 @@ mod my_extension {
# }
```

The `#[pymodule]` macro automatically sets the `module` attribute of the `#[pyclass]` macros declared inside of it with its name.
For nested modules, the name of the parent module is automatically added.
In the following example, the `Unit` class will have for `module` `my_extension.submodule` because it is properly nested
but the `Ext` class will have for `module` the default `builtins` because it not nested.
```rust
# #[cfg(feature = "experimental-declarative-modules")]
# mod declarative_module_module_attr_test {
use pyo3::prelude::*;

#[pyclass]
struct Ext;

#[pymodule]
mod my_extension {
use super::*;

#[pymodule_export]
use super::Ext;

#[pymodule]
mod submodule {
// This is a submodule

#[pyclass] // This will be part of the module
struct Unit;
}
}
# }
```
It is possible to customize the `module` value for a `#[pymodule]` with the `#[pyo3(module = "MY_MODULE")]` option.

Some changes are planned to this feature before stabilization, like automatically
filling submodules into `sys.modules` to allow easier imports (see [issue #759](https://github.com/PyO3/pyo3/issues/759))
and filling the `module` argument of inlined `#[pyclass]` automatically with the proper module name.
filling submodules into `sys.modules` to allow easier imports (see [issue #759](https://github.com/PyO3/pyo3/issues/759)).
Macro names might also change.
See [issue #3900](https://github.com/PyO3/pyo3/issues/3900) to track this feature progress.
1 change: 1 addition & 0 deletions newsfragments/4213.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Properly fills the `module=` attribute of declarative modules child `#[pymodule]` and `#[pyclass]`.
84 changes: 76 additions & 8 deletions pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
//! Code generation for the function that initializes a python module and adds classes and function.
use crate::utils::Ctx;
use crate::{
attributes::{self, take_attributes, take_pyo3_options, CrateAttribute, NameAttribute},
attributes::{
self, take_attributes, take_pyo3_options, CrateAttribute, ModuleAttribute, NameAttribute,
},
get_doc,
pyclass::PyClassPyO3Option,
pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
utils::Ctx,
};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{
ext::IdentExt,
parse::{Parse, ParseStream},
parse_quote, parse_quote_spanned,
punctuated::Punctuated,
spanned::Spanned,
token::Comma,
Item, Path, Result,
Item, Meta, Path, Result,
};

#[derive(Default)]
pub struct PyModuleOptions {
krate: Option<CrateAttribute>,
name: Option<syn::Ident>,
module: Option<ModuleAttribute>,
}

impl PyModuleOptions {
Expand All @@ -31,6 +36,7 @@ impl PyModuleOptions {
match option {
PyModulePyO3Option::Name(name) => options.set_name(name.value.0)?,
PyModulePyO3Option::Crate(path) => options.set_crate(path)?,
PyModulePyO3Option::Module(module) => options.set_module(module)?,
}
}

Expand All @@ -56,6 +62,16 @@ impl PyModuleOptions {
self.krate = Some(path);
Ok(())
}

fn set_module(&mut self, name: ModuleAttribute) -> Result<()> {
ensure_spanned!(
self.module.is_none(),
name.span() => "`module` may only be specified once"
);

self.module = Some(name);
Ok(())
}
}

pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
Expand All @@ -77,6 +93,12 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
let ctx = &Ctx::new(&options.krate);
let Ctx { pyo3_path } = ctx;
let doc = get_doc(attrs, None);
let name = options.name.unwrap_or_else(|| ident.unraw());
let full_name = if let Some(module) = &options.module {
format!("{}.{}", module.value.value(), name)
} else {
name.to_string()
};

let mut module_items = Vec::new();
let mut module_items_cfg_attrs = Vec::new();
Expand Down Expand Up @@ -156,6 +178,13 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
if has_attribute(&item_struct.attrs, "pyclass") {
module_items.push(item_struct.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs));
if !has_pyo3_module_declared::<PyClassPyO3Option>(
&item_struct.attrs,
"pyclass",
|option| matches!(option, PyClassPyO3Option::Module(_)),
)? {
set_module_attribute(&mut item_struct.attrs, &full_name);
}
}
}
Item::Enum(item_enum) => {
Expand All @@ -166,6 +195,13 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
if has_attribute(&item_enum.attrs, "pyclass") {
module_items.push(item_enum.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs));
if !has_pyo3_module_declared::<PyClassPyO3Option>(
&item_enum.attrs,
"pyclass",
|option| matches!(option, PyClassPyO3Option::Module(_)),
)? {
set_module_attribute(&mut item_enum.attrs, &full_name);
}
}
}
Item::Mod(item_mod) => {
Expand All @@ -176,6 +212,13 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
if has_attribute(&item_mod.attrs, "pymodule") {
module_items.push(item_mod.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs));
if !has_pyo3_module_declared::<PyModulePyO3Option>(
&item_mod.attrs,
"pymodule",
|option| matches!(option, PyModulePyO3Option::Module(_)),
)? {
set_module_attribute(&mut item_mod.attrs, &full_name);
}
}
}
Item::ForeignMod(item) => {
Expand Down Expand Up @@ -242,7 +285,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
}
}

let initialization = module_initialization(options, ident);
let initialization = module_initialization(&name, ctx);
Ok(quote!(
#vis mod #ident {
#(#items)*
Expand Down Expand Up @@ -286,10 +329,11 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
let stmts = std::mem::take(&mut function.block.stmts);
let Ctx { pyo3_path } = ctx;
let ident = &function.sig.ident;
let name = options.name.unwrap_or_else(|| ident.unraw());
let vis = &function.vis;
let doc = get_doc(&function.attrs, None);

let initialization = module_initialization(options, ident);
let initialization = module_initialization(&name, ctx);

// Module function called with optional Python<'_> marker as first arg, followed by the module.
let mut module_args = Vec::new();
Expand Down Expand Up @@ -354,9 +398,7 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
})
}

fn module_initialization(options: PyModuleOptions, ident: &syn::Ident) -> TokenStream {
let name = options.name.unwrap_or_else(|| ident.unraw());
let ctx = &Ctx::new(&options.krate);
fn module_initialization(name: &syn::Ident, ctx: &Ctx) -> TokenStream {
let Ctx { pyo3_path } = ctx;
let pyinit_symbol = format!("PyInit_{}", name);

Expand Down Expand Up @@ -491,9 +533,33 @@ fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool {
attrs.iter().any(|attr| attr.path().is_ident(ident))
}

fn set_module_attribute(attrs: &mut Vec<syn::Attribute>, module_name: &str) {
attrs.push(parse_quote!(#[pyo3(module = #module_name)]));
}

fn has_pyo3_module_declared<T: Parse>(
attrs: &[syn::Attribute],
root_attribute_name: &str,
is_module_option: impl Fn(&T) -> bool + Copy,
) -> Result<bool> {
for attr in attrs {
if (attr.path().is_ident("pyo3") || attr.path().is_ident(root_attribute_name))
&& matches!(attr.meta, Meta::List(_))
{
for option in &attr.parse_args_with(Punctuated::<T, Comma>::parse_terminated)? {
if is_module_option(option) {
return Ok(true);
}
}
}
}
Ok(false)
}

enum PyModulePyO3Option {
Crate(CrateAttribute),
Name(NameAttribute),
Module(ModuleAttribute),
}

impl Parse for PyModulePyO3Option {
Expand All @@ -503,6 +569,8 @@ impl Parse for PyModulePyO3Option {
input.parse().map(PyModulePyO3Option::Name)
} else if lookahead.peek(syn::Token![crate]) {
input.parse().map(PyModulePyO3Option::Crate)
} else if lookahead.peek(attributes::kw::module) {
input.parse().map(PyModulePyO3Option::Module)
} else {
Err(lookahead.error())
}
Expand Down
2 changes: 1 addition & 1 deletion pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ pub struct PyClassPyO3Options {
pub weakref: Option<kw::weakref>,
}

enum PyClassPyO3Option {
pub enum PyClassPyO3Option {
Crate(CrateAttribute),
Dict(kw::dict),
Eq(kw::eq),
Expand Down
57 changes: 54 additions & 3 deletions tests/test_declarative_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use pyo3::create_exception;
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use pyo3::sync::GILOnceCell;
#[cfg(not(Py_LIMITED_API))]
use pyo3::types::PyBool;

Expand Down Expand Up @@ -78,7 +79,7 @@ mod declarative_module {
x * 3
}

#[pyclass]
#[pyclass(name = "Struct")]
struct Struct;

#[pymethods]
Expand All @@ -89,12 +90,31 @@ mod declarative_module {
}
}

#[pyclass(eq, eq_int)]
#[pyclass(module = "foo")]
struct StructInCustomModule;

#[pyclass(eq, eq_int, name = "Enum")]
#[derive(PartialEq)]
enum Enum {
A,
B,
}

#[pyclass(eq, eq_int, module = "foo")]
#[derive(PartialEq)]
enum EnumInCustomModule {
A,
B,
}
}

#[pymodule]
#[pyo3(module = "custom_root")]
mod inner_custom_root {
use super::*;

#[pyclass]
struct Struct;
}

#[pymodule_init]
Expand All @@ -121,10 +141,17 @@ mod declarative_module2 {
use super::double;
}

fn declarative_module(py: Python<'_>) -> &Bound<'_, PyModule> {
static MODULE: GILOnceCell<Py<PyModule>> = GILOnceCell::new();
MODULE
.get_or_init(py, || pyo3::wrap_pymodule!(declarative_module)(py))
.bind(py)
}

#[test]
fn test_declarative_module() {
Python::with_gil(|py| {
let m = pyo3::wrap_pymodule!(declarative_module)(py).into_bound(py);
let m = declarative_module(py);
py_assert!(
py,
m,
Expand Down Expand Up @@ -188,3 +215,27 @@ fn test_raw_ident_module() {
py_assert!(py, m, "m.double(2) == 4");
})
}

#[test]
fn test_module_names() {
Python::with_gil(|py| {
let m = declarative_module(py);
py_assert!(
py,
m,
"m.inner.Struct.__module__ == 'declarative_module.inner'"
);
py_assert!(py, m, "m.inner.StructInCustomModule.__module__ == 'foo'");
py_assert!(
py,
m,
"m.inner.Enum.__module__ == 'declarative_module.inner'"
);
py_assert!(py, m, "m.inner.EnumInCustomModule.__module__ == 'foo'");
py_assert!(
py,
m,
"m.inner_custom_root.Struct.__module__ == 'custom_root.inner_custom_root'"
);
})
}

0 comments on commit 9c9a0a4

Please sign in to comment.