diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e8819e8bbc..beb3e35d02d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] +### Added + + * `module` argument to `pyclass` macro. [#499](https://github.com/PyO3/pyo3/pull/499) + + ## [0.7.0] - 2018-05-26 ### Added diff --git a/examples/word-count/src/lib.rs b/examples/word-count/src/lib.rs index aa98108c6be..29f97808c0f 100644 --- a/examples/word-count/src/lib.rs +++ b/examples/word-count/src/lib.rs @@ -8,7 +8,7 @@ use std::fs; use std::path::PathBuf; /// Represents a file that can be searched -#[pyclass] +#[pyclass(module = "word_count")] struct WordCounter { path: PathBuf, } diff --git a/guide/src/class.md b/guide/src/class.md index a7aa0943ecf..69de02b57ce 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -104,6 +104,8 @@ If a custom class contains references to other Python objects that can be collec * `extends=BaseType` - Use a custom base class. The base `BaseType` must implement `PyTypeInfo`. * `subclass` - Allows Python classes to inherit from this class. * `dict` - Adds `__dict__` support, so that the instances of this type have a dictionary containing arbitrary instance variables. +* `module="XXX"` - Set the name of the module the class will be shown as defined in. If not given, the class + will be a virtual member of the `builtins` module. ## Constructor diff --git a/pyo3-derive-backend/src/pyclass.rs b/pyo3-derive-backend/src/pyclass.rs index 5332f49c7df..04f378682a1 100644 --- a/pyo3-derive-backend/src/pyclass.rs +++ b/pyo3-derive-backend/src/pyclass.rs @@ -15,6 +15,7 @@ pub struct PyClassArgs { pub name: Option, pub flags: Vec, pub base: syn::TypePath, + pub module: Option, } impl Parse for PyClassArgs { @@ -34,6 +35,7 @@ impl Default for PyClassArgs { PyClassArgs { freelist: None, name: None, + module: None, // We need the 0 as value for the constant we're later building using quote for when there // are no other flags flags: vec![parse_quote! {0}], @@ -94,6 +96,20 @@ impl PyClassArgs { )); } }, + "module" => match *assign.right { + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(ref lit), + .. + }) => { + self.module = Some(lit.clone()); + } + _ => { + return Err(syn::Error::new_spanned( + *assign.right.clone(), + "Wrong format for module", + )); + } + }, _ => { return Err(syn::Error::new_spanned( *assign.left.clone(), @@ -298,6 +314,11 @@ fn impl_class( } else { quote! {0} }; + let module = if let Some(m) = &attr.module { + quote! { Some(#m) } + } else { + quote! { None } + }; let inventory_impl = impl_inventory(&cls); @@ -310,6 +331,7 @@ fn impl_class( type BaseType = #base; const NAME: &'static str = #cls_name; + const MODULE: Option<&'static str> = #module; const DESCRIPTION: &'static str = #doc; const FLAGS: usize = #(#flags)|*; diff --git a/src/type_object.rs b/src/type_object.rs index 15b525fd60c..d1e3242cd95 100644 --- a/src/type_object.rs +++ b/src/type_object.rs @@ -25,6 +25,9 @@ pub trait PyTypeInfo { /// Class name const NAME: &'static str; + /// Module name, if any + const MODULE: Option<&'static str>; + /// Class doc string const DESCRIPTION: &'static str = "\0"; @@ -256,7 +259,7 @@ where let gil = Python::acquire_gil(); let py = gil.python(); - initialize_type::(py, None).unwrap_or_else(|_| { + initialize_type::(py, ::MODULE).unwrap_or_else(|_| { panic!("An error occurred while initializing class {}", Self::NAME) }); } diff --git a/src/types/datetime.rs b/src/types/datetime.rs index 127aa1eeb36..077be5fb3a0 100644 --- a/src/types/datetime.rs +++ b/src/types/datetime.rs @@ -66,7 +66,12 @@ pub trait PyTimeAccess { /// Bindings around `datetime.date` pub struct PyDate(PyObject); -pyobject_native_type!(PyDate, *PyDateTimeAPI.DateType, PyDate_Check); +pyobject_native_type!( + PyDate, + *PyDateTimeAPI.DateType, + Some("datetime"), + PyDate_Check +); impl PyDate { pub fn new<'p>(py: Python<'p>, year: i32, month: u8, day: u8) -> PyResult<&'p PyDate> { @@ -116,7 +121,12 @@ impl PyDateAccess for PyDate { /// Bindings for `datetime.datetime` pub struct PyDateTime(PyObject); -pyobject_native_type!(PyDateTime, *PyDateTimeAPI.DateTimeType, PyDateTime_Check); +pyobject_native_type!( + PyDateTime, + *PyDateTimeAPI.DateTimeType, + Some("datetime"), + PyDateTime_Check +); impl PyDateTime { pub fn new<'p>( @@ -220,7 +230,12 @@ impl PyTimeAccess for PyDateTime { /// Bindings for `datetime.time` pub struct PyTime(PyObject); -pyobject_native_type!(PyTime, *PyDateTimeAPI.TimeType, PyTime_Check); +pyobject_native_type!( + PyTime, + *PyDateTimeAPI.TimeType, + Some("datetime"), + PyTime_Check +); impl PyTime { pub fn new<'p>( @@ -299,11 +314,21 @@ impl PyTimeAccess for PyTime { /// /// This is an abstract base class and should not be constructed directly. pub struct PyTzInfo(PyObject); -pyobject_native_type!(PyTzInfo, *PyDateTimeAPI.TZInfoType, PyTZInfo_Check); +pyobject_native_type!( + PyTzInfo, + *PyDateTimeAPI.TZInfoType, + Some("datetime"), + PyTZInfo_Check +); /// Bindings for `datetime.timedelta` pub struct PyDelta(PyObject); -pyobject_native_type!(PyDelta, *PyDateTimeAPI.DeltaType, PyDelta_Check); +pyobject_native_type!( + PyDelta, + *PyDateTimeAPI.DeltaType, + Some("datetime"), + PyDelta_Check +); impl PyDelta { pub fn new<'p>( diff --git a/src/types/mod.rs b/src/types/mod.rs index fb92102682b..e1dea39954a 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -79,9 +79,9 @@ macro_rules! pyobject_native_type_named ( #[macro_export] macro_rules! pyobject_native_type ( - ($name: ty, $typeobject: expr, $checkfunction: path $(,$type_param: ident)*) => { + ($name: ty, $typeobject: expr, $module: expr, $checkfunction: path $(,$type_param: ident)*) => { pyobject_native_type_named!($name $(,$type_param)*); - pyobject_native_type_convert!($name, $typeobject, $checkfunction $(,$type_param)*); + pyobject_native_type_convert!($name, $typeobject, $module, $checkfunction $(,$type_param)*); impl<'a, $($type_param,)*> ::std::convert::From<&'a $name> for &'a $crate::types::PyAny { fn from(ob: &'a $name) -> Self { @@ -89,16 +89,20 @@ macro_rules! pyobject_native_type ( } } }; + ($name: ty, $typeobject: expr, $checkfunction: path $(,$type_param: ident)*) => { + pyobject_native_type!{$name, $typeobject, Some("builtins"), $checkfunction $(,$type_param)*} + }; ); #[macro_export] macro_rules! pyobject_native_type_convert( - ($name: ty, $typeobject: expr, $checkfunction: path $(,$type_param: ident)*) => { + ($name: ty, $typeobject: expr, $module: expr, $checkfunction: path $(,$type_param: ident)*) => { impl<$($type_param,)*> $crate::type_object::PyTypeInfo for $name { type Type = (); type BaseType = $crate::types::PyAny; const NAME: &'static str = stringify!($name); + const MODULE: Option<&'static str> = $module; const SIZE: usize = ::std::mem::size_of::<$crate::ffi::PyObject>(); const OFFSET: isize = 0; @@ -154,6 +158,9 @@ macro_rules! pyobject_native_type_convert( } } }; + ($name: ty, $typeobject: expr, $checkfunction: path $(,$type_param: ident)*) => { + pyobject_native_type_convert!{$name, $typeobject, Some("builtins"), $checkfunction $(,$type_param)*} + }; ); mod any; diff --git a/src/types/num.rs b/src/types/num.rs index 164a8b1005e..5fccb07e468 100644 --- a/src/types/num.rs +++ b/src/types/num.rs @@ -123,7 +123,12 @@ pub(super) const IS_LITTLE_ENDIAN: c_int = 0; #[repr(transparent)] pub struct PyLong(PyObject); -pyobject_native_type!(PyLong, ffi::PyLong_Type, ffi::PyLong_Check); +pyobject_native_type!( + PyLong, + ffi::PyLong_Type, + Some("builtins"), + ffi::PyLong_Check +); macro_rules! int_fits_c_long ( ($rust_type:ty) => ( diff --git a/src/types/set.rs b/src/types/set.rs index 1232290858e..43d9b39d51a 100644 --- a/src/types/set.rs +++ b/src/types/set.rs @@ -18,7 +18,7 @@ pub struct PySet(PyObject); #[repr(transparent)] pub struct PyFrozenSet(PyObject); -pyobject_native_type!(PySet, ffi::PySet_Type, ffi::PySet_Check); +pyobject_native_type!(PySet, ffi::PySet_Type, Some("builtins"), ffi::PySet_Check); pyobject_native_type!(PyFrozenSet, ffi::PyFrozenSet_Type, ffi::PyFrozenSet_Check); impl PySet { diff --git a/src/types/string.rs b/src/types/string.rs index 8e542c030ac..ad3aaa40735 100644 --- a/src/types/string.rs +++ b/src/types/string.rs @@ -24,7 +24,12 @@ pyobject_native_type!(PyString, ffi::PyUnicode_Type, ffi::PyUnicode_Check); #[repr(transparent)] pub struct PyBytes(PyObject); -pyobject_native_type!(PyBytes, ffi::PyBytes_Type, ffi::PyBytes_Check); +pyobject_native_type!( + PyBytes, + ffi::PyBytes_Type, + Some("builtins"), + ffi::PyBytes_Check +); impl PyString { /// Creates a new Python string object. diff --git a/tests/test_module.rs b/tests/test_module.rs index 61b31907694..4daa5266f87 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -6,7 +6,10 @@ use pyo3::types::IntoPyDict; mod common; #[pyclass] -struct EmptyClass {} +struct AnonClass {} + +#[pyclass(module = "module")] +struct LocatedClass {} fn sum_as_string(a: i64, b: i64) -> String { format!("{}", a + b).to_string() @@ -34,7 +37,8 @@ fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { Ok(42) } - m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); m.add("foo", "bar").unwrap(); @@ -63,7 +67,9 @@ fn test_module_with_functions() { run("assert module_with_functions.sum_as_string(1, 2) == '3'"); run("assert module_with_functions.no_parameters() == 42"); run("assert module_with_functions.foo == 'bar'"); - run("assert module_with_functions.EmptyClass != None"); + run("assert module_with_functions.AnonClass != None"); + run("assert module_with_functions.LocatedClass != None"); + run("assert module_with_functions.LocatedClass.__module__ == 'module'"); run("assert module_with_functions.double(3) == 6"); run("assert module_with_functions.double.__doc__ == 'Doubles the given value'"); run("assert module_with_functions.also_double(3) == 6");