Skip to content

Commit

Permalink
fix edge cases with exception handling
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Sep 16, 2023
1 parent 8f4a26a commit 3a434e6
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 60 deletions.
1 change: 1 addition & 0 deletions newsfragments/3455.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`Err` returned from `#[pyfunction]` will now have a non-None `__context__` if called from inside a `catch` block.
1 change: 1 addition & 0 deletions newsfragments/3455.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix `IterNextOutput::Return` not returning a value on PyPy.
1 change: 1 addition & 0 deletions pytests/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
hypothesis>=3.55
pytest>=6.0
pytest-asyncio>=0.21
pytest-benchmark>=3.4
psutil>=5.6
typing_extensions>=4.0.0
87 changes: 87 additions & 0 deletions pytests/src/awaitable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//! The following classes are examples of objects which implement Python's
//! awaitable protocol.
//!
//! Both IterAwaitable and FutureAwaitable will return a value immediately
//! when awaited, see guide examples related to pyo3-asyncio for ways
//! to suspend tasks and await results.
use pyo3::{prelude::*, pyclass::IterNextOutput};

#[pyclass]
#[derive(Debug)]
pub(crate) struct IterAwaitable {
result: Option<PyResult<PyObject>>,
}

#[pymethods]
impl IterAwaitable {
#[new]
fn new(result: PyObject) -> Self {
IterAwaitable {
result: Some(Ok(result)),
}
}

fn __await__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> {
pyself
}

fn __iter__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> {
pyself
}

fn __next__(&mut self, py: Python<'_>) -> PyResult<IterNextOutput<PyObject, PyObject>> {
match self.result.take() {
Some(res) => match res {
Ok(v) => Ok(IterNextOutput::Return(v)),
Err(err) => Err(err),
},
_ => Ok(IterNextOutput::Yield(py.None())),
}
}
}

#[pyclass]
pub(crate) struct FutureAwaitable {
#[pyo3(get, set, name = "_asyncio_future_blocking")]
py_block: bool,
result: Option<PyResult<PyObject>>,
}

#[pymethods]
impl FutureAwaitable {
#[new]
fn new(result: PyObject) -> Self {
FutureAwaitable {
py_block: false,
result: Some(Ok(result)),
}
}

fn __await__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> {
pyself
}

fn __iter__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> {
pyself
}

fn __next__(
mut pyself: PyRefMut<'_, Self>,
) -> PyResult<IterNextOutput<PyRefMut<'_, Self>, PyObject>> {
match pyself.result {
Some(_) => match pyself.result.take().unwrap() {
Ok(v) => Ok(IterNextOutput::Return(v)),
Err(err) => Err(err),
},
_ => Ok(IterNextOutput::Yield(pyself)),
}
}
}

#[pymodule]
pub fn awaitable(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<IterAwaitable>()?;
m.add_class::<FutureAwaitable>()?;
Ok(())
}
3 changes: 3 additions & 0 deletions pytests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::wrap_pymodule;

pub mod awaitable;
pub mod buf_and_str;
pub mod comparisons;
pub mod datetime;
Expand All @@ -17,6 +18,7 @@ pub mod subclassing;

#[pymodule]
fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pymodule!(awaitable::awaitable))?;
#[cfg(not(Py_LIMITED_API))]
m.add_wrapped(wrap_pymodule!(buf_and_str::buf_and_str))?;
m.add_wrapped(wrap_pymodule!(comparisons::comparisons))?;
Expand All @@ -37,6 +39,7 @@ fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> {

let sys = PyModule::import(py, "sys")?;
let sys_modules: &PyDict = sys.getattr("modules")?.downcast()?;
sys_modules.set_item("pyo3_pytests.awaitable", m.getattr("awaitable")?)?;
sys_modules.set_item("pyo3_pytests.buf_and_str", m.getattr("buf_and_str")?)?;
sys_modules.set_item("pyo3_pytests.comparisons", m.getattr("comparisons")?)?;
sys_modules.set_item("pyo3_pytests.datetime", m.getattr("datetime")?)?;
Expand Down
4 changes: 4 additions & 0 deletions pytests/src/pyclasses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,14 @@ impl AssertingBaseClass {
}
}

#[pyclass]
struct ClassWithoutConstructor;

#[pymodule]
pub fn pyclasses(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<EmptyClass>()?;
m.add_class::<PyClassIter>()?;
m.add_class::<AssertingBaseClass>()?;
m.add_class::<ClassWithoutConstructor>()?;
Ok(())
}
13 changes: 13 additions & 0 deletions pytests/tests/test_awaitable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from pyo3_pytests.awaitable import IterAwaitable, FutureAwaitable


@pytest.mark.asyncio
async def test_iter_awaitable():
assert await IterAwaitable(5) == 5


@pytest.mark.asyncio
async def test_future_awaitable():
assert await FutureAwaitable(5) == 5
26 changes: 25 additions & 1 deletion pytests/tests/test_pyclasses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Type

import pytest
from pyo3_pytests import pyclasses

Expand Down Expand Up @@ -32,7 +34,29 @@ class AssertingSubClass(pyclasses.AssertingBaseClass):


def test_new_classmethod():
# The `AssertingBaseClass` constructor errors if it is not passed the relevant subclass.
# The `AssertingBaseClass` constructor errors if it is not passed the
# relevant subclass.
_ = AssertingSubClass(expected_type=AssertingSubClass)
with pytest.raises(ValueError):
_ = AssertingSubClass(expected_type=str)


class ClassWithoutConstructorPy:
def __new__(cls):
raise TypeError("No constructor defined")


@pytest.mark.parametrize(
"cls", [pyclasses.ClassWithoutConstructor, ClassWithoutConstructorPy]
)
def test_no_constructor_defined_propagates_cause(cls: Type):
original_error = ValueError("Original message")
with pytest.raises(Exception) as exc_info:
try:
raise original_error
except Exception:
cls() # should raise TypeError("No constructor defined")

assert exc_info.type is TypeError
assert exc_info.value.args == ("No constructor defined",)
assert exc_info.value.__context__ is original_error
133 changes: 74 additions & 59 deletions src/err/err_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,16 @@ impl PyErrState {
}

#[cfg(not(Py_3_12))]
pub(crate) fn into_ffi_tuple(
self,
py: Python<'_>,
) -> (*mut ffi::PyObject, *mut ffi::PyObject, *mut ffi::PyObject) {
match self {
pub(crate) fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
let (mut ptype, mut pvalue, mut ptraceback) = match self {
PyErrState::Lazy(lazy) => {
let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
if unsafe { ffi::PyExceptionClass_Check(ptype.as_ptr()) } == 0 {
PyErrState::lazy(
PyTypeError::type_object(py),
"exceptions must derive from BaseException",
)
.into_ffi_tuple(py)
} else {
(ptype.into_ptr(), pvalue.into_ptr(), std::ptr::null_mut())
}
// To be consistent with 3.12 logic, go via raise_lazy.
raise_lazy(py, lazy);
let mut ptype = std::ptr::null_mut();
let mut pvalue = std::ptr::null_mut();
let mut ptraceback = std::ptr::null_mut();
unsafe { ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback) };
(ptype, pvalue, ptraceback)
}
PyErrState::FfiTuple {
ptype,
Expand All @@ -123,21 +117,8 @@ impl PyErrState {
pvalue.map_or(std::ptr::null_mut(), Py::into_ptr),
ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr),
),
PyErrState::Normalized(PyErrStateNormalized {
ptype,
pvalue,
ptraceback,
}) => (
ptype.into_ptr(),
pvalue.into_ptr(),
ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr),
),
}
}

#[cfg(not(Py_3_12))]
pub(crate) fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
let (mut ptype, mut pvalue, mut ptraceback) = self.into_ffi_tuple(py);
PyErrState::Normalized(normalized) => return normalized,
};

unsafe {
ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
Expand All @@ -151,41 +132,75 @@ impl PyErrState {

#[cfg(Py_3_12)]
pub(crate) fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
// To keep the implementation simple, just write the exception into the interpreter,
// which will cause it to be normalized
self.restore(py);
// Safety: self.restore(py) will set the raised exception
let pvalue = unsafe { Py::from_owned_ptr(py, ffi::PyErr_GetRaisedException()) };
PyErrStateNormalized { pvalue }
}

#[cfg(not(Py_3_12))]
pub(crate) fn restore(self, py: Python<'_>) {
let (ptype, pvalue, ptraceback) = self.into_ffi_tuple(py);
unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) }
match self {
PyErrState::Lazy(lazy) => {
// See note on raise_lazy about possible future efficiency gain
raise_lazy(py, lazy);
// Safety: raise_lazy will set the raised exception
let pvalue = unsafe { Py::from_owned_ptr(py, ffi::PyErr_GetRaisedException()) };
PyErrStateNormalized { pvalue }
}
PyErrState::Normalized(normalized) => normalized,
}
}

#[cfg(Py_3_12)]
pub(crate) fn restore(self, py: Python<'_>) {
match self {
PyErrState::Lazy(lazy) => {
let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
unsafe {
if ffi::PyExceptionClass_Check(ptype.as_ptr()) == 0 {
ffi::PyErr_SetString(
PyTypeError::type_object_raw(py).cast(),
"exceptions must derive from BaseException\0"
.as_ptr()
.cast(),
)
} else {
ffi::PyErr_SetObject(ptype.as_ptr(), pvalue.as_ptr())
}
}
PyErrState::Lazy(lazy) => raise_lazy(py, lazy),
#[cfg(not(Py_3_12))]
PyErrState::FfiTuple {
ptype,
pvalue,
ptraceback,
} => {
let ptype = ptype.into_ptr();
let pvalue = pvalue.map_or(std::ptr::null_mut(), Py::into_ptr);
let ptraceback = ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr);
unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) }
}
PyErrState::Normalized(PyErrStateNormalized { pvalue }) => unsafe {
ffi::PyErr_SetRaisedException(pvalue.into_ptr())
PyErrState::Normalized(PyErrStateNormalized {
#[cfg(not(Py_3_12))]
ptype,
pvalue,
#[cfg(not(Py_3_12))]
ptraceback,
}) => unsafe {
#[cfg(not(Py_3_12))]
{
ffi::PyErr_Restore(
ptype.into_ptr(),
pvalue.into_ptr(),
ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr),
)
}
#[cfg(Py_3_12)]
{
ffi::PyErr_SetRaisedException(pvalue.into_ptr())
}
},
}
}
}

/// Raises a "lazy" exception state into the Python interpreter.
///
/// In principle this could be split in two; first a function to create an exception
/// in a normalized state, and then a call to `PyErr_SetRaisedException` to raise it.
///
/// This would require either moving some logic from C to Rust, or requesting a new
/// API in CPython.
fn raise_lazy(py: Python<'_>, lazy: Box<PyErrStateLazyFn>) {
let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
unsafe {
if ffi::PyExceptionClass_Check(ptype.as_ptr()) == 0 {
ffi::PyErr_SetString(
PyTypeError::type_object_raw(py).cast(),
"exceptions must derive from BaseException\0"
.as_ptr()
.cast(),
)
} else {
ffi::PyErr_SetObject(ptype.as_ptr(), pvalue.as_ptr())
}
}
}

0 comments on commit 3a434e6

Please sign in to comment.