diff --git a/guide/src/async-await.md b/guide/src/async-await.md index 9376a488185..af20547ef1f 100644 --- a/guide/src/async-await.md +++ b/guide/src/async-await.md @@ -30,8 +30,7 @@ Resulting future of an `async fn` decorated by `#[pyfunction]` must be `Send + ' As a consequence, `async fn` parameters and return types must also be `Send + 'static`, so it is not possible to have a signature like `async fn does_not_compile(arg: &PyAny, py: Python<'_>) -> &PyAny`. -It also means that methods cannot use `&self`/`&mut self`, *but this restriction should be dropped in the future.* - +However, an exception is done for method receiver, so async methods can accept `&self`/`&mut self` ## Implicit GIL holding diff --git a/newsfragments/3069.changed.md b/newsfragments/3069.changed.md new file mode 100644 index 00000000000..7979ea71960 --- /dev/null +++ b/newsfragments/3069.changed.md @@ -0,0 +1 @@ +Allow async methods to accept `&self`/`&mut self` \ No newline at end of file diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index a211ea6b481..2e1b9e3f4eb 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -472,8 +472,7 @@ impl<'a> FnSpec<'a> { } let rust_call = |args: Vec| { - let mut call = quote! { function(#self_arg #(#args),*) }; - if self.asyncness.is_some() { + let call = if self.asyncness.is_some() { let throw_callback = if cancel_handle.is_some() { quote! { Some(__throw_callback) } } else { @@ -484,8 +483,23 @@ impl<'a> FnSpec<'a> { Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)), None => quote!(None), }; - call = quote! {{ - let future = #call; + let future = match self.tp { + FnType::Fn(SelfType::Receiver { mutable: false, .. }) => quote! { + _pyo3::impl_::coroutine::ref_method_future( + py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf), + move |__self| function(__self, #(#args),*) + )? + }, + FnType::Fn(SelfType::Receiver { mutable: true, .. }) => quote! { + _pyo3::impl_::coroutine::mut_method_future( + py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf), + move |__self| function(__self, #(#args),*) + )? + }, + _ => quote! { function(#self_arg #(#args),*) }, + }; + let mut call = quote! {{ + let future = #future; _pyo3::impl_::coroutine::new_coroutine( _pyo3::intern!(py, stringify!(#python_name)), #qualname_prefix, @@ -500,7 +514,10 @@ impl<'a> FnSpec<'a> { #call }}; } - } + call + } else { + quote! { function(#self_arg #(#args),*) } + }; quotes::map_result_into_ptr(quotes::ok_wrap(call)) }; diff --git a/src/impl_/coroutine.rs b/src/impl_/coroutine.rs index c8b2cdcce49..49a04fde828 100644 --- a/src/impl_/coroutine.rs +++ b/src/impl_/coroutine.rs @@ -1,7 +1,12 @@ use std::future::Future; +use std::mem; use crate::coroutine::cancel::ThrowCallback; -use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject}; +use crate::pyclass::boolean_struct::False; +use crate::{ + coroutine::Coroutine, types::PyString, IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject, + PyRef, PyRefMut, PyResult, Python, +}; pub fn new_coroutine( name: &PyString, @@ -16,3 +21,46 @@ where { Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future) } + +fn get_ptr(obj: &Py) -> *mut T { + // SAFETY: Py can be casted as *const PyCell + unsafe { &*(obj.as_ptr() as *const PyCell) }.get_ptr() +} + +struct RefGuard(Py); + +impl Drop for RefGuard { + fn drop(&mut self) { + Python::with_gil(|gil| self.0.as_ref(gil).release_ref()) + } +} + +pub unsafe fn ref_method_future<'a, T: PyClass, F: Future + 'a>( + self_: &PyAny, + fut: impl FnOnce(&'a T) -> F, +) -> PyResult> { + let ref_: PyRef<'_, T> = self_.extract()?; + // SAFETY: `PyRef::as_ptr` returns a borrowed reference + let guard = RefGuard(unsafe { Py::::from_borrowed_ptr(self_.py(), ref_.as_ptr()) }); + mem::forget(ref_); + Ok(async move { fut(unsafe { &*get_ptr(&guard.0) }).await }) +} + +struct RefMutGuard(Py); + +impl Drop for RefMutGuard { + fn drop(&mut self) { + Python::with_gil(|gil| self.0.as_ref(gil).release_mut()) + } +} + +pub fn mut_method_future<'a, T: PyClass, F: Future + 'a>( + self_: &PyAny, + fut: impl FnOnce(&'a mut T) -> F, +) -> PyResult> { + let mut_: PyRefMut<'_, T> = self_.extract()?; + // SAFETY: `PyRefMut::as_ptr` returns a borrowed reference + let guard = RefMutGuard(unsafe { Py::::from_borrowed_ptr(self_.py(), mut_.as_ptr()) }); + mem::forget(mut_); + Ok(async move { fut(unsafe { &mut *get_ptr(&guard.0) }).await }) +} diff --git a/src/pycell.rs b/src/pycell.rs index 8a4ceb6b374..2124d57695a 100644 --- a/src/pycell.rs +++ b/src/pycell.rs @@ -516,6 +516,14 @@ impl PyCell { #[allow(clippy::useless_conversion)] offset.try_into().expect("offset should fit in Py_ssize_t") } + + pub(crate) fn release_ref(&self) { + self.borrow_checker().release_borrow(); + } + + pub(crate) fn release_mut(&self) { + self.borrow_checker().release_borrow_mut(); + } } impl PyCell { diff --git a/tests/test_coroutine.rs b/tests/test_coroutine.rs index 0f2341d228c..8eba40787d7 100644 --- a/tests/test_coroutine.rs +++ b/tests/test_coroutine.rs @@ -45,7 +45,6 @@ fn test_coroutine_qualname() { fn new() -> Self { Self } - // TODO use &self when possible async fn my_method(_self: Py) {} #[classmethod] async fn my_classmethod(_cls: Py) {} @@ -173,3 +172,14 @@ fn coroutine_cancel_handle() { .unwrap(); }) } + +#[test] +fn test_async_method_receiver() { + #[pyclass] + struct MyClass; + #[pymethods] + impl MyClass { + async fn method(&self) {} + async fn method_mut(&mut self) {} + } +}