From 4fec00d41b2cd0d0b7df9db4e87710ab8b5bfa83 Mon Sep 17 00:00:00 2001 From: Jake Lishman Date: Tue, 29 Oct 2024 14:36:55 +0000 Subject: [PATCH] Match `x + x` behaviour to `x += x` This puts the `self is other` check into the `__add__` and `__sub__` methods so that the behaviour of `x + x` is consistent with `x += x`, with regards to the addition being done as a scalar multiplication instead of concatentation. Both forms are mathematically correct, but this makes sure they're aligned. --- crates/accelerate/src/sparse_observable.rs | 89 +++++++++++++--------- 1 file changed, 52 insertions(+), 37 deletions(-) diff --git a/crates/accelerate/src/sparse_observable.rs b/crates/accelerate/src/sparse_observable.rs index 4533504d2324..14e386f3a2cb 100644 --- a/crates/accelerate/src/sparse_observable.rs +++ b/crates/accelerate/src/sparse_observable.rs @@ -1206,16 +1206,27 @@ impl SparseObservable { slf.borrow().eq(&other.borrow()) } - fn __add__(&self, other: &Bound) -> PyResult> { - let py = other.py(); + fn __add__(slf_: &Bound, other: &Bound) -> PyResult> { + let py = slf_.py(); + if slf_.is(other) { + // This fast path is for consistency with the in-place `__iadd__`, which would otherwise + // struggle to do the addition to itself. + return Ok(<&SparseObservable as ::std::ops::Mul<_>>::mul( + &slf_.borrow(), + Complex64::new(2.0, 0.0), + ) + .into_py(py)); + } let Some(other) = coerce_to_observable(other)? else { return Ok(py.NotImplemented()); }; + let slf_ = slf_.borrow(); let other = other.borrow(); - self.check_equal_qubits(&other)?; - Ok((self + &other).into_py(py)) + slf_.check_equal_qubits(&other)?; + Ok(<&SparseObservable as ::std::ops::Add>::add(&slf_, &other).into_py(py)) } fn __radd__(&self, other: &Bound) -> PyResult> { + // No need to handle the `self is other` case here, because `__add__` will get it. let py = other.py(); let Some(other) = coerce_to_observable(other)? else { return Ok(py.NotImplemented()); @@ -1227,32 +1238,36 @@ impl SparseObservable { fn __iadd__(slf_: Bound, other: &Bound) -> PyResult<()> { if slf_.is(other) { *slf_.borrow_mut() *= Complex64::new(2.0, 0.0); - } else { - let mut slf_ = slf_.borrow_mut(); - let Some(other) = coerce_to_observable(other)? else { - // This is not well behaved - we _should_ return `NotImplemented` to Python space - // without an exception, but limitations in PyO3 prevent this at the moment. See - // https://github.com/PyO3/pyo3/issues/4605. - return Err(PyTypeError::new_err(format!( - "invalid object for in-place addition of 'SparseObservable': {}", - other.repr()? - ))); - }; - let other = other.borrow(); - slf_.check_equal_qubits(&other)?; - *slf_ += &other; + return Ok(()); } + let mut slf_ = slf_.borrow_mut(); + let Some(other) = coerce_to_observable(other)? else { + // This is not well behaved - we _should_ return `NotImplemented` to Python space + // without an exception, but limitations in PyO3 prevent this at the moment. See + // https://github.com/PyO3/pyo3/issues/4605. + return Err(PyTypeError::new_err(format!( + "invalid object for in-place addition of 'SparseObservable': {}", + other.repr()? + ))); + }; + let other = other.borrow(); + slf_.check_equal_qubits(&other)?; + *slf_ += &other; Ok(()) } - fn __sub__(&self, other: &Bound) -> PyResult> { - let py = other.py(); + fn __sub__(slf_: &Bound, other: &Bound) -> PyResult> { + let py = slf_.py(); + if slf_.is(other) { + return Ok(SparseObservable::zero(slf_.borrow().num_qubits).into_py(py)); + } let Some(other) = coerce_to_observable(other)? else { return Ok(py.NotImplemented()); }; + let slf_ = slf_.borrow(); let other = other.borrow(); - self.check_equal_qubits(&other)?; - Ok((self - &other).into_py(py)) + slf_.check_equal_qubits(&other)?; + Ok(<&SparseObservable as ::std::ops::Sub>::sub(&slf_, &other).into_py(py)) } fn __rsub__(&self, other: &Bound) -> PyResult> { let py = other.py(); @@ -1268,22 +1283,22 @@ impl SparseObservable { // This is not strictly the same thing as `a - a` if `a` contains non-finite // floating-point values (`inf - inf` is `NaN`, for example); we don't really have a // clear view on what floating-point guarantees we're going to make right now. - slf_.borrow_mut().clear() - } else { - let mut slf_ = slf_.borrow_mut(); - let Some(other) = coerce_to_observable(other)? else { - // This is not well behaved - we _should_ return `NotImplemented` to Python space - // without an exception, but limitations in PyO3 prevent this at the moment. See - // https://github.com/PyO3/pyo3/issues/4605. - return Err(PyTypeError::new_err(format!( - "invalid object for in-place subtraction of 'SparseObservable': {}", - other.repr()? - ))); - }; - let other = other.borrow(); - slf_.check_equal_qubits(&other)?; - *slf_ -= &other; + slf_.borrow_mut().clear(); + return Ok(()); } + let mut slf_ = slf_.borrow_mut(); + let Some(other) = coerce_to_observable(other)? else { + // This is not well behaved - we _should_ return `NotImplemented` to Python space + // without an exception, but limitations in PyO3 prevent this at the moment. See + // https://github.com/PyO3/pyo3/issues/4605. + return Err(PyTypeError::new_err(format!( + "invalid object for in-place subtraction of 'SparseObservable': {}", + other.repr()? + ))); + }; + let other = other.borrow(); + slf_.check_equal_qubits(&other)?; + *slf_ -= &other; Ok(()) }