diff --git a/brainpy/math/numpy_ops.py b/brainpy/math/numpy_ops.py index 3538e84b6..fbc81be20 100644 --- a/brainpy/math/numpy_ops.py +++ b/brainpy/math/numpy_ops.py @@ -1521,8 +1521,11 @@ def vander(x, N=None, increasing=False): def fill_diagonal(a, val): - assert isinstance(a, JaxArray), f'Must be a JaxArray, but got {type(a)}' - assert a.ndim >= 2, f'Only support tensor has dimension >= 2, but got {a.shape}' + if not isinstance(a, JaxArray): + raise ValueError(f'Must be a JaxArray, but got {type(a)}') + if a.ndim < 2: + raise ValueError(f'Only support tensor has dimension >= 2, but got {a.shape}') + val = _remove_jaxarray(val) i, j = jnp.diag_indices(_min(a.shape[-2:])) a._value = a.value.at[..., i, j].set(val) diff --git a/brainpy/math/tests/test_numpy_ops.py b/brainpy/math/tests/test_numpy_ops.py index 0b7d41ba5..fc6e046f4 100644 --- a/brainpy/math/tests/test_numpy_ops.py +++ b/brainpy/math/tests/test_numpy_ops.py @@ -37,4 +37,18 @@ def test_remove_diag2(self): with self.assertRaises(ValueError): bm.remove_diag(a) + def test_fill_diagonal(self): + a = bm.arange(10) + with self.assertRaises(ValueError): + bm.fill_diagonal(a, 0.) + + b = jnp.ones((10, 10)) + with self.assertRaises(ValueError): + bm.fill_diagonal(b, 0) + + bm.random.seed() + c = bm.random.rand(10, 10) + bm.fill_diagonal(c, 0) + + bm.fill_diagonal(c, bm.arange(10))