Skip to content

Commit

Permalink
Implement __eq__ instead of __richcmp__.
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinhartman committed Nov 1, 2023
1 parent d80a6c5 commit c3e6b96
Showing 1 changed file with 27 additions and 42 deletions.
69 changes: 27 additions & 42 deletions crates/accelerate/src/quantum_circuit/circuit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ use crate::quantum_circuit::circuit_instruction::CircuitInstruction;
use crate::quantum_circuit::intern_context::{BitType, IndexType, InternContext};
use crate::quantum_circuit::py_ext;
use hashbrown::HashMap;
use pyo3::basic::CompareOp;
use pyo3::exceptions::{PyIndexError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyIterator, PyList, PySlice, PyTuple, PyType};
use pyo3::{AsPyPointer, PyObject, PyResult, PyTraverseError, PyVisit};
use pyo3::{PyObject, PyResult, PyTraverseError, PyVisit};
use std::hash::{Hash, Hasher};
use std::iter::zip;

Expand Down Expand Up @@ -371,16 +370,32 @@ impl CircuitData {
#[classattr]
const __hash__: Option<Py<PyAny>> = None;

fn __richcmp__(
slf: &PyCell<Self>,
other: &PyAny,
op: CompareOp,
py: Python<'_>,
) -> PyResult<PyObject> {
match op {
CompareOp::Eq => CircuitData::equals(slf, other).map(|r| r.into_py(py)),
CompareOp::Ne => CircuitData::equals(slf, other).map(|r| (!r).into_py(py)),
_ => Ok(py.NotImplemented()),
fn __eq__(slf: &PyCell<Self>, other: &PyAny) -> PyResult<bool> {
let slf: &PyAny = slf;
if slf.is(other) {
return Ok(true);
}
if slf.len()? != other.len()? {
return Ok(false);
}
// Implemented using generic iterators on both sides
// for simplicity.
let mut ours_itr = slf.iter()?;
let mut theirs_itr = other.iter()?;
loop {
match (ours_itr.next(), theirs_itr.next()) {
(Some(ours), Some(theirs)) => {
if !ours?.eq(theirs?)? {
return Ok(false);
}
}
(None, None) => {
return Ok(true);
}
_ => {
return Ok(false);
}
}
}
}

Expand Down Expand Up @@ -448,36 +463,6 @@ impl CircuitData {
Ok(index as usize)
}

fn equals(slf: &PyAny, other: &PyAny) -> PyResult<bool> {
let slf_len = slf.len()?;
let other_len = other.len();
if other_len.is_ok() && slf_len != other_len.unwrap() {
return Ok(false);
}
let mut ours_itr = slf.iter()?;
let mut theirs_itr = match other.iter() {
Ok(i) => i,
Err(_) => {
return Ok(false);
}
};
loop {
match (ours_itr.next(), theirs_itr.next()) {
(Some(ours), Some(theirs)) => {
if !ours?.eq(theirs?)? {
return Ok(false);
}
}
(None, None) => {
return Ok(true);
}
_ => {
return Ok(false);
}
}
}
}

fn get_or_cache(
&mut self,
py: Python<'_>,
Expand Down

0 comments on commit c3e6b96

Please sign in to comment.