Skip to content

Commit

Permalink
Match x + x behaviour to x += x
Browse files Browse the repository at this point in the history
This puts the `self is other` check into the `__add__` and `__sub__`
methods so that the behaviour of `x + x` is consistent with `x += x`,
with regards to the addition being done as a scalar multiplication
instead of concatentation.  Both forms are mathematically correct, but
this makes sure they're aligned.
  • Loading branch information
jakelishman committed Oct 29, 2024
1 parent 9c2983f commit 4fec00d
Showing 1 changed file with 52 additions and 37 deletions.
89 changes: 52 additions & 37 deletions crates/accelerate/src/sparse_observable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1206,16 +1206,27 @@ impl SparseObservable {
slf.borrow().eq(&other.borrow())
}

fn __add__(&self, other: &Bound<PyAny>) -> PyResult<Py<PyAny>> {
let py = other.py();
fn __add__(slf_: &Bound<Self>, other: &Bound<PyAny>) -> PyResult<Py<PyAny>> {
let py = slf_.py();
if slf_.is(other) {
// This fast path is for consistency with the in-place `__iadd__`, which would otherwise
// struggle to do the addition to itself.
return Ok(<&SparseObservable as ::std::ops::Mul<_>>::mul(
&slf_.borrow(),
Complex64::new(2.0, 0.0),
)
.into_py(py));
}
let Some(other) = coerce_to_observable(other)? else {
return Ok(py.NotImplemented());
};
let slf_ = slf_.borrow();
let other = other.borrow();
self.check_equal_qubits(&other)?;
Ok((self + &other).into_py(py))
slf_.check_equal_qubits(&other)?;
Ok(<&SparseObservable as ::std::ops::Add>::add(&slf_, &other).into_py(py))
}
fn __radd__(&self, other: &Bound<PyAny>) -> PyResult<Py<PyAny>> {
// No need to handle the `self is other` case here, because `__add__` will get it.
let py = other.py();
let Some(other) = coerce_to_observable(other)? else {
return Ok(py.NotImplemented());
Expand All @@ -1227,32 +1238,36 @@ impl SparseObservable {
fn __iadd__(slf_: Bound<SparseObservable>, other: &Bound<PyAny>) -> PyResult<()> {
if slf_.is(other) {
*slf_.borrow_mut() *= Complex64::new(2.0, 0.0);
} else {
let mut slf_ = slf_.borrow_mut();
let Some(other) = coerce_to_observable(other)? else {
// This is not well behaved - we _should_ return `NotImplemented` to Python space
// without an exception, but limitations in PyO3 prevent this at the moment. See
// https://github.com/PyO3/pyo3/issues/4605.
return Err(PyTypeError::new_err(format!(
"invalid object for in-place addition of 'SparseObservable': {}",
other.repr()?
)));
};
let other = other.borrow();
slf_.check_equal_qubits(&other)?;
*slf_ += &other;
return Ok(());
}
let mut slf_ = slf_.borrow_mut();
let Some(other) = coerce_to_observable(other)? else {
// This is not well behaved - we _should_ return `NotImplemented` to Python space
// without an exception, but limitations in PyO3 prevent this at the moment. See
// https://github.com/PyO3/pyo3/issues/4605.
return Err(PyTypeError::new_err(format!(
"invalid object for in-place addition of 'SparseObservable': {}",
other.repr()?
)));
};
let other = other.borrow();
slf_.check_equal_qubits(&other)?;
*slf_ += &other;
Ok(())
}

fn __sub__(&self, other: &Bound<PyAny>) -> PyResult<Py<PyAny>> {
let py = other.py();
fn __sub__(slf_: &Bound<Self>, other: &Bound<PyAny>) -> PyResult<Py<PyAny>> {
let py = slf_.py();
if slf_.is(other) {
return Ok(SparseObservable::zero(slf_.borrow().num_qubits).into_py(py));
}
let Some(other) = coerce_to_observable(other)? else {
return Ok(py.NotImplemented());
};
let slf_ = slf_.borrow();
let other = other.borrow();
self.check_equal_qubits(&other)?;
Ok((self - &other).into_py(py))
slf_.check_equal_qubits(&other)?;
Ok(<&SparseObservable as ::std::ops::Sub>::sub(&slf_, &other).into_py(py))
}
fn __rsub__(&self, other: &Bound<PyAny>) -> PyResult<Py<PyAny>> {
let py = other.py();
Expand All @@ -1268,22 +1283,22 @@ impl SparseObservable {
// This is not strictly the same thing as `a - a` if `a` contains non-finite
// floating-point values (`inf - inf` is `NaN`, for example); we don't really have a
// clear view on what floating-point guarantees we're going to make right now.
slf_.borrow_mut().clear()
} else {
let mut slf_ = slf_.borrow_mut();
let Some(other) = coerce_to_observable(other)? else {
// This is not well behaved - we _should_ return `NotImplemented` to Python space
// without an exception, but limitations in PyO3 prevent this at the moment. See
// https://github.com/PyO3/pyo3/issues/4605.
return Err(PyTypeError::new_err(format!(
"invalid object for in-place subtraction of 'SparseObservable': {}",
other.repr()?
)));
};
let other = other.borrow();
slf_.check_equal_qubits(&other)?;
*slf_ -= &other;
slf_.borrow_mut().clear();
return Ok(());
}
let mut slf_ = slf_.borrow_mut();
let Some(other) = coerce_to_observable(other)? else {
// This is not well behaved - we _should_ return `NotImplemented` to Python space
// without an exception, but limitations in PyO3 prevent this at the moment. See
// https://github.com/PyO3/pyo3/issues/4605.
return Err(PyTypeError::new_err(format!(
"invalid object for in-place subtraction of 'SparseObservable': {}",
other.repr()?
)));
};
let other = other.borrow();
slf_.check_equal_qubits(&other)?;
*slf_ -= &other;
Ok(())
}

Expand Down

0 comments on commit 4fec00d

Please sign in to comment.