From dc20c20521c98e362e0ad9b5a912a666c00856a1 Mon Sep 17 00:00:00 2001 From: Jake Lishman Date: Tue, 25 Jun 2024 18:50:00 +0100 Subject: [PATCH] Encapsulate Python sequence-like indexers This encapsulates a lot of the common logic around Python sequence-like indexers (`SliceOrInt`) into iterators that handle adapting negative indices and slices in `usize` for containers of a given size. These indexers now all implement `ExactSizeIterator` and `DoubleEndedIterator`, so they can be used with all `Iterator` methods, and can be used (with `Iterator::map` and friends) as inputs to `PyList::new_bound`, which makes code simpler at all points of use. The special-cased uses of this kind of thing from `CircuitData` are replaced with the new forms. This had no measurable impact on performance on my machine, and removes a lot noise from error-handling and highly specialised functions. --- Cargo.lock | 2 + Cargo.toml | 1 + crates/accelerate/Cargo.toml | 1 + .../src/euler_one_qubit_decomposer.rs | 55 +-- crates/accelerate/src/two_qubit_decompose.rs | 59 +-- crates/circuit/Cargo.toml | 1 + crates/circuit/src/circuit_data.rs | 307 +++++--------- crates/circuit/src/lib.rs | 12 +- crates/circuit/src/slice.rs | 375 ++++++++++++++++++ 9 files changed, 517 insertions(+), 296 deletions(-) create mode 100644 crates/circuit/src/slice.rs diff --git a/Cargo.lock b/Cargo.lock index 454823748e8d..ed94de91de8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1189,6 +1189,7 @@ dependencies = [ "rayon", "rustworkx-core", "smallvec", + "thiserror", ] [[package]] @@ -1201,6 +1202,7 @@ dependencies = [ "numpy", "pyo3", "smallvec", + "thiserror", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 13f43cfabcdc..a6ccf60f7f4b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ num-complex = "0.4" ndarray = "^0.15.6" numpy = "0.21.0" smallvec = "1.13" +thiserror = "1.0" # Most of the crates don't need the feature `extension-module`, since only `qiskit-pyext` builds an # actual C extension (the feature disables linking in `libpython`, which is forbidden in Python diff --git a/crates/accelerate/Cargo.toml b/crates/accelerate/Cargo.toml index b377a9b38a6d..58bccb19ac24 100644 --- a/crates/accelerate/Cargo.toml +++ b/crates/accelerate/Cargo.toml @@ -23,6 +23,7 @@ rustworkx-core = "0.14" faer = "0.19.1" itertools = "0.13.0" qiskit-circuit.workspace = true +thiserror.workspace = true [dependencies.smallvec] workspace = true diff --git a/crates/accelerate/src/euler_one_qubit_decomposer.rs b/crates/accelerate/src/euler_one_qubit_decomposer.rs index 9f10f76de467..01725269bb84 100644 --- a/crates/accelerate/src/euler_one_qubit_decomposer.rs +++ b/crates/accelerate/src/euler_one_qubit_decomposer.rs @@ -21,9 +21,9 @@ use std::f64::consts::PI; use std::ops::Deref; use std::str::FromStr; -use pyo3::exceptions::{PyIndexError, PyValueError}; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use pyo3::types::PyString; +use pyo3::types::{PyList, PyString}; use pyo3::wrap_pyfunction; use pyo3::Python; @@ -31,8 +31,8 @@ use ndarray::prelude::*; use numpy::PyReadonlyArray2; use pyo3::pybacked::PyBackedStr; +use qiskit_circuit::slice::{PySequenceIndex, SequenceIndex}; use qiskit_circuit::util::c64; -use qiskit_circuit::SliceOrInt; pub const ANGLE_ZERO_EPSILON: f64 = 1e-12; @@ -97,46 +97,15 @@ impl OneQubitGateSequence { Ok(self.gates.len()) } - fn __getitem__(&self, py: Python, idx: SliceOrInt) -> PyResult { - match idx { - SliceOrInt::Slice(slc) => { - let len = self.gates.len().try_into().unwrap(); - let indices = slc.indices(len)?; - let mut out_vec: Vec<(String, SmallVec<[f64; 3]>)> = Vec::new(); - // Start and stop will always be positive the slice api converts - // negatives to the index for example: - // list(range(5))[-1:-3:-1] - // will return start=4, stop=2, and step=-1 - let mut pos: isize = indices.start; - let mut cond = if indices.step < 0 { - pos > indices.stop - } else { - pos < indices.stop - }; - while cond { - if pos < len as isize { - out_vec.push(self.gates[pos as usize].clone()); - } - pos += indices.step; - if indices.step < 0 { - cond = pos > indices.stop; - } else { - cond = pos < indices.stop; - } - } - Ok(out_vec.into_py(py)) - } - SliceOrInt::Int(idx) => { - let len = self.gates.len() as isize; - if idx >= len || idx < -len { - Err(PyIndexError::new_err(format!("Invalid index, {idx}"))) - } else if idx < 0 { - let len = self.gates.len(); - Ok(self.gates[len - idx.unsigned_abs()].to_object(py)) - } else { - Ok(self.gates[idx as usize].to_object(py)) - } - } + fn __getitem__(&self, py: Python, idx: PySequenceIndex) -> PyResult { + match idx.with_len(self.gates.len())? { + SequenceIndex::Int(idx) => Ok(self.gates[idx].to_object(py)), + indices => Ok(PyList::new_bound( + py, + indices.iter().map(|pos| self.gates[pos].to_object(py)), + ) + .into_any() + .unbind()), } } } diff --git a/crates/accelerate/src/two_qubit_decompose.rs b/crates/accelerate/src/two_qubit_decompose.rs index 8637cb03c735..37061d5159f4 100644 --- a/crates/accelerate/src/two_qubit_decompose.rs +++ b/crates/accelerate/src/two_qubit_decompose.rs @@ -21,10 +21,6 @@ use approx::{abs_diff_eq, relative_eq}; use num_complex::{Complex, Complex64, ComplexFloat}; use num_traits::Zero; -use pyo3::exceptions::{PyIndexError, PyValueError}; -use pyo3::prelude::*; -use pyo3::wrap_pyfunction; -use pyo3::Python; use smallvec::{smallvec, SmallVec}; use std::f64::consts::{FRAC_1_SQRT_2, PI}; use std::ops::Deref; @@ -37,7 +33,11 @@ use ndarray::prelude::*; use ndarray::Zip; use numpy::PyReadonlyArray2; use numpy::{IntoPyArray, ToPyArray}; + +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; +use pyo3::types::PyList; use crate::convert_2q_block_matrix::change_basis; use crate::euler_one_qubit_decomposer::{ @@ -52,8 +52,8 @@ use rand_distr::StandardNormal; use rand_pcg::Pcg64Mcg; use qiskit_circuit::gate_matrix::{CX_GATE, H_GATE, ONE_QUBIT_IDENTITY, SX_GATE, X_GATE}; +use qiskit_circuit::slice::{PySequenceIndex, SequenceIndex}; use qiskit_circuit::util::{c64, GateArray1Q, GateArray2Q, C_M_ONE, C_ONE, C_ZERO, IM, M_IM}; -use qiskit_circuit::SliceOrInt; const PI2: f64 = PI / 2.; const PI4: f64 = PI / 4.; @@ -1131,46 +1131,15 @@ impl TwoQubitGateSequence { Ok(self.gates.len()) } - fn __getitem__(&self, py: Python, idx: SliceOrInt) -> PyResult { - match idx { - SliceOrInt::Slice(slc) => { - let len = self.gates.len().try_into().unwrap(); - let indices = slc.indices(len)?; - let mut out_vec: TwoQubitSequenceVec = Vec::new(); - // Start and stop will always be positive the slice api converts - // negatives to the index for example: - // list(range(5))[-1:-3:-1] - // will return start=4, stop=2, and step=- - let mut pos: isize = indices.start; - let mut cond = if indices.step < 0 { - pos > indices.stop - } else { - pos < indices.stop - }; - while cond { - if pos < len as isize { - out_vec.push(self.gates[pos as usize].clone()); - } - pos += indices.step; - if indices.step < 0 { - cond = pos > indices.stop; - } else { - cond = pos < indices.stop; - } - } - Ok(out_vec.into_py(py)) - } - SliceOrInt::Int(idx) => { - let len = self.gates.len() as isize; - if idx >= len || idx < -len { - Err(PyIndexError::new_err(format!("Invalid index, {idx}"))) - } else if idx < 0 { - let len = self.gates.len(); - Ok(self.gates[len - idx.unsigned_abs()].to_object(py)) - } else { - Ok(self.gates[idx as usize].to_object(py)) - } - } + fn __getitem__(&self, py: Python, idx: PySequenceIndex) -> PyResult { + match idx.with_len(self.gates.len())? { + SequenceIndex::Int(idx) => Ok(self.gates[idx].to_object(py)), + indices => Ok(PyList::new_bound( + py, + indices.iter().map(|pos| self.gates[pos].to_object(py)), + ) + .into_any() + .unbind()), } } } diff --git a/crates/circuit/Cargo.toml b/crates/circuit/Cargo.toml index dd7e878537d9..50160c7bac17 100644 --- a/crates/circuit/Cargo.toml +++ b/crates/circuit/Cargo.toml @@ -14,6 +14,7 @@ hashbrown.workspace = true num-complex.workspace = true ndarray.workspace = true numpy.workspace = true +thiserror.workspace = true [dependencies.pyo3] workspace = true diff --git a/crates/circuit/src/circuit_data.rs b/crates/circuit/src/circuit_data.rs index 07f4579a4cd3..10e0691021a1 100644 --- a/crates/circuit/src/circuit_data.rs +++ b/crates/circuit/src/circuit_data.rs @@ -22,11 +22,12 @@ use crate::imports::{BUILTIN_LIST, QUBIT}; use crate::interner::{IndexedInterner, Interner, InternerKey}; use crate::operations::{Operation, OperationType, Param, StandardGate}; use crate::parameter_table::{ParamEntry, ParamTable, GLOBAL_PHASE_INDEX}; -use crate::{Clbit, Qubit, SliceOrInt}; +use crate::slice::{PySequenceIndex, SequenceIndex}; +use crate::{Clbit, Qubit}; use pyo3::exceptions::{PyIndexError, PyValueError}; use pyo3::prelude::*; -use pyo3::types::{PyList, PySet, PySlice, PyTuple, PyType}; +use pyo3::types::{PyList, PySet, PyTuple, PyType}; use pyo3::{intern, PyTraverseError, PyVisit}; use hashbrown::{HashMap, HashSet}; @@ -321,7 +322,7 @@ impl CircuitData { } pub fn append_inner(&mut self, py: Python, value: PyRef) -> PyResult { - let packed = self.pack(py, value)?; + let packed = self.pack(value)?; let new_index = self.data.len(); self.data.push(packed); self.update_param_table(py, new_index, None) @@ -744,184 +745,130 @@ impl CircuitData { } // Note: we also rely on this to make us iterable! - pub fn __getitem__(&self, py: Python, index: &Bound) -> PyResult { - // Internal helper function to get a specific - // instruction by index. - fn get_at( - self_: &CircuitData, - py: Python<'_>, - index: isize, - ) -> PyResult> { - let index = self_.convert_py_index(index)?; - if let Some(inst) = self_.data.get(index) { - let qubits = self_.qargs_interner.intern(inst.qubits_id); - let clbits = self_.cargs_interner.intern(inst.clbits_id); - Py::new( - py, - CircuitInstruction::new( - py, - inst.op.clone(), - self_.qubits.map_indices(qubits.value), - self_.clbits.map_indices(clbits.value), - inst.params.clone(), - inst.extra_attrs.clone(), - ), - ) - } else { - Err(PyIndexError::new_err(format!( - "No element at index {:?} in circuit data", - index - ))) - } - } - - if index.is_exact_instance_of::() { - let slice = self.convert_py_slice(index.downcast_exact::()?)?; - let result = slice - .into_iter() - .map(|i| get_at(self, py, i)) - .collect::>>()?; - Ok(result.into_py(py)) - } else { - Ok(get_at(self, py, index.extract()?)?.into_py(py)) + pub fn __getitem__(&self, py: Python, index: PySequenceIndex) -> PyResult { + // Get a single item, assuming the index is validated as in bounds. + let get_single = |index: usize| { + let inst = &self.data[index]; + let qubits = self.qargs_interner.intern(inst.qubits_id); + let clbits = self.cargs_interner.intern(inst.clbits_id); + CircuitInstruction::new( + py, + inst.op.clone(), + self.qubits.map_indices(qubits.value), + self.clbits.map_indices(clbits.value), + inst.params.clone(), + inst.extra_attrs.clone(), + ) + .into_py(py) + }; + match index.with_len(self.data.len())? { + SequenceIndex::Int(index) => Ok(get_single(index)), + indices => Ok(PyList::new_bound(py, indices.iter().map(get_single)).into_py(py)), } } - pub fn __delitem__(&mut self, py: Python, index: SliceOrInt) -> PyResult<()> { - match index { - SliceOrInt::Slice(slice) => { - let slice = { - let mut s = self.convert_py_slice(&slice)?; - if s.len() > 1 && s.first().unwrap() < s.last().unwrap() { - // Reverse the order so we're sure to delete items - // at the back first (avoids messing up indices). - s.reverse() - } - s - }; - for i in slice.into_iter() { - self.__delitem__(py, SliceOrInt::Int(i))?; - } - self.reindex_parameter_table(py)?; - Ok(()) - } - SliceOrInt::Int(index) => { - let index = self.convert_py_index(index)?; - if self.data.get(index).is_some() { - if index == self.data.len() { - // For individual removal from param table before - // deletion - self.remove_from_parameter_table(py, index)?; - self.data.remove(index); - } else { - // For delete in the middle delete before reindexing - self.data.remove(index); - self.reindex_parameter_table(py)?; - } - Ok(()) - } else { - Err(PyIndexError::new_err(format!( - "No element at index {:?} in circuit data", - index - ))) - } - } - } + pub fn __delitem__(&mut self, py: Python, index: PySequenceIndex) -> PyResult<()> { + self.delitem(py, index.with_len(self.data.len())?) } pub fn setitem_no_param_table_update( &mut self, - py: Python<'_>, - index: isize, - value: &Bound, + index: usize, + value: PyRef, ) -> PyResult<()> { - let index = self.convert_py_index(index)?; - let value: PyRef = value.downcast()?.borrow(); - let mut packed = self.pack(py, value)?; + let mut packed = self.pack(value)?; std::mem::swap(&mut packed, &mut self.data[index]); Ok(()) } - pub fn __setitem__( - &mut self, - py: Python<'_>, - index: SliceOrInt, - value: &Bound, - ) -> PyResult<()> { - match index { - SliceOrInt::Slice(slice) => { - let indices = slice.indices(self.data.len().try_into().unwrap())?; - let slice = self.convert_py_slice(&slice)?; - let values = value.iter()?.collect::>>>()?; - if indices.step != 1 && slice.len() != values.len() { - // A replacement of a different length when step isn't exactly '1' - // would result in holes. - return Err(PyValueError::new_err(format!( - "attempt to assign sequence of size {:?} to extended slice of size {:?}", - values.len(), - slice.len(), - ))); - } + pub fn __setitem__(&mut self, index: PySequenceIndex, value: &Bound) -> PyResult<()> { + fn set_single(slf: &mut CircuitData, index: usize, value: &Bound) -> PyResult<()> { + let py = value.py(); + let mut packed = slf.pack(value.downcast::()?.borrow())?; + slf.remove_from_parameter_table(py, index)?; + std::mem::swap(&mut packed, &mut slf.data[index]); + slf.update_param_table(py, index, None)?; + Ok(()) + } - for (i, v) in slice.iter().zip(values.iter()) { - self.__setitem__(py, SliceOrInt::Int(*i), v)?; + let py = value.py(); + match index.with_len(self.data.len())? { + SequenceIndex::Int(index) => set_single(self, index, value), + indices @ SequenceIndex::PosRange { + start, + stop, + step: 1, + } => { + // `list` allows setting a slice with step +1 to an arbitrary length. + let values = value.iter()?.collect::>>()?; + for (index, value) in indices.iter().zip(values.iter()) { + set_single(self, index, value)?; } - - if slice.len() > values.len() { - // Delete any extras. - let slice = PySlice::new_bound( + if indices.len() > values.len() { + self.delitem( py, - indices.start + values.len() as isize, - indices.stop, - 1isize, - ); - self.__delitem__(py, SliceOrInt::Slice(slice))?; + SequenceIndex::PosRange { + start: start + values.len(), + stop, + step: 1, + }, + )? } else { - // Insert any extra values. - for v in values.iter().skip(slice.len()).rev() { - let v: PyRef = v.extract()?; - self.insert(py, indices.stop, v)?; + for value in values[indices.len()..].iter().rev() { + self.insert(stop as isize, value.downcast()?.borrow())?; } } - Ok(()) } - SliceOrInt::Int(index) => { - let index = self.convert_py_index(index)?; - let value: PyRef = value.extract()?; - let mut packed = self.pack(py, value)?; - self.remove_from_parameter_table(py, index)?; - std::mem::swap(&mut packed, &mut self.data[index]); - self.update_param_table(py, index, None)?; - Ok(()) + indices => { + let values = value.iter()?.collect::>>()?; + if indices.len() == values.len() { + for (index, value) in indices.iter().zip(values.iter()) { + set_single(self, index, value)?; + } + Ok(()) + } else { + Err(PyValueError::new_err(format!( + "attempt to assign sequence of size {:?} to extended slice of size {:?}", + values.len(), + indices.len(), + ))) + } } } } - pub fn insert( - &mut self, - py: Python<'_>, - index: isize, - value: PyRef, - ) -> PyResult<()> { - let index = self.convert_py_index_clamped(index); - let old_len = self.data.len(); - let packed = self.pack(py, value)?; + pub fn insert(&mut self, mut index: isize, value: PyRef) -> PyResult<()> { + // `list.insert` has special-case extra clamping logic for its index argument. + let index = { + if index < 0 { + // This can't exceed `isize::MAX` because `self.data[0]` is larger than a byte. + index += self.data.len() as isize; + } + if index < 0 { + 0 + } else if index as usize > self.data.len() { + self.data.len() + } else { + index as usize + } + }; + let py = value.py(); + let packed = self.pack(value)?; self.data.insert(index, packed); - if index == old_len { - self.update_param_table(py, old_len, None)?; + if index == self.data.len() - 1 { + self.update_param_table(py, index, None)?; } else { self.reindex_parameter_table(py)?; } Ok(()) } - pub fn pop(&mut self, py: Python<'_>, index: Option) -> PyResult { - let index = - index.unwrap_or_else(|| std::cmp::max(0, self.data.len() as isize - 1).into_py(py)); - let item = self.__getitem__(py, index.bind(py))?; - - self.__delitem__(py, index.bind(py).extract()?)?; + pub fn pop(&mut self, py: Python<'_>, index: Option) -> PyResult { + let index = index.unwrap_or(PySequenceIndex::Int(-1)); + let native_index = index.with_len(self.data.len())?; + let item = self.__getitem__(py, index)?; + self.delitem(py, native_index)?; Ok(item) } @@ -931,7 +878,7 @@ impl CircuitData { value: &Bound, params: Option)>>, ) -> PyResult { - let packed = self.pack(py, value.try_borrow()?)?; + let packed = self.pack(value.try_borrow()?)?; let new_index = self.data.len(); self.data.push(packed); self.update_param_table(py, new_index, params) @@ -1175,56 +1122,22 @@ impl CircuitData { } impl CircuitData { - /// Converts a Python slice to a `Vec` of indices into - /// the instruction listing, [CircuitData.data]. - fn convert_py_slice(&self, slice: &Bound) -> PyResult> { - let indices = slice.indices(self.data.len().try_into().unwrap())?; - if indices.step > 0 { - Ok((indices.start..indices.stop) - .step_by(indices.step as usize) - .collect()) - } else { - let mut out = Vec::with_capacity(indices.slicelength as usize); - let mut x = indices.start; - while x > indices.stop { - out.push(x); - x += indices.step; - } - Ok(out) + /// Native internal driver of `__delitem__` that uses a Rust-space version of the + /// `SequenceIndex`. This assumes that the `SequenceIndex` contains only in-bounds indices, and + /// panics if not. + fn delitem(&mut self, py: Python, indices: SequenceIndex) -> PyResult<()> { + // We need to delete in reverse order so we don't invalidate higher indices with a deletion. + for index in indices.descending() { + self.data.remove(index); } - } - - /// Converts a Python index to an index into the instruction listing, - /// or one past its end. - /// If the resulting index would be < 0, clamps to 0. - /// If the resulting index would be > len(data), clamps to len(data). - fn convert_py_index_clamped(&self, index: isize) -> usize { - let index = if index < 0 { - index + self.data.len() as isize - } else { - index - }; - std::cmp::min(std::cmp::max(0, index), self.data.len() as isize) as usize - } - - /// Converts a Python index to an index into the instruction listing. - fn convert_py_index(&self, index: isize) -> PyResult { - let index = if index < 0 { - index + self.data.len() as isize - } else { - index - }; - - if index < 0 || index >= self.data.len() as isize { - return Err(PyIndexError::new_err(format!( - "Index {:?} is out of bounds.", - index, - ))); + if !indices.is_empty() { + self.reindex_parameter_table(py)?; } - Ok(index as usize) + Ok(()) } - fn pack(&mut self, py: Python, inst: PyRef) -> PyResult { + fn pack(&mut self, inst: PyRef) -> PyResult { + let py = inst.py(); let qubits = Interner::intern( &mut self.qargs_interner, InternerKey::Value(self.qubits.map_bits(inst.qubits.bind(py))?.collect()), diff --git a/crates/circuit/src/lib.rs b/crates/circuit/src/lib.rs index 9fcaa36480cf..9f0a8017bf21 100644 --- a/crates/circuit/src/lib.rs +++ b/crates/circuit/src/lib.rs @@ -17,23 +17,13 @@ pub mod gate_matrix; pub mod imports; pub mod operations; pub mod parameter_table; +pub mod slice; pub mod util; mod bit_data; mod interner; use pyo3::prelude::*; -use pyo3::types::PySlice; - -/// A private enumeration type used to extract arguments to pymethod -/// that may be either an index or a slice -#[derive(FromPyObject)] -pub enum SliceOrInt<'a> { - // The order here defines the order the variants are tried in the FromPyObject` derivation. - // `Int` is _much_ more common, so that should be first. - Int(isize), - Slice(Bound<'a, PySlice>), -} pub type BitType = u32; #[derive(Copy, Clone, Debug, Hash, Ord, PartialOrd, Eq, PartialEq)] diff --git a/crates/circuit/src/slice.rs b/crates/circuit/src/slice.rs new file mode 100644 index 000000000000..056adff0a282 --- /dev/null +++ b/crates/circuit/src/slice.rs @@ -0,0 +1,375 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2024 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use thiserror::Error; + +use pyo3::exceptions::PyIndexError; +use pyo3::prelude::*; +use pyo3::types::PySlice; + +use self::sealed::{Descending, SequenceIndexIter}; + +/// A Python-space indexer for the standard `PySequence` type; a single integer or a slice. +/// +/// These come in as `isize`s from Python space, since Python typically allows negative indices. +/// Use `with_len` to specialize the index to a valid Rust-space indexer into a collection of the +/// given length. +pub enum PySequenceIndex<'py> { + Int(isize), + Slice(Bound<'py, PySlice>), +} + +impl<'py> FromPyObject<'py> for PySequenceIndex<'py> { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + // `slice` can't be subclassed in Python, so it's safe (and faster) to check for it exactly. + // The `downcast_exact` check is just a pointer comparison, so while `slice` is the less + // common input, doing that first has little-to-no impact on the speed of the `isize` path, + // while the reverse makes `slice` inputs significantly slower. + if let Ok(slice) = ob.downcast_exact::() { + return Ok(Self::Slice(slice.clone())); + } + Ok(Self::Int(ob.extract()?)) + } +} + +impl<'py> PySequenceIndex<'py> { + /// Specialize this index to a collection of the given `len`, returning a Rust-native type. + pub fn with_len(&self, len: usize) -> Result { + match self { + PySequenceIndex::Int(index) => { + let index = if *index >= 0 { + let index = *index as usize; + if index >= len { + return Err(PySequenceIndexError::OutOfRange); + } + index + } else { + len.checked_sub(index.unsigned_abs()) + .ok_or(PySequenceIndexError::OutOfRange)? + }; + Ok(SequenceIndex::Int(index)) + } + PySequenceIndex::Slice(slice) => { + let indices = slice + .indices(len as ::std::os::raw::c_long) + .map_err(PySequenceIndexError::from)?; + if indices.step > 0 { + Ok(SequenceIndex::PosRange { + start: indices.start as usize, + stop: indices.stop as usize, + step: indices.step as usize, + }) + } else { + Ok(SequenceIndex::NegRange { + // `indices.start` can be negative if the collection length is 0. + start: (indices.start >= 0).then_some(indices.start as usize), + // `indices.stop` can be negative if the 0 index should be output. + stop: (indices.stop >= 0).then_some(indices.stop as usize), + step: indices.step.unsigned_abs(), + }) + } + } + } + } +} + +/// Error type for problems encountered when calling methods on `PySequenceIndex`. +#[derive(Error, Debug)] +pub enum PySequenceIndexError { + #[error("index out of range")] + OutOfRange, + #[error(transparent)] + InnerPy(#[from] PyErr), +} +impl From for PyErr { + fn from(value: PySequenceIndexError) -> PyErr { + match value { + PySequenceIndexError::OutOfRange => PyIndexError::new_err("index out of range"), + PySequenceIndexError::InnerPy(inner) => inner, + } + } +} + +/// Rust-native version of a Python sequence-like indexer. +/// +/// Typically this is constructed by a call to `PySequenceIndex::with_len`, which guarantees that +/// all the indices will be in bounds for a collection of the given length. +/// +/// This splits the positive- and negative-step versions of the slice in two so it can be translated +/// more easily into static dispatch. This type can be converted into several types of iterator. +#[derive(Clone, Copy, Debug)] +pub enum SequenceIndex { + Int(usize), + PosRange { + start: usize, + stop: usize, + step: usize, + }, + NegRange { + start: Option, + stop: Option, + step: usize, + }, +} + +impl SequenceIndex { + /// The number of indices this refers to. + pub fn len(&self) -> usize { + match self { + Self::Int(_) => 1, + Self::PosRange { start, stop, step } => { + let gap = stop.saturating_sub(*start); + gap / *step + (gap % *step != 0) as usize + } + Self::NegRange { start, stop, step } => 'arm: { + let Some(start) = start else { break 'arm 0 }; + let gap = stop + .map(|stop| start.saturating_sub(stop)) + .unwrap_or(*start + 1); + gap / step + (gap % step != 0) as usize + } + } + } + + pub fn is_empty(&self) -> bool { + // This is just to keep clippy happy; the length is already fairly inexpensive to calculate. + self.len() == 0 + } + + /// Get an iterator over the indices. This will be a single-item iterator for the case of + /// `Self::Int`, but you probably wanted to destructure off that case beforehand anyway. + pub fn iter(&self) -> SequenceIndexIter { + match self { + Self::Int(value) => SequenceIndexIter::Int(Some(*value)), + Self::PosRange { start, step, .. } => SequenceIndexIter::PosRange { + lowest: *start, + step: *step, + indices: 0..self.len(), + }, + Self::NegRange { start, step, .. } => SequenceIndexIter::NegRange { + // We can unwrap `highest` to an arbitrary value if `None`, because in that case the + // `len` is 0 and the iterator will not yield any objects. + highest: start.unwrap_or_default(), + step: *step, + indices: 0..self.len(), + }, + } + } + + // Get an iterator over the contained indices that is guaranteed to iterate from the highest + // index to the lowest. + pub fn descending(&self) -> Descending { + Descending(self.iter()) + } +} + +impl IntoIterator for SequenceIndex { + type Item = usize; + type IntoIter = SequenceIndexIter; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +// Private module to make it impossible to construct or inspect the internals of the iterator types +// from outside this file, while still allowing them to be used. +mod sealed { + /// Custom iterator for indices for Python sequence-likes. + /// + /// In the range types, the `indices ` are `Range` objects that run from 0 to the length of the + /// iterator. In theory, we could generate the iterators ourselves, but that ends up with a lot of + /// boilerplate. + #[derive(Clone, Debug)] + pub enum SequenceIndexIter { + Int(Option), + PosRange { + lowest: usize, + step: usize, + indices: ::std::ops::Range, + }, + NegRange { + highest: usize, + // The step of the iterator, but note that this is a negative range, so the forwards method + // steps downwards from `upper` towards `lower`. + step: usize, + indices: ::std::ops::Range, + }, + } + impl Iterator for SequenceIndexIter { + type Item = usize; + + #[inline] + fn next(&mut self) -> Option { + match self { + Self::Int(value) => value.take(), + Self::PosRange { + lowest, + step, + indices, + } => indices.next().map(|idx| *lowest + idx * *step), + Self::NegRange { + highest, + step, + indices, + } => indices.next().map(|idx| *highest - idx * *step), + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + match self { + Self::Int(None) => (0, Some(0)), + Self::Int(Some(_)) => (1, Some(1)), + Self::PosRange { indices, .. } | Self::NegRange { indices, .. } => { + indices.size_hint() + } + } + } + } + impl DoubleEndedIterator for SequenceIndexIter { + #[inline] + fn next_back(&mut self) -> Option { + match self { + Self::Int(value) => value.take(), + Self::PosRange { + lowest, + step, + indices, + } => indices.next_back().map(|idx| *lowest + idx * *step), + Self::NegRange { + highest, + step, + indices, + } => indices.next_back().map(|idx| *highest - idx * *step), + } + } + } + impl ExactSizeIterator for SequenceIndexIter {} + + pub struct Descending(pub SequenceIndexIter); + impl Iterator for Descending { + type Item = usize; + + #[inline] + fn next(&mut self) -> Option { + match self.0 { + SequenceIndexIter::Int(_) | SequenceIndexIter::NegRange { .. } => self.0.next(), + SequenceIndexIter::PosRange { .. } => self.0.next_back(), + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } + } + impl DoubleEndedIterator for Descending { + #[inline] + fn next_back(&mut self) -> Option { + match self.0 { + SequenceIndexIter::Int(_) | SequenceIndexIter::NegRange { .. } => { + self.0.next_back() + } + SequenceIndexIter::PosRange { .. } => self.0.next(), + } + } + } + impl ExactSizeIterator for Descending {} +} + +#[cfg(test)] +mod test { + use super::*; + + /// Get a set of test parametrisations for iterator methods. The second argument is the + /// expected values from a normal forward iteration. + fn index_iterator_cases() -> impl Iterator)> { + let pos = |start, stop, step| SequenceIndex::PosRange { start, stop, step }; + let neg = |start, stop, step| SequenceIndex::NegRange { start, stop, step }; + + [ + (SequenceIndex::Int(3), vec![3]), + (pos(0, 5, 2), vec![0, 2, 4]), + (pos(2, 10, 1), vec![2, 3, 4, 5, 6, 7, 8, 9]), + (pos(1, 15, 3), vec![1, 4, 7, 10, 13]), + (neg(Some(3), None, 1), vec![3, 2, 1, 0]), + (neg(Some(3), None, 2), vec![3, 1]), + (neg(Some(2), Some(0), 1), vec![2, 1]), + (neg(Some(2), Some(0), 2), vec![2]), + (neg(Some(2), Some(0), 3), vec![2]), + (neg(Some(10), Some(2), 3), vec![10, 7, 4]), + (neg(None, None, 1), vec![]), + (neg(None, None, 3), vec![]), + ] + .into_iter() + } + + /// Test that the index iterator's implementation of `ExactSizeIterator` is correct. + #[test] + fn index_iterator() { + for (index, forwards) in index_iterator_cases() { + // We're testing that all the values are the same, and the `size_hint` is correct at + // every single point. + let mut actual = Vec::new(); + let mut sizes = Vec::new(); + let mut iter = index.iter(); + loop { + sizes.push(iter.size_hint().0); + if let Some(next) = iter.next() { + actual.push(next); + } else { + break; + } + } + assert_eq!( + actual, forwards, + "values for {:?}\nActual : {:?}\nExpected: {:?}", + index, actual, forwards, + ); + let expected_sizes = (0..=forwards.len()).rev().collect::>(); + assert_eq!( + sizes, expected_sizes, + "sizes for {:?}\nActual : {:?}\nExpected: {:?}", + index, sizes, expected_sizes, + ); + } + } + + /// Test that the index iterator's implementation of `DoubleEndedIterator` is correct. + #[test] + fn reversed_index_iterator() { + for (index, forwards) in index_iterator_cases() { + let actual = index.iter().rev().collect::>(); + let expected = forwards.into_iter().rev().collect::>(); + assert_eq!( + actual, expected, + "reversed {:?}\nActual : {:?}\nExpected: {:?}", + index, actual, expected, + ); + } + } + + /// Test that `descending` produces its values in reverse-sorted order. + #[test] + fn descending() { + for (index, mut expected) in index_iterator_cases() { + let actual = index.descending().collect::>(); + expected.sort_by(|left, right| right.cmp(left)); + assert_eq!( + actual, expected, + "descending {:?}\nActual : {:?}\nExpected: {:?}", + index, actual, expected, + ); + } + } +}