From 3f15d2a9464aeb6574d4e70976f14322b3331f03 Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Thu, 30 Nov 2023 03:10:26 +0100 Subject: [PATCH] feat: allow async methods to accept `&self`/`&mut self` --- guide/src/async-await.md | 3 +- newsfragments/3609.changed.md | 1 + pyo3-macros-backend/src/method.rs | 48 ++++++++++++------- src/impl_/coroutine.rs | 78 +++++++++++++++++++++++++++++-- src/pycell.rs | 10 ++++ tests/test_coroutine.rs | 53 +++++++++++++++++++++ 6 files changed, 171 insertions(+), 22 deletions(-) create mode 100644 newsfragments/3609.changed.md diff --git a/guide/src/async-await.md b/guide/src/async-await.md index 3649a9a0fed..847fed2f47d 100644 --- a/guide/src/async-await.md +++ b/guide/src/async-await.md @@ -30,8 +30,7 @@ Resulting future of an `async fn` decorated by `#[pyfunction]` must be `Send + ' As a consequence, `async fn` parameters and return types must also be `Send + 'static`, so it is not possible to have a signature like `async fn does_not_compile(arg: &PyAny, py: Python<'_>) -> &PyAny`. -It also means that methods cannot use `&self`/`&mut self`, *but this restriction should be dropped in the future.* - +However, there is an exception for method receiver, so async methods can accept `&self`/`&mut self` ## Implicit GIL holding diff --git a/newsfragments/3609.changed.md b/newsfragments/3609.changed.md new file mode 100644 index 00000000000..7979ea71960 --- /dev/null +++ b/newsfragments/3609.changed.md @@ -0,0 +1 @@ +Allow async methods to accept `&self`/`&mut self` \ No newline at end of file diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index 1cbb9304cda..dd392a5f905 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -1,18 +1,19 @@ use std::fmt::Display; -use crate::attributes::{TextSignatureAttribute, TextSignatureAttributeValue}; -use crate::deprecations::{Deprecation, Deprecations}; -use crate::params::impl_arg_params; -use crate::pyfunction::{FunctionSignature, PyFunctionArgPyO3Attributes}; -use crate::pyfunction::{PyFunctionOptions, SignatureAttribute}; -use crate::quotes; -use crate::utils::{self, PythonDoc}; use proc_macro2::{Span, TokenStream}; -use quote::ToTokens; -use quote::{quote, quote_spanned}; -use syn::ext::IdentExt; -use syn::spanned::Spanned; -use syn::{Ident, Result}; +use quote::{quote, quote_spanned, ToTokens}; +use syn::{ext::IdentExt, spanned::Spanned, Ident, Result}; + +use crate::{ + attributes::{TextSignatureAttribute, TextSignatureAttributeValue}, + deprecations::{Deprecation, Deprecations}, + params::impl_arg_params, + pyfunction::{ + FunctionSignature, PyFunctionArgPyO3Attributes, PyFunctionOptions, SignatureAttribute, + }, + quotes, + utils::{self, PythonDoc}, +}; #[derive(Clone, Debug)] pub struct FnArg<'a> { @@ -473,8 +474,7 @@ impl<'a> FnSpec<'a> { } let rust_call = |args: Vec| { - let mut call = quote! { function(#self_arg #(#args),*) }; - if self.asyncness.is_some() { + let call = if self.asyncness.is_some() { let throw_callback = if cancel_handle.is_some() { quote! { Some(__throw_callback) } } else { @@ -485,8 +485,19 @@ impl<'a> FnSpec<'a> { Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)), None => quote!(None), }; - call = quote! {{ - let future = #call; + let future = match self.tp { + FnType::Fn(SelfType::Receiver { mutable: false, .. }) => quote! {{ + let __guard = _pyo3::impl_::coroutine::RefGuard::<#cls>::new(py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf))?; + async move { function(::std::ops::Deref::deref(&__guard), #(#args),*).await } + }}, + FnType::Fn(SelfType::Receiver { mutable: true, .. }) => quote! {{ + let mut __guard = _pyo3::impl_::coroutine::RefMutGuard::<#cls>::new(py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf))?; + async move { function(::std::ops::DerefMut::deref_mut(&mut __guard), #(#args),*).await } + }}, + _ => quote! { function(#self_arg #(#args),*) }, + }; + let mut call = quote! {{ + let future = #future; _pyo3::impl_::coroutine::new_coroutine( _pyo3::intern!(py, stringify!(#python_name)), #qualname_prefix, @@ -501,7 +512,10 @@ impl<'a> FnSpec<'a> { #call }}; } - } + call + } else { + quote! { function(#self_arg #(#args),*) } + }; quotes::map_result_into_ptr(quotes::ok_wrap(call)) }; diff --git a/src/impl_/coroutine.rs b/src/impl_/coroutine.rs index c8b2cdcce49..c0e80f00e42 100644 --- a/src/impl_/coroutine.rs +++ b/src/impl_/coroutine.rs @@ -1,7 +1,15 @@ -use std::future::Future; +use std::{ + future::Future, + mem, + ops::{Deref, DerefMut}, +}; -use crate::coroutine::cancel::ThrowCallback; -use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject}; +use crate::{ + coroutine::{cancel::ThrowCallback, Coroutine}, + pyclass::boolean_struct::False, + types::PyString, + IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyRef, PyRefMut, PyResult, Python, +}; pub fn new_coroutine( name: &PyString, @@ -16,3 +24,67 @@ where { Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future) } + +fn get_ptr(obj: &Py) -> *mut T { + // SAFETY: Py can be casted as *const PyCell + unsafe { &*(obj.as_ptr() as *const PyCell) }.get_ptr() +} + +pub struct RefGuard(Py); + +impl RefGuard { + pub fn new(obj: &PyAny) -> PyResult { + let ref_: PyRef<'_, T> = obj.extract()?; + // SAFETY: `PyRef::as_ptr` returns a borrowed reference + let guard = RefGuard(unsafe { Py::::from_borrowed_ptr(obj.py(), ref_.as_ptr()) }); + mem::forget(ref_); + Ok(guard) + } +} + +impl Deref for RefGuard { + type Target = T; + fn deref(&self) -> &Self::Target { + // SAFETY: `RefGuard` has been built from `PyRef` and provides the same guarantees + unsafe { &*get_ptr(&self.0) } + } +} + +impl Drop for RefGuard { + fn drop(&mut self) { + Python::with_gil(|gil| self.0.as_ref(gil).release_ref()) + } +} + +pub struct RefMutGuard>(Py); + +impl> RefMutGuard { + pub fn new(obj: &PyAny) -> PyResult { + let mut_: PyRefMut<'_, T> = obj.extract()?; + // // SAFETY: `PyRefMut::as_ptr` returns a borrowed reference + let guard = RefMutGuard(unsafe { Py::::from_borrowed_ptr(obj.py(), mut_.as_ptr()) }); + mem::forget(mut_); + Ok(guard) + } +} + +impl> Deref for RefMutGuard { + type Target = T; + fn deref(&self) -> &Self::Target { + // SAFETY: `RefMutGuard` has been built from `PyRefMut` and provides the same guarantees + unsafe { &*get_ptr(&self.0) } + } +} + +impl> DerefMut for RefMutGuard { + fn deref_mut(&mut self) -> &mut Self::Target { + // SAFETY: `RefMutGuard` has been built from `PyRefMut` and provides the same guarantees + unsafe { &mut *get_ptr(&self.0) } + } +} + +impl> Drop for RefMutGuard { + fn drop(&mut self) { + Python::with_gil(|gil| self.0.as_ref(gil).release_mut()) + } +} diff --git a/src/pycell.rs b/src/pycell.rs index 8a4ceb6b374..8b85bfec8e2 100644 --- a/src/pycell.rs +++ b/src/pycell.rs @@ -516,6 +516,16 @@ impl PyCell { #[allow(clippy::useless_conversion)] offset.try_into().expect("offset should fit in Py_ssize_t") } + + #[cfg(feature = "macros")] + pub(crate) fn release_ref(&self) { + self.borrow_checker().release_borrow(); + } + + #[cfg(feature = "macros")] + pub(crate) fn release_mut(&self) { + self.borrow_checker().release_borrow_mut(); + } } impl PyCell { diff --git a/tests/test_coroutine.rs b/tests/test_coroutine.rs index cf975423c25..01b84ca8e94 100644 --- a/tests/test_coroutine.rs +++ b/tests/test_coroutine.rs @@ -234,3 +234,56 @@ fn coroutine_panic() { py_run!(gil, panic, &handle_windows(test)); }) } + +#[test] +fn test_async_method_receiver() { + #[pyclass] + struct Counter(usize); + #[pymethods] + impl Counter { + #[new] + fn new() -> Self { + Self(0) + } + async fn get(&self) -> usize { + self.0 + } + async fn incr(&mut self) -> usize { + self.0 += 1; + self.0 + } + } + Python::with_gil(|gil| { + let test = r#" + import asyncio + + obj = Counter() + coro1 = obj.get() + coro2 = obj.get() + try: + obj.incr() # borrow checking should fail + except RuntimeError as err: + pass + else: + assert False + assert asyncio.run(coro1) == 0 + coro2.close() + coro3 = obj.incr() + try: + obj.incr() # borrow checking should fail + except RuntimeError as err: + pass + else: + assert False + try: + obj.get() == 42 # borrow checking should fail + except RuntimeError as err: + pass + else: + assert False + assert asyncio.run(coro3) == 1 + "#; + let locals = [("Counter", gil.get_type::())].into_py_dict(gil); + py_run!(gil, *locals, test); + }) +}