Skip to content

Commit

Permalink
convert tuple iteration to use Py2 in list
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Nov 21, 2023
1 parent b1ece19 commit c90da02
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 35 deletions.
9 changes: 8 additions & 1 deletion src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ impl<'a> Input<'a> for PyAny {
fn extract_generic_iterable(&'a self) -> ValResult<GenericIterable<'a>> {
// Handle concrete non-overlapping types first, then abstract types
if let Ok(iterable) = self.downcast::<PyList>() {
Ok(GenericIterable::List(iterable))
Ok(GenericIterable::List(Py2::borrowed_from_gil_ref(&iterable).clone()))
} else if let Ok(iterable) = self.downcast::<PyTuple>() {
Ok(GenericIterable::Tuple(iterable))
} else if let Ok(iterable) = self.downcast::<PySet>() {
Expand Down Expand Up @@ -746,6 +746,13 @@ impl BorrowInput for &'_ PyAny {
}
}

impl BorrowInput for Py2<'_, PyAny> {
type Input<'a> = PyAny where Self: 'a;
fn borrow_input(&self) -> &Self::Input<'_> {
self.as_gil_ref()
}
}

/// Best effort check of whether it's likely to make sense to inspect obj for attributes and iterate over it
/// with `obj.dir()`
fn from_attributes_applicable(obj: &PyAny) -> bool {
Expand Down
69 changes: 39 additions & 30 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::tools::py_err;
use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator};

use super::input_string::StringMapping;
use super::{py_error_on_minusone, Input};
use super::{py_error_on_minusone, BorrowInput, Input};

pub struct ValidationMatch<T>(T, Exactness);

Expand Down Expand Up @@ -67,7 +67,7 @@ impl<T> ValidationMatch<T> {
/// This mostly matches python's definition of `Collection`.
#[cfg_attr(debug_assertions, derive(Debug))]
pub enum GenericIterable<'a> {
List(&'a PyList),
List(Py2<'a, PyList>),
Tuple(&'a PyTuple),
Set(&'a PySet),
FrozenSet(&'a PyFrozenSet),
Expand All @@ -92,33 +92,37 @@ impl<'a, 'py: 'a> GenericIterable<'a> {
pub fn as_sequence_iterator(
&self,
py: Python<'py>,
) -> PyResult<Box<dyn Iterator<Item = PyResult<&'a PyAny>> + 'a>> {
) -> PyResult<Box<dyn Iterator<Item = PyResult<Py2<'a, PyAny>>> + 'a>> {
match self {
GenericIterable::List(iter) => Ok(Box::new(iter.iter().map(Ok))),
GenericIterable::Tuple(iter) => Ok(Box::new(iter.iter().map(Ok))),
GenericIterable::Set(iter) => Ok(Box::new(iter.iter().map(Ok))),
GenericIterable::FrozenSet(iter) => Ok(Box::new(iter.iter().map(Ok))),
GenericIterable::List(iter) => Ok(Box::new(iter.clone().into_iter().map(Ok))),
GenericIterable::Tuple(iter) => Ok(Box::new(iter.iter().map(Ok).map(pyresult_ref_to_pyresult_py2))),
GenericIterable::Set(iter) => Ok(Box::new(iter.iter().map(Ok).map(pyresult_ref_to_pyresult_py2))),
GenericIterable::FrozenSet(iter) => Ok(Box::new(iter.iter().map(Ok).map(pyresult_ref_to_pyresult_py2))),
// Note that this iterates over only the keys, just like doing iter({}) in Python
GenericIterable::Dict(iter) => Ok(Box::new(iter.iter().map(|(k, _)| Ok(k)))),
GenericIterable::DictKeys(iter) => Ok(Box::new(iter.iter()?)),
GenericIterable::DictValues(iter) => Ok(Box::new(iter.iter()?)),
GenericIterable::DictItems(iter) => Ok(Box::new(iter.iter()?)),
GenericIterable::Dict(iter) => Ok(Box::new(
iter.iter().map(|(k, _)| Ok(k)).map(pyresult_ref_to_pyresult_py2),
)),
GenericIterable::DictKeys(iter) => Ok(Box::new(iter.iter()?.map(pyresult_ref_to_pyresult_py2))),
GenericIterable::DictValues(iter) => Ok(Box::new(iter.iter()?.map(pyresult_ref_to_pyresult_py2))),
GenericIterable::DictItems(iter) => Ok(Box::new(iter.iter()?.map(pyresult_ref_to_pyresult_py2))),
// Note that this iterates over only the keys, just like doing iter({}) in Python
GenericIterable::Mapping(iter) => Ok(Box::new(iter.keys()?.iter()?)),
GenericIterable::PyString(iter) => Ok(Box::new(iter.iter()?)),
GenericIterable::Bytes(iter) => Ok(Box::new(iter.iter()?)),
GenericIterable::PyByteArray(iter) => Ok(Box::new(iter.iter()?)),
GenericIterable::Sequence(iter) => Ok(Box::new(iter.iter()?)),
GenericIterable::Iterator(iter) => Ok(Box::new(iter.iter()?)),
GenericIterable::Mapping(iter) => Ok(Box::new(iter.keys()?.iter()?.map(pyresult_ref_to_pyresult_py2))),
GenericIterable::PyString(iter) => Ok(Box::new(iter.iter()?.map(pyresult_ref_to_pyresult_py2))),
GenericIterable::Bytes(iter) => Ok(Box::new(iter.iter()?.map(pyresult_ref_to_pyresult_py2))),
GenericIterable::PyByteArray(iter) => Ok(Box::new(iter.iter()?.map(pyresult_ref_to_pyresult_py2))),
GenericIterable::Sequence(iter) => Ok(Box::new(iter.iter()?.map(pyresult_ref_to_pyresult_py2))),
GenericIterable::Iterator(iter) => Ok(Box::new(iter.iter()?.map(pyresult_ref_to_pyresult_py2))),
GenericIterable::JsonArray(iter) => Ok(Box::new(iter.iter().map(move |v| {
let v = v.to_object(py);
Ok(v.into_ref(py))
let v = v.to_object(py).attach_into(py);
Ok(v)
}))),
// Note that this iterates over only the keys, just like doing iter({}) in Python, just for consistency
GenericIterable::JsonObject(iter) => Ok(Box::new(
iter.iter().map(move |(k, _)| Ok(k.to_object(py).into_ref(py))),
iter.iter().map(move |(k, _)| Ok(k.to_object(py).attach_into(py))),
)),
GenericIterable::JsonString(s) => Ok(Box::new(PyString::new(py, s).iter()?)),
GenericIterable::JsonString(s) => {
Ok(Box::new(PyString::new(py, s).iter()?.map(pyresult_ref_to_pyresult_py2)))
}
}
}
}
Expand Down Expand Up @@ -188,7 +192,7 @@ macro_rules! any_next_error {
#[allow(clippy::too_many_arguments)]
fn validate_iter_to_vec<'a, 's>(
py: Python<'a>,
iter: impl Iterator<Item = PyResult<&'a (impl Input<'a> + 'a)>>,
iter: impl Iterator<Item = PyResult<impl BorrowInput + 'a>>,
capacity: usize,
mut max_length_check: MaxLengthCheck<'a, impl Input<'a>>,
validator: &'s CombinedValidator,
Expand All @@ -198,7 +202,7 @@ fn validate_iter_to_vec<'a, 's>(
let mut errors: Vec<ValLineError> = Vec::new();
for (index, item_result) in iter.enumerate() {
let item = item_result.map_err(|e| any_next_error!(py, e, max_length_check.input, index))?;
match validator.validate(py, item, state) {
match validator.validate(py, item.borrow_input(), state) {
Ok(item) => {
max_length_check.incr()?;
output.push(item);
Expand Down Expand Up @@ -253,7 +257,7 @@ impl BuildSet for &PyFrozenSet {
fn validate_iter_to_set<'a, 's>(
py: Python<'a>,
set: impl BuildSet,
iter: impl Iterator<Item = PyResult<&'a (impl Input<'a> + 'a)>>,
iter: impl Iterator<Item = PyResult<impl BorrowInput + 'a>>,
input: &'a (impl Input<'a> + 'a),
field_type: &'static str,
max_length: Option<usize>,
Expand All @@ -263,7 +267,7 @@ fn validate_iter_to_set<'a, 's>(
let mut errors: Vec<ValLineError> = Vec::new();
for (index, item_result) in iter.enumerate() {
let item = item_result.map_err(|e| any_next_error!(py, e, input, index))?;
match validator.validate(py, item, state) {
match validator.validate(py, item.borrow_input(), state) {
Ok(item) => {
set.build_add(item)?;
if let Some(max_length) = max_length {
Expand Down Expand Up @@ -301,14 +305,14 @@ fn validate_iter_to_set<'a, 's>(
fn no_validator_iter_to_vec<'a, 's>(
py: Python<'a>,
input: &'a (impl Input<'a> + 'a),
iter: impl Iterator<Item = PyResult<&'a (impl Input<'a> + 'a)>>,
iter: impl Iterator<Item = PyResult<impl BorrowInput + 'a>>,
mut max_length_check: MaxLengthCheck<'a, impl Input<'a>>,
) -> ValResult<Vec<PyObject>> {
iter.enumerate()
.map(|(index, result)| {
let v = result.map_err(|e| any_next_error!(py, e, input, index))?;
max_length_check.incr()?;
Ok(v.to_object(py))
Ok(v.borrow_input().to_object(py))
})
.collect()
}
Expand Down Expand Up @@ -360,7 +364,7 @@ impl<'a> GenericIterable<'a> {
}

match self {
GenericIterable::List(collection) => validate!(collection.iter().map(Ok)),
GenericIterable::List(collection) => validate!(collection.clone().into_iter().map(Ok)),
GenericIterable::Tuple(collection) => validate!(collection.iter().map(Ok)),
GenericIterable::Set(collection) => validate!(collection.iter().map(Ok)),
GenericIterable::FrozenSet(collection) => validate!(collection.iter().map(Ok)),
Expand Down Expand Up @@ -389,7 +393,7 @@ impl<'a> GenericIterable<'a> {
}

match self {
GenericIterable::List(collection) => validate_set!(collection.iter().map(Ok)),
GenericIterable::List(collection) => validate_set!(collection.clone().into_iter().map(Ok)),
GenericIterable::Tuple(collection) => validate_set!(collection.iter().map(Ok)),
GenericIterable::Set(collection) => validate_set!(collection.iter().map(Ok)),
GenericIterable::FrozenSet(collection) => validate_set!(collection.iter().map(Ok)),
Expand All @@ -412,7 +416,7 @@ impl<'a> GenericIterable<'a> {

match self {
GenericIterable::List(collection) => {
no_validator_iter_to_vec(py, input, collection.iter().map(Ok), max_length_check)
no_validator_iter_to_vec(py, input, collection.clone().into_iter().map(Ok), max_length_check)
}
GenericIterable::Tuple(collection) => {
no_validator_iter_to_vec(py, input, collection.iter().map(Ok), max_length_check)
Expand Down Expand Up @@ -1039,3 +1043,8 @@ impl ToPyObject for Int {
}
}
}

/// Backwards-compatibility helper while migrating PyO3 API
fn pyresult_ref_to_pyresult_py2(result: PyResult<&'_ PyAny>) -> PyResult<Py2<'_, PyAny>> {
result.map(|any| Py2::borrowed_from_gil_ref(&any).clone())
}
9 changes: 5 additions & 4 deletions src/validators/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pyo3::types::{PyDict, PyList, PyTuple};

use crate::build_tools::is_strict;
use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult};
use crate::input::BorrowInput;
use crate::input::{GenericIterable, Input};
use crate::tools::SchemaDict;
use crate::validators::Exactness;
Expand Down Expand Up @@ -113,7 +114,7 @@ impl BuildValidator for TuplePositionalValidator {
}

#[allow(clippy::too_many_arguments)]
fn validate_tuple_positional<'s, 'data, T: Iterator<Item = PyResult<&'data I>>, I: Input<'data> + 'data>(
fn validate_tuple_positional<'s, 'data, T: Iterator<Item = PyResult<impl BorrowInput + 'data>>>(
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
Expand All @@ -126,7 +127,7 @@ fn validate_tuple_positional<'s, 'data, T: Iterator<Item = PyResult<&'data I>>,
) -> ValResult<()> {
for (index, validator) in items_validators.iter().enumerate() {
match collection_iter.next() {
Some(result) => match validator.validate(py, result?, state) {
Some(result) => match validator.validate(py, result?.borrow_input(), state) {
Ok(item) => output.push(item),
Err(ValError::LineErrors(line_errors)) => {
errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index.into())));
Expand All @@ -145,7 +146,7 @@ fn validate_tuple_positional<'s, 'data, T: Iterator<Item = PyResult<&'data I>>,
for (index, result) in collection_iter.enumerate() {
let item = result?;
match extras_validator {
Some(ref extras_validator) => match extras_validator.validate(py, item, state) {
Some(ref extras_validator) => match extras_validator.validate(py, item.borrow_input(), state) {
Ok(item) => output.push(item),
Err(ValError::LineErrors(line_errors)) => {
errors.extend(
Expand Down Expand Up @@ -222,7 +223,7 @@ impl Validator for TuplePositionalValidator {
}

match collection {
GenericIterable::List(collection_iter) => iter!(collection_iter.iter().map(Ok)),
GenericIterable::List(collection_iter) => iter!(collection_iter.clone().into_iter().map(Ok)),
GenericIterable::Tuple(collection_iter) => iter!(collection_iter.iter().map(Ok)),
GenericIterable::JsonArray(collection_iter) => iter!(collection_iter.iter().map(Ok)),
other => iter!(other.as_sequence_iterator(py)?),
Expand Down

0 comments on commit c90da02

Please sign in to comment.