Skip to content

Commit

Permalink
bug: fix bug of brainpy.math.fill_diagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed May 2, 2022
1 parent fc52810 commit 0e9e09c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
7 changes: 5 additions & 2 deletions brainpy/math/numpy_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,8 +1521,11 @@ def vander(x, N=None, increasing=False):


def fill_diagonal(a, val):
assert isinstance(a, JaxArray), f'Must be a JaxArray, but got {type(a)}'
assert a.ndim >= 2, f'Only support tensor has dimension >= 2, but got {a.shape}'
if not isinstance(a, JaxArray):
raise ValueError(f'Must be a JaxArray, but got {type(a)}')
if a.ndim < 2:
raise ValueError(f'Only support tensor has dimension >= 2, but got {a.shape}')
val = _remove_jaxarray(val)
i, j = jnp.diag_indices(_min(a.shape[-2:]))
a._value = a.value.at[..., i, j].set(val)

Expand Down
14 changes: 14 additions & 0 deletions brainpy/math/tests/test_numpy_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,18 @@ def test_remove_diag2(self):
with self.assertRaises(ValueError):
bm.remove_diag(a)

def test_fill_diagonal(self):
a = bm.arange(10)
with self.assertRaises(ValueError):
bm.fill_diagonal(a, 0.)

b = jnp.ones((10, 10))
with self.assertRaises(ValueError):
bm.fill_diagonal(b, 0)

bm.random.seed()
c = bm.random.rand(10, 10)
bm.fill_diagonal(c, 0)

bm.fill_diagonal(c, bm.arange(10))

0 comments on commit 0e9e09c

Please sign in to comment.