Skip to content

Commit

Permalink
refactor: drop futures_util dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo committed Dec 4, 2023
1 parent 8a674c2 commit 2ca9f59
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 32 deletions.
3 changes: 0 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ unindent = { version = "0.2.1", optional = true }
# support crate for multiple-pymethods feature
inventory = { version = "0.3.0", optional = true }

# coroutine implementation
futures-util = "0.3"

# crate integrations that can be added using the eponymous features
anyhow = { version = "1.0", optional = true }
chrono = { version = "0.4.25", default-features = false, optional = true }
Expand Down
35 changes: 18 additions & 17 deletions src/coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
//! Python coroutine implementation, used notably when wrapping `async fn`
//! with `#[pyfunction]`/`#[pymethods]`.
use std::{
any::Any,
future::Future,
panic,
pin::Pin,
sync::Arc,
task::{Context, Poll},
task::{Context, Poll, Waker},
};

use futures_util::FutureExt;
use pyo3_macros::{pyclass, pymethods};

use crate::{
coroutine::waker::AsyncioWaker,
coroutine::{cancel::ThrowCallback, waker::AsyncioWaker},
exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration},
panic::PanicException,
pyclass::IterNextOutput,
Expand All @@ -24,20 +22,17 @@ use crate::{
pub(crate) mod cancel;
mod waker;

use crate::coroutine::cancel::ThrowCallback;
pub use cancel::CancelHandle;

const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";

type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;

/// Python coroutine wrapping a [`Future`].
#[pyclass(crate = "crate")]
pub struct Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: Option<Pin<Box<dyn Future<Output = FutureOutput> + Send>>>,
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
waker: Option<Arc<AsyncioWaker>>,
}

Expand Down Expand Up @@ -68,7 +63,7 @@ impl Coroutine {
name,
qualname_prefix,
throw_callback,
future: Some(Box::pin(panic::AssertUnwindSafe(wrap).catch_unwind())),
future: Some(Box::pin(wrap)),
waker: None,
}
}
Expand Down Expand Up @@ -98,22 +93,28 @@ impl Coroutine {
} else {
self.waker = Some(Arc::new(AsyncioWaker::new()));
}
let waker = futures_util::task::waker(self.waker.clone().unwrap());
let waker = Waker::from(self.waker.clone().unwrap());
// poll the Rust future and forward its results if ready
if let Poll::Ready(res) = future_rs.as_mut().poll(&mut Context::from_waker(&waker)) {
self.close();
return match res {
Ok(res) => Ok(IterNextOutput::Return(res?)),
Err(err) => Err(PanicException::from_panic_payload(err)),
};
// polling is UnwindSafe because the future is dropped in case of panic
let poll = || future_rs.as_mut().poll(&mut Context::from_waker(&waker));
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
Ok(Poll::Ready(res)) => {
self.close();
return Ok(IterNextOutput::Return(res?));
}
Err(err) => {
self.close();
return Err(PanicException::from_panic_payload(err));
}
_ => {}
}
// otherwise, initialize the waker `asyncio.Future`
if let Some(future) = self.waker.as_ref().unwrap().initialize_future(py)? {
// `asyncio.Future` must be awaited; fortunately, it implements `__iter__ = __await__`
// and will yield itself if its result has not been set in polling above
if let Some(future) = PyIterator::from_object(future).unwrap().next() {
// future has not been leaked into Python for now, and Rust code can only call
// `set_result(None)` in `ArcWake` implementation, so it's safe to unwrap
// `set_result(None)` in `Wake` implementation, so it's safe to unwrap
return Ok(IterNextOutput::Yield(future.unwrap().into()));
}
}
Expand Down
14 changes: 9 additions & 5 deletions src/coroutine/waker.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::sync::GILOnceCell;
use crate::types::PyCFunction;
use crate::{intern, wrap_pyfunction, Py, PyAny, PyObject, PyResult, Python};
use futures_util::task::ArcWake;
use pyo3_macros::pyfunction;
use std::sync::Arc;
use std::task::Wake;

/// Lazy `asyncio.Future` wrapper, implementing [`ArcWake`] by calling `Future.set_result`.
/// Lazy `asyncio.Future` wrapper, implementing [`Wake`] by calling `Future.set_result`.
///
/// asyncio future is let uninitialized until [`initialize_future`][1] is called.
/// If [`wake`][2] is called before future initialization (during Rust future polling),
Expand All @@ -31,10 +31,14 @@ impl AsyncioWaker {
}
}

impl ArcWake for AsyncioWaker {
fn wake_by_ref(arc_self: &Arc<Self>) {
impl Wake for AsyncioWaker {
fn wake(self: Arc<Self>) {
self.wake_by_ref()
}

fn wake_by_ref(self: &Arc<Self>) {
Python::with_gil(|gil| {
if let Some(loop_and_future) = arc_self.0.get_or_init(gil, || None) {
if let Some(loop_and_future) = self.0.get_or_init(gil, || None) {
loop_and_future
.set_result(gil)
.expect("unexpected error in coroutine waker");
Expand Down
45 changes: 38 additions & 7 deletions tests/test_coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#![cfg(feature = "macros")]
#![cfg(not(target_arch = "wasm32"))]
use std::ops::Deref;
use std::{task::Poll, thread, time::Duration};
use std::{ops::Deref, task::Poll, thread, time::Duration};

use futures::{channel::oneshot, future::poll_fn, FutureExt};
use pyo3::coroutine::CancelHandle;
use pyo3::types::{IntoPyDict, PyType};
use pyo3::{prelude::*, py_run};
use pyo3::{
coroutine::CancelHandle,
prelude::*,
py_run,
types::{IntoPyDict, PyType},
};

#[path = "../src/tests/common.rs"]
mod common;
Expand Down Expand Up @@ -119,7 +121,7 @@ fn cancelled_coroutine() {
let test = r#"
import asyncio
async def main():
task = asyncio.create_task(sleep(1))
task = asyncio.create_task(sleep(999))
await asyncio.sleep(0)
task.cancel()
await task
Expand Down Expand Up @@ -155,7 +157,7 @@ fn coroutine_cancel_handle() {
let test = r#"
import asyncio;
async def main():
task = asyncio.create_task(cancellable_sleep(1))
task = asyncio.create_task(cancellable_sleep(999))
await asyncio.sleep(0)
task.cancel()
return await task
Expand Down Expand Up @@ -203,3 +205,32 @@ fn coroutine_is_cancelled() {
.unwrap();
})
}

#[test]
fn coroutine_panic() {
#[pyfunction]
async fn panic() {
panic!("test panic");
}
Python::with_gil(|gil| {
let panic = wrap_pyfunction!(panic, gil).unwrap();
let test = r#"
import asyncio
coro = panic()
try:
asyncio.run(coro)
except BaseException as err:
assert type(err).__name__ == "PanicException"
assert str(err) == "test panic"
else:
assert False
try:
coro.send(None)
except RuntimeError as err:
assert str(err) == "cannot reuse already awaited coroutine"
else:
assert False;
"#;
py_run!(gil, panic, &handle_windows(test));
})
}

0 comments on commit 2ca9f59

Please sign in to comment.