From c489fff85faa9d83eb63f2dcb9917e153b00a29c Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 19 Dec 2024 10:43:35 +0800 Subject: [PATCH] updates --- brainstate/compile/_progress_bar.py | 2 + brainstate/event/_csr.py | 362 +++++++++++++++++++++- brainstate/graph/_graph_node.py | 5 +- brainstate/graph/_graph_operation.py | 14 +- brainstate/nn/_dyn_impl/_inputs.py | 129 +++++++- brainstate/nn/_dynamics/_dynamics_base.py | 10 + brainstate/nn/_interaction/_linear.py | 15 +- brainstate/random/_rand_funs.py | 10 +- pyproject.toml | 2 +- requirements.txt | 2 +- setup.py | 2 +- 11 files changed, 515 insertions(+), 38 deletions(-) diff --git a/brainstate/compile/_progress_bar.py b/brainstate/compile/_progress_bar.py index 8db3baa..a2a2430 100644 --- a/brainstate/compile/_progress_bar.py +++ b/brainstate/compile/_progress_bar.py @@ -35,6 +35,8 @@ class ProgressBar(object): def __init__(self, freq: Optional[int] = None, count: Optional[int] = None, **kwargs): self.print_freq = freq + if isinstance(freq, int): + assert freq > 0, "Print rate should be > 0." self.print_count = count if self.print_freq is not None and self.print_count is not None: raise ValueError("Cannot specify both count and freq.") diff --git a/brainstate/event/_csr.py b/brainstate/event/_csr.py index 67178d9..137b4a9 100644 --- a/brainstate/event/_csr.py +++ b/brainstate/event/_csr.py @@ -13,13 +13,21 @@ # limitations under the License. # ============================================================================== + +from __future__ import annotations + +import operator from typing import Callable import brainunit as u import jax import jax.numpy as jnp -from brainunit.sparse.csr import _csr_matvec as csr_matvec, _csr_matmat as csr_matmat -from brainunit.sparse.csr import _csr_to_coo as csr_to_coo +import numpy as np +from brainunit.sparse._csr import ( + _csr_matvec as csr_matvec, + _csr_matmat as csr_matmat, + _csr_to_coo as csr_to_coo +) from jax.experimental.sparse import JAXSparse from jax.interpreters import ad @@ -32,10 +40,148 @@ ] -class CSR(u.sparse.CSR): +@jax.tree_util.register_pytree_node_class +class CSR(u.sparse.SparseMatrix): """ - Event-driven sparse matrix in CSR format. + Event-driven and Unit-aware CSR matrix. """ + data: jax.Array | u.Quantity + indices: jax.Array + indptr: jax.Array + shape: tuple[int, int] + nse = property(lambda self: self.data.size) + dtype = property(lambda self: self.data.dtype) + _bufs = property(lambda self: (self.data, self.indices, self.indptr)) + + def __init__(self, args, *, shape): + self.data, self.indices, self.indptr = map(u.math.asarray, args) + super().__init__(args, shape=shape) + + @classmethod + def fromdense(cls, mat, *, nse=None, index_dtype=np.int32): + if nse is None: + nse = (u.get_mantissa(mat) != 0).sum() + return u.sparse.csr_fromdense(mat, nse=nse, index_dtype=index_dtype) + + def with_data(self, data: jax.Array | u.Quantity) -> CSR: + assert data.shape == self.data.shape + assert data.dtype == self.data.dtype + assert u.get_unit(data) == u.get_unit(self.data) + return CSR((data, self.indices, self.indptr), shape=self.shape) + + def todense(self): + return u.sparse.csr_todense(self) + + def transpose(self, axes=None): + assert axes is None + return CSC((self.data, self.indices, self.indptr), shape=self.shape[::-1]) + + def __abs__(self): + return CSR((abs(self.data), self.indices, self.indptr), shape=self.shape) + + def __neg__(self): + return CSR((-self.data, self.indices, self.indptr), shape=self.shape) + + def __pos__(self): + return CSR((self.data.__pos__(), self.indices, self.indptr), shape=self.shape) + + def _binary_op(self, other, op): + if isinstance(other, CSR): + if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr): + return CSR( + (op(self.data, other.data), + self.indices, + self.indptr), + shape=self.shape + ) + if isinstance(other, JAXSparse): + raise NotImplementedError(f"binary operation {op} between two sparse objects.") + + other = u.math.asarray(other) + if other.size == 1: + return CSR( + (op(self.data, other), self.indices, self.indptr), + shape=self.shape + ) + elif other.ndim == 2 and other.shape == self.shape: + rows, cols = csr_to_coo(self.indices, self.indptr) + other = other[rows, cols] + return CSR( + (op(self.data, other), + self.indices, + self.indptr), + shape=self.shape + ) + else: + raise NotImplementedError(f"mul with object of shape {other.shape}") + + def _binary_rop(self, other, op): + if isinstance(other, CSR): + if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr): + return CSR( + (op(other.data, self.data), + self.indices, + self.indptr), + shape=self.shape + ) + if isinstance(other, JAXSparse): + raise NotImplementedError(f"binary operation {op} between two sparse objects.") + + other = u.math.asarray(other) + if other.size == 1: + return CSR( + (op(other, self.data), + self.indices, + self.indptr), + shape=self.shape + ) + elif other.ndim == 2 and other.shape == self.shape: + rows, cols = csr_to_coo(self.indices, self.indptr) + other = other[rows, cols] + return CSR( + (op(other, self.data), + self.indices, + self.indptr), + shape=self.shape + ) + else: + raise NotImplementedError(f"mul with object of shape {other.shape}") + + def __mul__(self, other: jax.Array | u.Quantity) -> CSR: + return self._binary_op(other, operator.mul) + + def __rmul__(self, other: jax.Array | u.Quantity) -> CSR: + return self._binary_rop(other, operator.mul) + + def __div__(self, other: jax.Array | u.Quantity) -> CSR: + return self._binary_op(other, operator.truediv) + + def __rdiv__(self, other: jax.Array | u.Quantity) -> CSR: + return self._binary_rop(other, operator.truediv) + + def __truediv__(self, other) -> CSR: + return self.__div__(other) + + def __rtruediv__(self, other) -> CSR: + return self.__rdiv__(other) + + def __add__(self, other) -> CSR: + return self._binary_op(other, operator.add) + + def __radd__(self, other) -> CSR: + return self._binary_rop(other, operator.add) + + def __sub__(self, other) -> CSR: + return self._binary_op(other, operator.sub) + + def __rsub__(self, other) -> CSR: + return self._binary_rop(other, operator.sub) + + def __mod__(self, other) -> CSR: + return self._binary_op(other, operator.mod) + + def __rmod__(self, other) -> CSR: + return self._binary_rop(other, operator.mod) def __matmul__(self, other): if isinstance(other, JAXSparse): @@ -89,11 +235,177 @@ def __rmatmul__(self, other): else: raise NotImplementedError(f"matmul with object of shape {other.shape}") + def tree_flatten(self): + return (self.data,), {"shape": self.shape, "indices": self.indices, "indptr": self.indptr} + + @classmethod + def tree_unflatten(cls, aux_data, children): + obj = object.__new__(cls) + obj.data, = children + if aux_data.keys() != {'shape', 'indices', 'indptr'}: + raise ValueError(f"CSR.tree_unflatten: invalid {aux_data=}") + obj.__dict__.update(**aux_data) + return obj -class CSC(u.sparse.CSC): + +@jax.tree_util.register_pytree_node_class +class CSC(u.sparse.SparseMatrix): """ - Event-driven sparse matrix in CSC format. + Event-driven and Unit-aware CSC matrix. """ + data: jax.Array + indices: jax.Array + indptr: jax.Array + shape: tuple[int, int] + nse = property(lambda self: self.data.size) + dtype = property(lambda self: self.data.dtype) + + def __init__(self, args, *, shape): + self.data, self.indices, self.indptr = map(u.math.asarray, args) + super().__init__(args, shape=shape) + + @classmethod + def fromdense(cls, mat, *, nse=None, index_dtype=np.int32): + if nse is None: + nse = (u.get_mantissa(mat) != 0).sum() + return u.sparse.csr_fromdense(mat.T, nse=nse, index_dtype=index_dtype).T + + @classmethod + def _empty(cls, shape, *, dtype=None, index_dtype='int32'): + """Create an empty CSC instance. Public method is sparse.empty().""" + shape = tuple(shape) + if len(shape) != 2: + raise ValueError(f"CSC must have ndim=2; got {shape=}") + data = jnp.empty(0, dtype) + indices = jnp.empty(0, index_dtype) + indptr = jnp.zeros(shape[1] + 1, index_dtype) + return cls((data, indices, indptr), shape=shape) + + @classmethod + def _eye(cls, N, M, k, *, dtype=None, index_dtype='int32'): + return CSR._eye(M, N, -k, dtype=dtype, index_dtype=index_dtype).T + + def with_data(self, data: jax.Array | u.Quantity) -> CSC: + assert data.shape == self.data.shape + assert data.dtype == self.data.dtype + assert u.get_unit(data) == u.get_unit(self.data) + return CSC((data, self.indices, self.indptr), shape=self.shape) + + def todense(self): + return u.sparse.csr_todense(self.T).T + + def transpose(self, axes=None): + assert axes is None + return CSR((self.data, self.indices, self.indptr), shape=self.shape[::-1]) + + def __abs__(self): + return CSC((abs(self.data), self.indices, self.indptr), shape=self.shape) + + def __neg__(self): + return CSC((-self.data, self.indices, self.indptr), shape=self.shape) + + def __pos__(self): + return CSC((self.data.__pos__(), self.indices, self.indptr), shape=self.shape) + + def _binary_op(self, other, op): + if isinstance(other, CSC): + if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr): + return CSC( + (op(self.data, other.data), + self.indices, + self.indptr), + shape=self.shape + ) + if isinstance(other, JAXSparse): + raise NotImplementedError(f"binary operation {op} between two sparse objects.") + + other = u.math.asarray(other) + if other.size == 1: + return CSC( + (op(self.data, other), + self.indices, + self.indptr), + shape=self.shape + ) + elif other.ndim == 2 and other.shape == self.shape: + cols, rows = csr_to_coo(self.indices, self.indptr) + other = other[rows, cols] + return CSC( + (op(self.data, other), + self.indices, + self.indptr), + shape=self.shape + ) + else: + raise NotImplementedError(f"mul with object of shape {other.shape}") + + def _binary_rop(self, other, op): + if isinstance(other, CSC): + if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr): + return CSC( + (op(other.data, self.data), + self.indices, + self.indptr), + shape=self.shape + ) + if isinstance(other, JAXSparse): + raise NotImplementedError(f"binary operation {op} between two sparse objects.") + + other = u.math.asarray(other) + if other.size == 1: + return CSC( + (op(other, self.data), + self.indices, + self.indptr), + shape=self.shape + ) + elif other.ndim == 2 and other.shape == self.shape: + cols, rows = csr_to_coo(self.indices, self.indptr) + other = other[rows, cols] + return CSC( + (op(other, self.data), + self.indices, + self.indptr), + shape=self.shape + ) + else: + raise NotImplementedError(f"mul with object of shape {other.shape}") + + def __mul__(self, other: jax.Array | u.Quantity) -> 'CSC': + return self._binary_op(other, operator.mul) + + def __rmul__(self, other: jax.Array | u.Quantity) -> 'CSC': + return self._binary_rop(other, operator.mul) + + def __div__(self, other: jax.Array | u.Quantity) -> CSC: + return self._binary_op(other, operator.truediv) + + def __rdiv__(self, other: jax.Array | u.Quantity) -> CSC: + return self._binary_rop(other, operator.truediv) + + def __truediv__(self, other) -> CSC: + return self.__div__(other) + + def __rtruediv__(self, other) -> CSC: + return self.__rdiv__(other) + + def __add__(self, other) -> CSC: + return self._binary_op(other, operator.add) + + def __radd__(self, other) -> CSC: + return self._binary_rop(other, operator.add) + + def __sub__(self, other) -> CSC: + return self._binary_op(other, operator.sub) + + def __rsub__(self, other) -> CSC: + return self._binary_rop(other, operator.sub) + + def __mod__(self, other) -> CSC: + return self._binary_op(other, operator.mod) + + def __rmod__(self, other) -> CSC: + return self._binary_rop(other, operator.mod) def __matmul__(self, other): if isinstance(other, JAXSparse): @@ -148,6 +460,18 @@ def __rmatmul__(self, other): else: raise NotImplementedError(f"matmul with object of shape {other.shape}") + def tree_flatten(self): + return (self.data,), {"shape": self.shape, "indices": self.indices, "indptr": self.indptr} + + @classmethod + def tree_unflatten(cls, aux_data, children): + obj = object.__new__(cls) + obj.data, = children + if aux_data.keys() != {'shape', 'indices', 'indptr'}: + raise ValueError(f"CSR.tree_unflatten: invalid {aux_data=}") + obj.__dict__.update(**aux_data) + return obj + def _csr_matvec( data: jax.Array | u.Quantity, @@ -177,7 +501,6 @@ def _csr_matvec( """ data, unitd = u.split_mantissa_unit(data) v, unitv = u.split_mantissa_unit(v) - # res = csr_matvec_p.bind(data, indices, indptr, v, shape=shape, transpose=transpose) res = event_csrmv_p_call( data, indices, indptr, v, shape=shape, @@ -194,7 +517,8 @@ def _csr_matmat( B: jax.Array | u.Quantity, *, shape: Shape, - transpose: bool = False + transpose: bool = False, + float_as_event: bool = True, ) -> jax.Array | u.Quantity: """ Product of CSR sparse matrix and a dense matrix. @@ -215,7 +539,15 @@ def _csr_matmat( """ data, unitd = u.split_mantissa_unit(data) B, unitb = u.split_mantissa_unit(B) - res = csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose) + res = event_csrmm_p_call( + data, + indices, + indptr, + B, + shape=shape, + transpose=transpose, + float_as_event=float_as_event, + )[0] return u.maybe_decimal(res * (unitd * unitb)) @@ -483,9 +815,9 @@ def event_csrmv_p_call( indptr, v, *, - shape, - transpose, - float_as_event, + shape: Shape, + transpose: bool, + float_as_event: bool, ): if jax.default_backend() == 'cpu': return event_csrmv_p( @@ -541,9 +873,9 @@ def event_csrmm_p_call( indptr, B, *, - shape, - transpose, - float_as_event, + shape: Shape, + transpose: bool, + float_as_event: bool, ): if jax.default_backend() == 'cpu': return event_csrmm_p( diff --git a/brainstate/graph/_graph_node.py b/brainstate/graph/_graph_node.py index f882ef1..0b06395 100644 --- a/brainstate/graph/_graph_node.py +++ b/brainstate/graph/_graph_node.py @@ -173,8 +173,9 @@ def _to_shape_dtype(value): def _node_flatten( node: Node ) -> Tuple[Tuple[Tuple[str, Any], ...], Tuple[Type]]: - graph_invisible_attrs = getattr(node, 'graph_invisible_attrs', ()) - graph_invisible_attrs = tuple(graph_invisible_attrs) + ('_trace_state',) + # graph_invisible_attrs = getattr(node, 'graph_invisible_attrs', ()) + # graph_invisible_attrs = tuple(graph_invisible_attrs) + ('_trace_state',) + graph_invisible_attrs = ('_trace_state',) nodes = sorted( (key, value) for key, value in vars(node).items() if (key not in graph_invisible_attrs) diff --git a/brainstate/graph/_graph_operation.py b/brainstate/graph/_graph_operation.py index 554ba70..dcc95bb 100644 --- a/brainstate/graph/_graph_operation.py +++ b/brainstate/graph/_graph_operation.py @@ -608,9 +608,9 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache): if isinstance(value, TreefyState): variable.update_from_ref(value) elif isinstance(value, State): - if value._been_writen: + if value._been_writen: variable.write_value(value.value) - else: + else: variable.restore_value(value.value) else: raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.') @@ -1600,10 +1600,12 @@ def _iter_graph_leaf( visited_.add(id(node_)) node_dict = _get_node_impl(node_).node_dict(node_) for key, value in node_dict.items(): - yield from _iter_graph_leaf(value, - visited_, - (*path_parts_, key), - level_ + 1 if _is_graph_node(value) else level_) + yield from _iter_graph_leaf( + value, + visited_, + (*path_parts_, key), + level_ + 1 if _is_graph_node(value) else level_ + ) else: if level_ >= allowed_hierarchy[0]: yield path_parts_, node_ diff --git a/brainstate/nn/_dyn_impl/_inputs.py b/brainstate/nn/_dyn_impl/_inputs.py index b98c46d..6e882a1 100644 --- a/brainstate/nn/_dyn_impl/_inputs.py +++ b/brainstate/nn/_dyn_impl/_inputs.py @@ -17,17 +17,23 @@ from typing import Union, Optional, Sequence, Callable import brainunit as u +import jax +import numpy as np from brainstate import environ, init, random from brainstate._state import ShortTermState -from brainstate.compile import while_loop -from brainstate.nn._dynamics._dynamics_base import Dynamics +from brainstate._state import State +from brainstate.compile import while_loop, cond +from brainstate.nn._dynamics._dynamics_base import Dynamics, Prefetch +from brainstate.nn._module import Module from brainstate.typing import ArrayLike, Size, DTypeLike __all__ = [ 'SpikeTime', 'PoissonSpike', 'PoissonEncoder', + 'PoissonInput', + 'poisson_input', ] @@ -152,3 +158,122 @@ def update(self, freqs: ArrayLike): spikes = random.rand(*self.varshape) <= (freqs * environ.get_dt()) spikes = u.math.asarray(spikes, dtype=self.spk_type) return spikes + + +class PoissonInput(Module): + """ + Poisson Input to the given :py:class:`brainstate.State`. + + Adds independent Poisson input to a target variable. For large + numbers of inputs, this is much more efficient than creating a + `PoissonGroup`. The synaptic events are generated randomly during the + simulation and are not preloaded and stored in memory. All the inputs must + target the same variable, have the same frequency and same synaptic weight. + All neurons in the target variable receive independent realizations of + Poisson spike trains. + + Args: + target: The variable that is targeted by this input. Should be an instance of :py:class:`~.Variable`. + num_input: The number of inputs. + freq: The frequency of each of the inputs. Must be a scalar. + weight: The synaptic weight. Must be a scalar. + name: The target name. + """ + + def __init__( + self, + target: Prefetch, + indices: Union[np.ndarray, jax.Array], + num_input: int, + freq: Union[int, float], + weight: Union[int, float], + name: Optional[str] = None, + ): + super().__init__(name=name) + + self.target = target + self.indices = indices + self.num_input = num_input + self.freq = freq + self.weight = weight + + def update(self): + p = self.freq * environ.get_dt() + a = self.num_input * p + b = self.num_input * (1 - p) + + target = self.target() + target_state = getattr(self.target.module, self.target.item) + + # generate Poisson input + inp = cond( + u.math.logical_and(a > 5, b > 5), + lambda: random.normal(a, b * p, self.indices.shape), + lambda: random.binomial(self.num_input, p, self.indices.shape).astype(float) + ) + + # update target variable + target_state.value = target.at[self.indices].add(inp * self.weight) + + +def poisson_input( + freq: ArrayLike, + num_input: int, + weight: ArrayLike, + target: State, + indices: Optional[Union[np.ndarray, jax.Array]] = None, +): + """ + Poisson Input to the given :py:class:`brainstate.State`. + """ + assert isinstance(target, State), 'The target must be a State.' + p = freq * environ.get_dt() + a = num_input * p + b = num_input * (1 - p) + tar_val = target.value + if indices is None: + # generate Poisson input + inp = cond( + u.math.logical_and(a > 5, b > 5), + lambda: jax.tree.map( + lambda tar: random.normal(a, b * p, tar.shape), + tar_val, + is_leaf=u.math.is_quantity + ), + lambda: jax.tree.map( + lambda tar: random.binomial(num_input, p, tar.shape).astype(float), + tar_val, + is_leaf=u.math.is_quantity + ) + ) + + # update target variable + target.value = jax.tree.map( + lambda x: x * weight, + inp, + is_leaf=u.math.is_quantity + ) + + else: + # generate Poisson input + inp = cond( + u.math.logical_and(a > 5, b > 5), + lambda: jax.tree.map( + lambda tar: random.normal(a, b * p, tar[indices].shape), + tar_val, + is_leaf=u.math.is_quantity + ), + lambda: jax.tree.map( + lambda tar: random.binomial(num_input, p, tar[indices].shape).astype(float), + tar_val, + is_leaf=u.math.is_quantity + ) + ) + + # update target variable + target.value = jax.tree.map( + lambda x, tar: tar.at[indices].add(x * weight), + inp, + tar_val, + is_leaf=u.math.is_quantity + ) diff --git a/brainstate/nn/_dynamics/_dynamics_base.py b/brainstate/nn/_dynamics/_dynamics_base.py index 2c181d3..158a737 100644 --- a/brainstate/nn/_dynamics/_dynamics_base.py +++ b/brainstate/nn/_dynamics/_dynamics_base.py @@ -445,6 +445,16 @@ def __call__(self, *args, **kwargs): item = _get_prefetch_item(self) return item.value if isinstance(item, State) else item + def get_item_value(self): + item = _get_prefetch_item(self) + return item.value if isinstance(item, State) else item + + def get_item(self): + """ + Get + """ + return _get_prefetch_item(self) + class PrefetchDelay(Node): def __init__(self, module: Dynamics, item: str): diff --git a/brainstate/nn/_interaction/_linear.py b/brainstate/nn/_interaction/_linear.py index d897ebb..9cc996f 100644 --- a/brainstate/nn/_interaction/_linear.py +++ b/brainstate/nn/_interaction/_linear.py @@ -199,7 +199,7 @@ class SparseLinear(Module): ``brainunit.sparse.CSC``, ``brainunit.sparse.COO``, or any other sparse matrix). Args: - weight: SparseMatrix. The sparse weight matrix. + spar_mat: SparseMatrix. The sparse weight matrix. in_size: Size. The input size. name: str. The object name. """ @@ -207,7 +207,7 @@ class SparseLinear(Module): def __init__( self, - weight: u.sparse.SparseMatrix, + spar_mat: u.sparse.SparseMatrix, b_init: Optional[Union[Callable, ArrayLike]] = None, in_size: Size = None, name: Optional[str] = None, @@ -217,7 +217,7 @@ def __init__( # input and output shape if in_size is not None: self.in_size = in_size - self.out_size = weight.shape[-1] + self.out_size = spar_mat.shape[-1] if in_size is not None: assert self.in_size[:-1] == self.out_size[:-1], ( 'The first n-1 dimensions of "in_size" ' @@ -225,15 +225,16 @@ def __init__( ) # weights - assert isinstance(weight, u.sparse.SparseMatrix), '"weight" must be a SparseMatrix.' - params = dict(weight=weight) + assert isinstance(spar_mat, u.sparse.SparseMatrix), '"weight" must be a SparseMatrix.' + self.spar_mat = spar_mat + params = dict(weight=spar_mat.data) if b_init is not None: params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False) self.weight = ParamState(params) def update(self, x): - weight = self.weight.value['weight'] - y = x @ weight + data = self.weight.value['weight'] + y = x @ self.spar_mat.with_data(data) if 'bias' in self.weight.value: y = y + self.weight.value['bias'] return y diff --git a/brainstate/random/_rand_funs.py b/brainstate/random/_rand_funs.py index 3749991..7738f54 100644 --- a/brainstate/random/_rand_funs.py +++ b/brainstate/random/_rand_funs.py @@ -959,9 +959,13 @@ def logistic(loc=None, scale=None, size: Optional[Size] = None, return DEFAULT.logistic(loc, scale, size, key=key, dtype=dtype) -def normal(loc=None, scale=None, size: Optional[Size] = None, - key: Optional[SeedOrKey] = None, - dtype: DTypeLike = None): +def normal( + loc=None, + scale=None, + size: Optional[Size] = None, + key: Optional[SeedOrKey] = None, + dtype: DTypeLike = None +): r""" Draw random samples from a normal (Gaussian) distribution. diff --git a/pyproject.toml b/pyproject.toml index 6546884..8149b36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dependencies = [ 'jax', 'jaxlib', 'numpy', - 'brainunit>=0.0.3.post20241214', + 'brainunit>=0.0.4', ] dynamic = ['version'] diff --git a/requirements.txt b/requirements.txt index 666390b..a37929a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ numpy jax jaxlib -brainunit>=0.0.3 +brainunit>=0.0.4 diff --git a/setup.py b/setup.py index 6c8d6a8..48648d1 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.9', - install_requires=['numpy>=1.15', 'jax', 'tqdm', 'brainunit>=0.0.3.post20241214'], + install_requires=['numpy>=1.15', 'jax', 'tqdm', 'brainunit>=0.0.4'], url='https://github.com/chaobrain/brainstate', project_urls={ "Bug Tracker": "https://github.com/chaobrain/brainstate/issues",