From fc528109d1b5249ebe014f8d31a8cd34fbb1eea9 Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 2 May 2022 10:58:13 +0800 Subject: [PATCH] update neuron models --- brainpy/dyn/neurons/reduced_models.py | 27 ++++++++++++++----------- brainpy/dyn/synapses/abstract_models.py | 2 +- brainpy/tools/checking.py | 7 ++++--- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/brainpy/dyn/neurons/reduced_models.py b/brainpy/dyn/neurons/reduced_models.py index 2ea2dd6cb..8a15f3d26 100644 --- a/brainpy/dyn/neurons/reduced_models.py +++ b/brainpy/dyn/neurons/reduced_models.py @@ -89,6 +89,7 @@ def __init__( V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), noise: Union[float, Tensor, Initializer, Callable] = None, noise_type: str = 'value', + keep_size: bool=False, method: str = 'exp_auto', name: str = None ): @@ -96,29 +97,31 @@ def __init__( super(LIF, self).__init__(size=size, name=name) # parameters + self.keep_size = keep_size self.noise_type = noise_type if noise_type not in ['func', 'value']: raise ValueError(f'noise_type only supports `func` and `value`, but we got {noise_type}') - self.V_rest = init_param(V_rest, self.num, allow_none=False) - self.V_reset = init_param(V_reset, self.num, allow_none=False) - self.V_th = init_param(V_th, self.num, allow_none=False) - self.tau = init_param(tau, self.num, allow_none=False) - self.tau_ref = init_param(tau_ref, self.num, allow_none=False) + size = self.size if keep_size else self.num + self.V_rest = init_param(V_rest, size, allow_none=False) + self.V_reset = init_param(V_reset, size, allow_none=False) + self.V_th = init_param(V_th, size, allow_none=False) + self.tau = init_param(tau, size, allow_none=False) + self.tau_ref = init_param(tau_ref, size, allow_none=False) if noise_type == 'func': self.noise = noise else: - self.noise = init_param(noise, self.num, allow_none=True) + self.noise = init_param(noise, size, allow_none=True) # initializers check_initializer(V_initializer, 'V_initializer') self._V_initializer = V_initializer # variables - self.V = bm.Variable(init_param(V_initializer, (self.num,))) - self.input = bm.Variable(bm.zeros(self.num)) - self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) - self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) - self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) + self.V = bm.Variable(init_param(V_initializer, size)) + self.input = bm.Variable(bm.zeros(size)) + self.spike = bm.Variable(bm.zeros(size, dtype=bool)) + self.t_last_spike = bm.Variable(bm.ones(size) * -1e7) + self.refractory = bm.Variable(bm.zeros(size, dtype=bool)) # integral f = lambda V, t, I_ext: (-V + self.V_rest + I_ext) / self.tau @@ -129,7 +132,7 @@ def __init__( self.integral = odeint(method=method, f=f) def reset(self): - self.V.value = init_param(self._V_initializer, (self.num,)) + self.V.value = init_param(self._V_initializer, self.size if self.keep_size else self.num) self.input[:] = 0 self.spike[:] = False self.t_last_spike[:] = -1e7 diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index b034287cb..39692468d 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -178,7 +178,7 @@ def update(self, t, dt): # update outputs target = getattr(self.post, self.post_key) if self.post_has_ref: - target += post_vs * (1. - self.post.refractory) + target += post_vs * bm.logical_not(self.post.refractory) else: target += post_vs diff --git a/brainpy/tools/checking.py b/brainpy/tools/checking.py index b929d9c98..39cd5eea1 100644 --- a/brainpy/tools/checking.py +++ b/brainpy/tools/checking.py @@ -289,13 +289,14 @@ def check_integer(value: int, name=None, min_bound=None, max_bound=None, allow_n else: raise ValueError(f'{name} must be an int, but got None') if not isinstance(value, int): - raise ValueError(f'{name} must be an int, but got {type(value)}') + if hasattr(value, 'dtype') and not jnp.issubdtype(value.dtype, jnp.integer): + raise ValueError(f'{name} must be an int, but got {value}') if min_bound is not None: - if value < min_bound: + if jnp.any(value < min_bound): raise ValueError(f"{name} must be an int bigger than {min_bound}, " f"while we got {value}") if max_bound is not None: - if value > max_bound: + if jnp.any(value > max_bound): raise ValueError(f"{name} must be an int smaller than {max_bound}, " f"while we got {value}")