diff --git a/brainpy/_src/dyn/neurons/base.py b/brainpy/_src/dyn/neurons/base.py index 02a457d0a..264ce8865 100644 --- a/brainpy/_src/dyn/neurons/base.py +++ b/brainpy/_src/dyn/neurons/base.py @@ -29,7 +29,7 @@ def __init__( scaling: Optional[bm.Scaling] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, ): @@ -43,18 +43,18 @@ def __init__( self.spk_reset = spk_reset self.spk_fun = is_callable(spk_fun) self.detach_spk = detach_spk - self._spk_type = spk_type + self._spk_dtype = spk_dtype if scaling is None: self.scaling = bm.get_membrane_scaling() else: self.scaling = scaling @property - def spk_type(self): - if self._spk_type is None: + def spk_dtype(self): + if self._spk_dtype is None: return bm.float_ if isinstance(self.mode, bm.TrainingMode) else bm.bool_ else: - return self._spk_type + return self._spk_dtype def offset_scaling(self, x, bias=None, scale=None): s = self.scaling.offset_scaling(x, bias=bias, scale=scale) diff --git a/brainpy/_src/dyn/neurons/lif.py b/brainpy/_src/dyn/neurons/lif.py index 018ad24a9..988c915ac 100644 --- a/brainpy/_src/dyn/neurons/lif.py +++ b/brainpy/_src/dyn/neurons/lif.py @@ -77,7 +77,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -99,7 +99,7 @@ def __init__( spk_fun=spk_fun, detach_spk=detach_spk, method=method, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, scaling=scaling) @@ -124,7 +124,7 @@ def derivative(self, V, t, I): def reset_state(self, batch_size=None, **kwargs): self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) def update(self, x=None): t = share.load('t') @@ -206,7 +206,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -230,7 +230,7 @@ def __init__( spk_fun=spk_fun, detach_spk=detach_spk, method=method, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, scaling=scaling) @@ -257,7 +257,7 @@ def derivative(self, V, t, I): def reset_state(self, batch_size=None, **kwargs): self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) def update(self, x=None): t = share.load('t') @@ -399,7 +399,7 @@ def __init__( keep_size: bool = False, mode: Optional[bm.Mode] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, detach_spk: bool = False, spk_reset: str = 'soft', method: str = 'exp_auto', @@ -429,7 +429,7 @@ def __init__( sharding=sharding, spk_fun=spk_fun, detach_spk=detach_spk, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, init_var=False, @@ -673,7 +673,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -699,7 +699,7 @@ def __init__( spk_fun=spk_fun, detach_spk=detach_spk, method=method, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, scaling=scaling) @@ -730,7 +730,7 @@ def derivative(self, V, t, I): def reset_state(self, batch_size=None, **kwargs): self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) def update(self, x=None): t = share.load('t') @@ -1001,7 +1001,7 @@ def __init__( keep_size: bool = False, mode: Optional[bm.Mode] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, detach_spk: bool = False, spk_reset: str = 'soft', method: str = 'exp_auto', @@ -1033,7 +1033,7 @@ def __init__( sharding=sharding, spk_fun=spk_fun, detach_spk=detach_spk, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, init_var=False, @@ -1343,7 +1343,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -1373,7 +1373,7 @@ def __init__( spk_fun=spk_fun, detach_spk=detach_spk, method=method, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, scaling=scaling) # parameters @@ -1416,7 +1416,7 @@ def derivative(self): def reset_state(self, batch_size=None, **kwargs): self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) self.w = self.std_scaling(self.init_variable(self._w_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) def update(self, x=None): t = share.load('t') @@ -1672,7 +1672,7 @@ def __init__( keep_size: bool = False, mode: Optional[bm.Mode] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -1708,7 +1708,7 @@ def __init__( sharding=sharding, spk_fun=spk_fun, detach_spk=detach_spk, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, init_var=False, @@ -1991,7 +1991,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -2017,7 +2017,7 @@ def __init__( spk_fun=spk_fun, detach_spk=detach_spk, method=method, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, scaling=scaling) # parameters @@ -2046,7 +2046,7 @@ def derivative(self, V, t, I): def reset_state(self, batch_size=None, **kwargs): self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) def update(self, x=None): t = share.load('t') @@ -2255,7 +2255,7 @@ def __init__( keep_size: bool = False, mode: Optional[bm.Mode] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -2287,7 +2287,7 @@ def __init__( sharding=sharding, spk_fun=spk_fun, detach_spk=detach_spk, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, init_var=False, @@ -2554,7 +2554,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -2583,7 +2583,7 @@ def __init__( spk_fun=spk_fun, detach_spk=detach_spk, method=method, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, scaling=scaling) # parameters @@ -2624,7 +2624,7 @@ def derivative(self): def reset_state(self, batch_size=None, **kwargs): self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) self.w = self.std_scaling(self.init_variable(self._w_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) def update(self, x=None): t = share.load('t') @@ -2856,7 +2856,7 @@ def __init__( keep_size: bool = False, mode: Optional[bm.Mode] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -2891,7 +2891,7 @@ def __init__( sharding=sharding, spk_fun=spk_fun, detach_spk=detach_spk, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, init_var=False, @@ -3201,7 +3201,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -3237,7 +3237,7 @@ def __init__( spk_fun=spk_fun, detach_spk=detach_spk, method=method, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, scaling=scaling) # parameters @@ -3291,7 +3291,7 @@ def reset_state(self, batch_size=None, **kwargs): self.V_th = self.offset_scaling(self.init_variable(self._Vth_initializer, batch_size)) self.I1 = self.std_scaling(self.init_variable(self._I1_initializer, batch_size)) self.I2 = self.std_scaling(self.init_variable(self._I2_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) def update(self, x=None): t = share.load('t') @@ -3581,7 +3581,7 @@ def __init__( keep_size: bool = False, mode: Optional[bm.Mode] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -3623,7 +3623,7 @@ def __init__( sharding=sharding, spk_fun=spk_fun, detach_spk=detach_spk, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, init_var=False, @@ -3952,7 +3952,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -3982,7 +3982,7 @@ def __init__( spk_fun=spk_fun, detach_spk=detach_spk, method=method, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, scaling=scaling) # parameters @@ -4031,7 +4031,7 @@ def reset_state(self, batch_size=None, **kwargs): self.V = self.offset_scaling(self.V) self.u = self.offset_scaling(self.init_variable(self._u_initializer, batch_size), bias=self.b * self.scaling.bias, scale=self.scaling.scale) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) def update(self, x=None): t = share.load('t') @@ -4266,7 +4266,7 @@ def __init__( keep_size: bool = False, mode: Optional[bm.Mode] = None, spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_type: Any = None, + spk_dtype: Any = None, spk_reset: str = 'soft', detach_spk: bool = False, method: str = 'exp_auto', @@ -4302,7 +4302,7 @@ def __init__( sharding=sharding, spk_fun=spk_fun, detach_spk=detach_spk, - spk_type=spk_type, + spk_dtype=spk_dtype, spk_reset=spk_reset, init_var=False, diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index 0c9bf8f54..b5d12d9ce 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -1518,6 +1518,8 @@ def float(self): return jnp.asarray(self.value, dtype=jnp.float32) def double(self): return jnp.asarray(self.value, dtype=jnp.float64) +setattr(Array, "__array_priority__", 100) + JaxArray = Array ndarray = Array diff --git a/brainpy/_src/math/tests/test_jaxarray.py b/brainpy/_src/math/tests/test_ndarray.py similarity index 89% rename from brainpy/_src/math/tests/test_jaxarray.py rename to brainpy/_src/math/tests/test_ndarray.py index 9a227a071..09a6f791c 100644 --- a/brainpy/_src/math/tests/test_jaxarray.py +++ b/brainpy/_src/math/tests/test_ndarray.py @@ -111,3 +111,14 @@ def test_update(self): ) self.assertTrue(view.sum() == bm.sum(bm.arange(5) + 10)) + + +class TestArrayPriority(unittest.TestCase): + def test1(self): + a = bm.Array(bm.zeros(10)) + assert isinstance(a + bm.ones(1).value, bm.Array) + assert isinstance(a + np.ones(1), bm.Array) + assert isinstance(a * np.ones(1), bm.Array) + assert isinstance(np.ones(1) + a, bm.Array) + assert isinstance(np.ones(1) * a, bm.Array) +