From 34d76d0ff8a506af1149c5102697c029a17c3360 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Tue, 15 Feb 2022 22:51:37 +0000 Subject: [PATCH] pymethods: more tests for magic methods --- pyo3-macros-backend/src/pymethod.rs | 4 ++ tests/test_arithmetics.rs | 33 +++++++++++++ tests/test_proto_methods.rs | 74 ++++++++++++++++++++++++----- tests/test_sequence.rs | 19 ++++---- 4 files changed, 108 insertions(+), 22 deletions(-) diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 297f206b370..b5687d2f5c7 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -51,6 +51,8 @@ impl PyMethodKind { "__anext__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ANEXT__)), "__len__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__LEN__)), "__contains__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CONTAINS__)), + "__concat__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CONCAT__)), + "__repeat__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__REPEAT__)), "__getitem__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETITEM__)), "__pos__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__POS__)), "__neg__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__NEG__)), @@ -602,6 +604,8 @@ const __LEN__: SlotDef = SlotDef::new("Py_mp_length", "lenfunc").ret_ty(Ty::PySs const __CONTAINS__: SlotDef = SlotDef::new("Py_sq_contains", "objobjproc") .arguments(&[Ty::Object]) .ret_ty(Ty::Int); +const __CONCAT__: SlotDef = SlotDef::new("Py_sq_concat", "binaryfunc").arguments(&[Ty::Object]); +const __REPEAT__: SlotDef = SlotDef::new("Py_sq_repeat", "ssizeargfunc").arguments(&[Ty::PySsizeT]); const __GETITEM__: SlotDef = SlotDef::new("Py_mp_subscript", "binaryfunc").arguments(&[Ty::Object]); const __POS__: SlotDef = SlotDef::new("Py_nb_positive", "unaryfunc"); diff --git a/tests/test_arithmetics.rs b/tests/test_arithmetics.rs index ed1870bbcbe..548ed91360d 100644 --- a/tests/test_arithmetics.rs +++ b/tests/test_arithmetics.rs @@ -52,6 +52,39 @@ fn unary_arithmetic() { py_run!(py, c, "assert repr(round(c, 1)) == 'UA(3)'"); } +#[pyclass] +struct Indexable(i32); + +#[pymethods] +impl Indexable { + fn __index__(&self) -> i32 { + self.0 + } + + fn __int__(&self) -> i32 { + self.0 + } + + fn __float__(&self) -> f64 { + f64::from(self.0) + } + + fn __invert__(&self) -> Self { + Self(!self.0) + } +} + +#[test] +fn indexable() { + Python::with_gil(|py| { + let i = PyCell::new(py, Indexable(5)).unwrap(); + py_run!(py, i, "assert int(i) == 5"); + py_run!(py, i, "assert [0, 1, 2, 3, 4, 5][i] == 5"); + py_run!(py, i, "assert float(i) == 5.0"); + py_run!(py, i, "assert int(~i) == -6"); + }) +} + #[pyclass] struct InPlaceOperations { value: u32, diff --git a/tests/test_proto_methods.rs b/tests/test_proto_methods.rs index a6d621c9dfc..45455e195a3 100644 --- a/tests/test_proto_methods.rs +++ b/tests/test_proto_methods.rs @@ -608,6 +608,7 @@ fn getattr_doesnt_override_member() { /// Wraps a Python future and yield it once. #[pyclass] +#[derive(Debug)] struct OnceFuture { future: PyObject, polled: bool, @@ -645,26 +646,75 @@ fn test_await() { let gil = Python::acquire_gil(); let py = gil.python(); let once = py.get_type::(); - let source = pyo3::indoc::indoc!( - r#" + let source = r#" import asyncio -import sys - async def main(): res = await Once(await asyncio.sleep(0.1)) - return res + assert res is None + # For an odd error similar to https://bugs.python.org/issue38563 if sys.platform == "win32" and sys.version_info >= (3, 8, 0): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) -# get_event_loop can raise an error: https://github.com/PyO3/pyo3/pull/961#issuecomment-645238579 -loop = asyncio.new_event_loop() -asyncio.set_event_loop(loop) -assert loop.run_until_complete(main()) is None -loop.close() -"# - ); + +asyncio.run(main()) +"#; + let globals = PyModule::import(py, "__main__").unwrap().dict(); + globals.set_item("Once", once).unwrap(); + py.run(source, Some(globals), None) + .map_err(|e| e.print(py)) + .unwrap(); +} + +#[pyclass] +struct AsyncIterator { + future: Option>, +} + +#[pymethods] +impl AsyncIterator { + #[new] + fn new(future: Py) -> Self { + Self { + future: Some(future), + } + } + + fn __aiter__(slf: PyRef) -> PyRef { + slf + } + + fn __anext__(&mut self) -> Option> { + self.future.take() + } +} + +#[test] +fn test_anext_aiter() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let once = py.get_type::(); + let source = r#" +import asyncio + +async def main(): + count = 0 + async for result in AsyncIterator(Once(await asyncio.sleep(0.1))): + # The Once is awaited as part of the `async for` and produces None + assert result is None + count +=1 + assert count == 1 + +# For an odd error similar to https://bugs.python.org/issue38563 +if sys.platform == "win32" and sys.version_info >= (3, 8, 0): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + +asyncio.run(main()) +"#; let globals = PyModule::import(py, "__main__").unwrap().dict(); globals.set_item("Once", once).unwrap(); + globals + .set_item("AsyncIterator", py.get_type::()) + .unwrap(); py.run(source, Some(globals), None) .map_err(|e| e.print(py)) .unwrap(); diff --git a/tests/test_sequence.rs b/tests/test_sequence.rs index e156b9ff919..e60cc2cec7e 100644 --- a/tests/test_sequence.rs +++ b/tests/test_sequence.rs @@ -1,7 +1,5 @@ #![cfg(feature = "macros")] -#![cfg(feature = "pyproto")] // FIXME: change this to use #[pymethods] once supports sequence protocol -use pyo3::class::PySequenceProtocol; use pyo3::exceptions::{PyIndexError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{IntoPyDict, PyList}; @@ -32,10 +30,7 @@ impl ByteSequence { }) } } -} -#[pyproto] -impl PySequenceProtocol for ByteSequence { fn __len__(&self) -> usize { self.elements.len() } @@ -51,8 +46,12 @@ impl PySequenceProtocol for ByteSequence { self.elements[idx as usize] = value; } - fn __delitem__(&mut self, idx: isize) -> PyResult<()> { - if (idx < self.elements.len() as isize) && (idx >= 0) { + fn __delitem__(&mut self, mut idx: isize) -> PyResult<()> { + let self_len = self.elements.len() as isize; + if idx < 0 { + idx += self_len; + } + if (idx < self_len) && (idx >= 0) { self.elements.remove(idx as usize); Ok(()) } else { @@ -67,7 +66,7 @@ impl PySequenceProtocol for ByteSequence { } } - fn __concat__(&self, other: PyRef<'p, Self>) -> Self { + fn __concat__(&self, other: PyRef) -> Self { let mut elements = self.elements.clone(); elements.extend_from_slice(&other.elements); Self { elements } @@ -274,8 +273,8 @@ struct OptionList { items: Vec>, } -#[pyproto] -impl PySequenceProtocol for OptionList { +#[pymethods] +impl OptionList { fn __getitem__(&self, idx: isize) -> PyResult> { match self.items.get(idx as usize) { Some(x) => Ok(*x),