Skip to content

Commit

Permalink
Allow #[new] to return existing instances
Browse files Browse the repository at this point in the history
fixes #2384
  • Loading branch information
alex committed Jul 2, 2023
1 parent 1a0c9be commit 98709c2
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 7 deletions.
1 change: 1 addition & 0 deletions newsfragments/3287.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`#[new]` methods may now return `Py<Self>` in order to return existing instances
31 changes: 24 additions & 7 deletions src/pyclass_init.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Contains initialization utilities for `#[pyclass]`.
use crate::callback::IntoPyCallbackOutput;
use crate::impl_::pyclass::{PyClassBaseType, PyClassDict, PyClassThreadChecker, PyClassWeakRef};
use crate::{ffi, PyCell, PyClass, PyErr, PyResult, Python};
use crate::{ffi, IntoPyPointer, Py, PyCell, PyClass, PyErr, PyResult, Python};
use crate::{
ffi::PyTypeObject,
pycell::{
Expand Down Expand Up @@ -134,17 +134,22 @@ impl<T: PyTypeInfo> PyObjectInit<T> for PyNativeTypeInitializer<T> {
/// );
/// });
/// ```
pub struct PyClassInitializer<T: PyClass> {
init: T,
super_init: <T::BaseType as PyClassBaseType>::Initializer,
pub struct PyClassInitializer<T: PyClass>(PyClassInitializerImpl<T>);

enum PyClassInitializerImpl<T: PyClass> {
Existing(Py<T>),
New {
init: T,
super_init: <T::BaseType as PyClassBaseType>::Initializer,
},
}

impl<T: PyClass> PyClassInitializer<T> {
/// Constructs a new initializer from value `T` and base class' initializer.
///
/// It is recommended to use `add_subclass` instead of this method for most usage.
pub fn new(init: T, super_init: <T::BaseType as PyClassBaseType>::Initializer) -> Self {
Self { init, super_init }
Self(PyClassInitializerImpl::New { init, super_init })
}

/// Constructs a new initializer from an initializer for the base class.
Expand Down Expand Up @@ -242,13 +247,18 @@ impl<T: PyClass> PyObjectInit<T> for PyClassInitializer<T> {
contents: MaybeUninit<PyCellContents<T>>,
}

let obj = self.super_init.into_new_object(py, subtype)?;
let (init, super_init) = match self.0 {
PyClassInitializerImpl::Existing(value) => return Ok(value.into_ptr()),
PyClassInitializerImpl::New { init, super_init } => (init, super_init),
};

let obj = super_init.into_new_object(py, subtype)?;

let cell: *mut PartiallyInitializedPyCell<T> = obj as _;
std::ptr::write(
(*cell).contents.as_mut_ptr(),
PyCellContents {
value: ManuallyDrop::new(UnsafeCell::new(self.init)),
value: ManuallyDrop::new(UnsafeCell::new(init)),
borrow_checker: <T::PyClassMutability as PyClassMutability>::Storage::new(),
thread_checker: T::ThreadChecker::new(),
dict: T::Dict::INIT,
Expand Down Expand Up @@ -284,6 +294,13 @@ where
}
}

impl<T: PyClass> From<Py<T>> for PyClassInitializer<T> {
#[inline]
fn from(value: Py<T>) -> PyClassInitializer<T> {
PyClassInitializer(PyClassInitializerImpl::Existing(value))
}
}

// Implementation used by proc macros to allow anything convertible to PyClassInitializer<T> to be
// the return value of pyclass #[new] method (optionally wrapped in `Result<U, E>`).
impl<T, U> IntoPyCallbackOutput<PyClassInitializer<T>> for U
Expand Down
60 changes: 60 additions & 0 deletions tests/test_class_new.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::sync::GILOnceCell;
use pyo3::types::IntoPyDict;

#[pyclass]
Expand Down Expand Up @@ -204,3 +205,62 @@ fn new_with_custom_error() {
assert_eq!(err.to_string(), "ValueError: custom error");
});
}

#[pyclass]
struct NewExisting {
#[pyo3(get)]
num: usize,
}

#[pymethods]
impl NewExisting {
#[new]
fn new(py: pyo3::Python<'_>, val: usize) -> pyo3::Py<NewExisting> {
static PRE_BUILT: GILOnceCell<[pyo3::Py<NewExisting>; 2]> = GILOnceCell::new();
let existing = PRE_BUILT.get_or_init(py, || {
[
pyo3::PyCell::new(py, NewExisting { num: 0 })
.unwrap()
.into(),
pyo3::PyCell::new(py, NewExisting { num: 1 })
.unwrap()
.into(),
]
});

if val < existing.len() {
return existing[val].clone_ref(py);
}

pyo3::PyCell::new(py, NewExisting { num: val })
.unwrap()
.into()
}
}

#[test]
fn test_new_existing() {
Python::with_gil(|py| {
let typeobj = py.get_type::<NewExisting>();

let obj1 = typeobj.call1((0,)).unwrap();
let obj2 = typeobj.call1((0,)).unwrap();
let obj3 = typeobj.call1((1,)).unwrap();
let obj4 = typeobj.call1((1,)).unwrap();
let obj5 = typeobj.call1((2,)).unwrap();
let obj6 = typeobj.call1((2,)).unwrap();

assert!(obj1.getattr("num").unwrap().extract::<u32>().unwrap() == 0);
assert!(obj2.getattr("num").unwrap().extract::<u32>().unwrap() == 0);
assert!(obj3.getattr("num").unwrap().extract::<u32>().unwrap() == 1);
assert!(obj4.getattr("num").unwrap().extract::<u32>().unwrap() == 1);
assert!(obj5.getattr("num").unwrap().extract::<u32>().unwrap() == 2);
assert!(obj6.getattr("num").unwrap().extract::<u32>().unwrap() == 2);

assert!(obj1.is(obj2));
assert!(obj3.is(obj4));
assert!(!obj1.is(obj3));
assert!(!obj1.is(obj5));
assert!(!obj5.is(obj6));
});
}

0 comments on commit 98709c2

Please sign in to comment.