Skip to content

Commit

Permalink
feat: support multiple slice in LengthDelay
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed May 4, 2022
1 parent a064ead commit 4ed3aca
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions brainpy/math/delayvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ def reset(

# time variables
if self.idx is None:
self.idx = Variable(jnp.asarray([0], dtype=jnp.int32))
self.idx = Variable(jnp.asarray([0]))
else:
self.idx.value = jnp.asarray([0], dtype=jnp.int32)
self.idx.value = jnp.asarray([0])

# delay data
if self.data is None:
Expand All @@ -349,7 +349,7 @@ def _check_delay(self, delay_len, transforms):
f'maximum delay {self.num_delay_step}. But we '
f'got {delay_len}')

def __call__(self, delay_len, indices=None):
def __call__(self, delay_len, *indices):
# check
if check.is_checking():
id_tap(self._check_delay, delay_len)
Expand All @@ -358,10 +358,10 @@ def __call__(self, delay_len, indices=None):
if delay_idx.dtype not in [jnp.int32, jnp.int64]:
raise ValueError(f'"delay_len" must be integer, but we got {delay_len}')
# the delay data
if indices is None:
return self.data[delay_idx]
if len(indices) > 0:
return self.data[delay_idx, *indices]
else:
return self.data[delay_idx, indices]
return self.data[delay_idx]

def update(self, value: Union[float, JaxArray, jnp.DeviceArray]):
if jnp.shape(value) != self.shape:
Expand Down

0 comments on commit 4ed3aca

Please sign in to comment.