diff --git a/brainunit/sparse/_coo.py b/brainunit/sparse/_coo.py index be1d0aa..88d912c 100644 --- a/brainunit/sparse/_coo.py +++ b/brainunit/sparse/_coo.py @@ -222,37 +222,63 @@ def __pos__(self): ) def _binary_op(self, other, op): + if isinstance(other, COO): + if id(self.row) == id(other.row) and id(self.col) == id(other.col): + return COO( + (op(self.data, other.data), self.row, self.col), + shape=self.shape, + rows_sorted=self._rows_sorted, + cols_sorted=self._cols_sorted + ) if isinstance(other, JAXSparse): - raise NotImplementedError("mul between two sparse objects.") + raise NotImplementedError(f"binary operation {op} between two sparse objects.") + other = asarray(other) if other.size == 1: return COO( (op(self.data, other), self.row, self.col), - shape=self.shape + shape=self.shape, + rows_sorted=self._rows_sorted, + cols_sorted=self._cols_sorted ) elif other.ndim == 2 and other.shape == self.shape: other = other[self.row, self.col] return COO( (op(self.data, other), self.row, self.col), - shape=self.shape + shape=self.shape, + rows_sorted=self._rows_sorted, + cols_sorted=self._cols_sorted ) else: raise NotImplementedError(f"mul with object of shape {other.shape}") def _binary_rop(self, other, op): + if isinstance(other, COO): + if id(self.row) == id(other.row) and id(self.col) == id(other.col): + return COO( + (op(other.data, self.data), self.row, self.col), + shape=self.shape, + rows_sorted=self._rows_sorted, + cols_sorted=self._cols_sorted + ) if isinstance(other, JAXSparse): - raise NotImplementedError("mul between two sparse objects.") + raise NotImplementedError(f"binary operation {op} between two sparse objects.") + other = asarray(other) if other.size == 1: return COO( (op(other, self.data), self.row, self.col), - shape=self.shape + shape=self.shape, + rows_sorted=self._rows_sorted, + cols_sorted=self._cols_sorted ) elif other.ndim == 2 and other.shape == self.shape: other = other[self.row, self.col] return COO( (op(other, self.data), self.row, self.col), - shape=self.shape + shape=self.shape, + rows_sorted=self._rows_sorted, + cols_sorted=self._cols_sorted ) else: raise NotImplementedError(f"mul with object of shape {other.shape}") diff --git a/brainunit/sparse/_coo_test.py b/brainunit/sparse/_coo_test.py index 0a4beaa..89be4d7 100644 --- a/brainunit/sparse/_coo_test.py +++ b/brainunit/sparse/_coo_test.py @@ -302,6 +302,9 @@ def f(sp, x): grads = jax.grad(f)(sp, xs) + sp = sp + grads * 1e-3 + sp = sp + 1e-3 * grads + def test_jit(self): @jax.jit def f(sp, x): diff --git a/brainunit/sparse/_csr.py b/brainunit/sparse/_csr.py index 2257bcd..612ba1b 100644 --- a/brainunit/sparse/_csr.py +++ b/brainunit/sparse/_csr.py @@ -120,8 +120,15 @@ def __pos__(self): return CSR((self.data.__pos__(), self.indices, self.indptr), shape=self.shape) def _binary_op(self, other, op): + if isinstance(other, CSR): + if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr): + return CSR( + (op(self.data, other.data), self.indices, self.indptr), + shape=self.shape + ) if isinstance(other, JAXSparse): - raise NotImplementedError("mul between two sparse objects.") + raise NotImplementedError(f"binary operation {op} between two sparse objects.") + other = asarray(other) if other.size == 1: return CSR( @@ -139,8 +146,15 @@ def _binary_op(self, other, op): raise NotImplementedError(f"mul with object of shape {other.shape}") def _binary_rop(self, other, op): + if isinstance(other, CSR): + if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr): + return CSR( + (op(other.data, self.data), self.indices, self.indptr), + shape=self.shape + ) if isinstance(other, JAXSparse): - raise NotImplementedError("mul between two sparse objects.") + raise NotImplementedError(f"binary operation {op} between two sparse objects.") + other = asarray(other) if other.size == 1: return CSR( @@ -294,8 +308,15 @@ def __pos__(self): return CSC((self.data.__pos__(), self.indices, self.indptr), shape=self.shape) def _binary_op(self, other, op): + if isinstance(other, CSC): + if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr): + return CSC( + (op(self.data, other.data), self.indices, self.indptr), + shape=self.shape + ) if isinstance(other, JAXSparse): - raise NotImplementedError("mul between two sparse objects.") + raise NotImplementedError(f"binary operation {op} between two sparse objects.") + other = asarray(other) if other.size == 1: return CSC( @@ -313,8 +334,15 @@ def _binary_op(self, other, op): raise NotImplementedError(f"mul with object of shape {other.shape}") def _binary_rop(self, other, op): + if isinstance(other, CSC): + if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr): + return CSC( + (op(other.data, self.data), self.indices, self.indptr), + shape=self.shape + ) if isinstance(other, JAXSparse): - raise NotImplementedError("mul between two sparse objects.") + raise NotImplementedError(f"binary operation {op} between two sparse objects.") + other = asarray(other) if other.size == 1: return CSC( diff --git a/brainunit/sparse/_csr_test.py b/brainunit/sparse/_csr_test.py index 981a4cf..8fe144a 100644 --- a/brainunit/sparse/_csr_test.py +++ b/brainunit/sparse/_csr_test.py @@ -301,6 +301,9 @@ def f(csr, x): grads = jax.grad(f)(csr, xs) + csr = csr + grads * 1e-3 + csr = csr + 1e-3 * grads + def test_jit(self): @jax.jit def f(csr, x): @@ -596,6 +599,9 @@ def f(csc, x): grads = jax.grad(f)(csc, xs) + csc = csc + grads * 1e-3 + csc = csc + 1e-3 * grads + def test_jit(self): @jax.jit