diff --git a/README.md b/README.md index d56e18a..53d76df 100644 --- a/README.md +++ b/README.md @@ -37,12 +37,14 @@ The official documentation is hosted on Read the Docs: [https://brainstate.readt ## See also the BDP ecosystem -- [``brainpy``](https://github.com/brainpy/BrainPy): The solution for the general-purpose brain dynamics programming. +- [``brainstate``](https://github.com/brainpy/brainstate): A ``State``-based transformation system for brain dynamics programming. -- [``brainstate``](https://github.com/brainpy/brainstate): The core system for the next generation of BrainPy framework. +- [``brainunit``](https://github.com/brainpy/brainunit): The unit system for brain dynamics programming. -- [``braintools``](https://github.com/brainpy/braintools): The tools for the brain dynamics simulation and analysis. +- [``braintaichi``](https://github.com/brainpy/braintaichi): Leveraging Taichi Lang to customize brain dynamics operators. -- [``brainscale``](https://github.com/brainpy/brainscale): The scalable online learning for biological spiking neural networks. +- [``brainscale``](https://github.com/brainpy/brainscale): The scalable online learning framework for biological neural networks. + +- [``braintools``](https://github.com/brainpy/braintools): The toolbox for the brain dynamics simulation, training and analysis. diff --git a/brainstate/environ.py b/brainstate/environ.py index 0047c3a..60f011e 100644 --- a/brainstate/environ.py +++ b/brainstate/environ.py @@ -18,7 +18,8 @@ __all__ = [ 'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform', - 'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision', + 'get_host_device_count', 'get_platform', + 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision', 'tolerance', 'dftype', 'ditype', 'dutype', 'dctype', ] diff --git a/brainstate/functional/_activations.py b/brainstate/functional/_activations.py index bafe5c6..39fa118 100644 --- a/brainstate/functional/_activations.py +++ b/brainstate/functional/_activations.py @@ -27,7 +27,7 @@ from jax.scipy.special import logsumexp from jax.typing import ArrayLike -from brainstate import math, random +from .. import math, random __all__ = [ "tanh", diff --git a/brainstate/functional/_normalization.py b/brainstate/functional/_normalization.py index 39b23e9..7c52610 100644 --- a/brainstate/functional/_normalization.py +++ b/brainstate/functional/_normalization.py @@ -20,11 +20,14 @@ import jax import jax.numpy as jnp +from .._utils import set_module_as + __all__ = [ 'weight_standardization', ] +@set_module_as('brainstate.functional') def weight_standardization( w: jax.typing.ArrayLike, eps: float = 1e-4, diff --git a/brainstate/optim/_lr_scheduler_test.py b/brainstate/optim/_lr_scheduler_test.py index 01a9988..10cdbed 100644 --- a/brainstate/optim/_lr_scheduler_test.py +++ b/brainstate/optim/_lr_scheduler_test.py @@ -34,3 +34,16 @@ def test1(self): self.assertTrue(jnp.allclose(r, 0.001)) else: self.assertTrue(jnp.allclose(r, 0.0001)) + + def test2(self): + lr = bst.transform.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)) + for i in range(40): + r = lr(i) + if i < 10: + self.assertEqual(r, 0.1) + elif i < 20: + self.assertTrue(jnp.allclose(r, 0.01)) + elif i < 30: + self.assertTrue(jnp.allclose(r, 0.001)) + else: + self.assertTrue(jnp.allclose(r, 0.0001)) diff --git a/brainstate/transform/_jit.py b/brainstate/transform/_jit.py index 3330eaf..7573309 100644 --- a/brainstate/transform/_jit.py +++ b/brainstate/transform/_jit.py @@ -23,8 +23,8 @@ from jax._src import sharding_impls from jax.lib import xla_client as xc -from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values from brainstate._utils import set_module_as +from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values __all__ = ['jit'] @@ -33,10 +33,13 @@ class JittedFunction(Callable): """ A wrapped version of ``fun``, set up for just-in-time compilation. """ - origin_fun: Callable # the original function + origin_fun: Callable # the original function stateful_fun: StatefulFunction # the stateful function for extracting states jitted_fun: jax.stages.Wrapped # the jitted function - clear_cache: Callable # clear the cache of the jitted function + clear_cache: Callable # clear the cache of the jitted function + + def __call__(self, *args, **kwargs): + pass def _get_jitted_fun( @@ -85,12 +88,16 @@ def clear_cache(): jit_fun.clear_cache() jitted_fun: JittedFunction + # the original function jitted_fun.origin_fun = fun.fun + # the stateful function for extracting states jitted_fun.stateful_fun = fun + # the jitted function jitted_fun.jitted_fun = jit_fun + # clear cache jitted_fun.clear_cache = clear_cache @@ -99,18 +106,18 @@ def clear_cache(): @set_module_as('brainstate.transform') def jit( - fun: Callable = None, - in_shardings=sharding_impls.UNSPECIFIED, - out_shardings=sharding_impls.UNSPECIFIED, - static_argnums: int | Sequence[int] | None = None, - donate_argnums: int | Sequence[int] | None = None, - donate_argnames: str | Iterable[str] | None = None, - keep_unused: bool = False, - device: xc.Device | None = None, - backend: str | None = None, - inline: bool = False, - abstracted_axes: Any | None = None, - **kwargs + fun: Callable = None, + in_shardings=sharding_impls.UNSPECIFIED, + out_shardings=sharding_impls.UNSPECIFIED, + static_argnums: int | Sequence[int] | None = None, + donate_argnums: int | Sequence[int] | None = None, + donate_argnames: str | Iterable[str] | None = None, + keep_unused: bool = False, + device: xc.Device | None = None, + backend: str | None = None, + inline: bool = False, + abstracted_axes: Any | None = None, + **kwargs ) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]: """ Sets up ``fun`` for just-in-time compilation with XLA. @@ -228,12 +235,31 @@ def jit( if fun is None: def wrapper(fun_again: Callable) -> JittedFunction: - return _get_jitted_fun(fun_again, in_shardings, out_shardings, static_argnums, - donate_argnums, donate_argnames, keep_unused, - device, backend, inline, abstracted_axes, **kwargs) + return _get_jitted_fun(fun_again, + in_shardings, + out_shardings, + static_argnums, + donate_argnums, + donate_argnames, + keep_unused, + device, + backend, + inline, + abstracted_axes, + **kwargs) + return wrapper else: - return _get_jitted_fun(fun, in_shardings, out_shardings, static_argnums, - donate_argnums, donate_argnames, keep_unused, - device, backend, inline, abstracted_axes, **kwargs) + return _get_jitted_fun(fun, + in_shardings, + out_shardings, + static_argnums, + donate_argnums, + donate_argnames, + keep_unused, + device, + backend, + inline, + abstracted_axes, + **kwargs) diff --git a/docs/_static/braincore.jpg b/docs/_static/braincore.jpg deleted file mode 100644 index 6eca184..0000000 Binary files a/docs/_static/braincore.jpg and /dev/null differ diff --git a/docs/api.rst b/docs/api.rst index 84d4b8f..4fb71a9 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -5,10 +5,14 @@ API Documentation :maxdepth: 1 apis/changelog.md - apis/braincore.rst + apis/brainstate.rst + apis/init.rst apis/math.rst apis/mixin.rst + apis/optim.rst apis/random.rst apis/surrogate.rst apis/transform.rst apis/util.rst + apis/brainstate.nn.rst + apis/brainstate.functional.rst diff --git a/docs/apis/brainstate.functional.rst b/docs/apis/brainstate.functional.rst new file mode 100644 index 0000000..f8fd9c9 --- /dev/null +++ b/docs/apis/brainstate.functional.rst @@ -0,0 +1,76 @@ +``brainstate.functional`` module +================================ + +.. currentmodule:: brainstate.functional +.. automodule:: brainstate.functional + +Activation Functions +-------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + tanh + relu + squareplus + softplus + soft_sign + sigmoid + silu + swish + log_sigmoid + elu + leaky_relu + hard_tanh + celu + selu + gelu + glu + logsumexp + log_softmax + softmax + standardize + one_hot + relu6 + hard_sigmoid + hard_silu + hard_swish + hard_shrink + rrelu + mish + soft_shrink + prelu + tanh_shrink + softmin + + +Normalization Functions +----------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + weight_standardization + + +Spiking Operations +------------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + spike_bitwise_or + spike_bitwise_and + spike_bitwise_iand + spike_bitwise_not + spike_bitwise_xor + spike_bitwise_ixor + spike_bitwise + + diff --git a/docs/apis/brainstate.nn.rst b/docs/apis/brainstate.nn.rst new file mode 100644 index 0000000..72323e1 --- /dev/null +++ b/docs/apis/brainstate.nn.rst @@ -0,0 +1,215 @@ +``brainstate.nn`` module +======================== + +.. currentmodule:: brainstate.nn +.. automodule:: brainstate.nn + +Base Classes +------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ExplicitInOutSize + ElementWiseBlock + Sequential + DnnLayer + + +Synaptic Projections +-------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + HalfProjAlignPostMg + FullProjAlignPostMg + HalfProjAlignPost + FullProjAlignPost + FullProjAlignPreSDMg + FullProjAlignPreDSMg + FullProjAlignPreSD + FullProjAlignPreDS + HalfProjDelta + FullProjDelta + VanillaProj + + +Connection Layers +----------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + Linear + ScaledWSLinear + SignedWLinear + CSRLinear + Conv1d + Conv2d + Conv3d + ScaledWSConv1d + ScaledWSConv2d + ScaledWSConv3d + + +Neuronal/Synaptic Dynamics +-------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + Neuron + IF + LIF + ALIF + Synapse + Expon + STP + STD + + +Rate RNNs +--------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + RNNCell + ValinaRNNCell + GRUCell + MGUCell + LSTMCell + URLSTMCell + + +Readout Layers +-------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + LeakyRateReadout + LeakySpikeReadout + + +Synaptic Outputs +---------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + SynOut + COBA + CUBA + MgBlock + + +Element-wise Layers +------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + Threshold + ReLU + RReLU + Hardtanh + ReLU6 + Sigmoid + Hardsigmoid + Tanh + SiLU + Mish + Hardswish + ELU + CELU + SELU + GLU + GELU + Hardshrink + LeakyReLU + LogSigmoid + Softplus + Softshrink + PReLU + Softsign + Tanhshrink + Softmin + Softmax + Softmax2d + LogSoftmax + Dropout + Dropout1d + Dropout2d + Dropout3d + AlphaDropout + FeatureAlphaDropout + Identity + SpikeBitwise + + +Normalization Layers +-------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + BatchNorm1d + BatchNorm2d + BatchNorm3d + + +Pooling Layers +-------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + Flatten + Unflatten + AvgPool1d + AvgPool2d + AvgPool3d + MaxPool1d + MaxPool2d + MaxPool3d + AdaptiveAvgPool1d + AdaptiveAvgPool2d + AdaptiveAvgPool3d + AdaptiveMaxPool1d + AdaptiveMaxPool2d + AdaptiveMaxPool3d + + +Other Layers +------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + DropoutFixed + + diff --git a/docs/apis/braincore.rst b/docs/apis/brainstate.rst similarity index 92% rename from docs/apis/braincore.rst rename to docs/apis/brainstate.rst index e027fe6..f990bc2 100644 --- a/docs/apis/braincore.rst +++ b/docs/apis/brainstate.rst @@ -1,8 +1,8 @@ -``braincore`` module -==================== +``brainstate`` module +===================== -.. currentmodule:: braincore -.. automodule:: braincore +.. currentmodule:: brainstate +.. automodule:: brainstate ``State`` System ---------------- diff --git a/docs/apis/init.rst b/docs/apis/init.rst new file mode 100644 index 0000000..cf4398f --- /dev/null +++ b/docs/apis/init.rst @@ -0,0 +1,29 @@ +``brainstate.init`` module +========================== + +.. currentmodule:: brainstate.init +.. automodule:: brainstate.init + +.. autosummary:: + :toctree: generated/ + + param + state + noise + to_size + Initializer + ZeroInit + Constant + Identity + Normal + TruncatedNormal + Uniform + VarianceScaling + KaimingUniform + KaimingNormal + XavierUniform + XavierNormal + LecunUniform + LecunNormal + Orthogonal + DeltaOrthogonal diff --git a/docs/apis/math.rst b/docs/apis/math.rst index 35ca35c..eaa654b 100644 --- a/docs/apis/math.rst +++ b/docs/apis/math.rst @@ -1,8 +1,8 @@ -``braincore.math`` module -========================= +``brainstate.math`` module +========================== -.. currentmodule:: braincore.math -.. automodule:: braincore.math +.. currentmodule:: brainstate.math +.. automodule:: brainstate.math .. autosummary:: :toctree: generated/ @@ -19,7 +19,7 @@ as_numpy tree_zeros_like tree_ones_like - ein_reduce - ein_rearrange - ein_repeat - ein_shape + einreduce + einrearrange + einrepeat + einshape diff --git a/docs/apis/mixin.rst b/docs/apis/mixin.rst index 1be63a2..d8b30a6 100644 --- a/docs/apis/mixin.rst +++ b/docs/apis/mixin.rst @@ -1,8 +1,8 @@ -``braincore.mixin`` module -========================== +``brainstate.mixin`` module +=========================== -.. currentmodule:: braincore.mixin -.. automodule:: braincore.mixin +.. currentmodule:: brainstate.mixin +.. automodule:: brainstate.mixin .. autosummary:: :toctree: generated/ diff --git a/docs/apis/optim.rst b/docs/apis/optim.rst new file mode 100644 index 0000000..191d1f2 --- /dev/null +++ b/docs/apis/optim.rst @@ -0,0 +1,33 @@ +``brainstate.optim`` module +=========================== + +.. currentmodule:: brainstate.optim +.. automodule:: brainstate.optim + +.. autosummary:: + :toctree: generated/ + + to_same_dict_tree + LearningRateScheduler + ConstantLR + StepLR + MultiStepLR + CosineAnnealingLR + CosineAnnealingWarmRestarts + ExponentialLR + ExponentialDecayLR + InverseTimeDecayLR + PolynomialDecayLR + PiecewiseConstantLR + OptimState + Optimizer + SGD + Momentum + MomentumNesterov + Adagrad + Adadelta + RMSProp + Adam + LARS + Adan + AdamW diff --git a/docs/apis/random.rst b/docs/apis/random.rst index 8b26b8d..ee99c67 100644 --- a/docs/apis/random.rst +++ b/docs/apis/random.rst @@ -1,8 +1,8 @@ -``braincore.random`` module -=========================== +``brainstate.random`` module +============================ -.. currentmodule:: braincore.random -.. automodule:: braincore.random +.. currentmodule:: brainstate.random +.. automodule:: brainstate.random Random Number Generators diff --git a/docs/apis/surrogate.rst b/docs/apis/surrogate.rst index 0148b78..d1255b4 100644 --- a/docs/apis/surrogate.rst +++ b/docs/apis/surrogate.rst @@ -1,8 +1,8 @@ -``braincore.surrogate`` module -============================== +``brainstate.surrogate`` module +=============================== -.. currentmodule:: braincore.surrogate -.. automodule:: braincore.surrogate +.. currentmodule:: brainstate.surrogate +.. automodule:: brainstate.surrogate Surrogate Gradient Functions diff --git a/docs/apis/transform.rst b/docs/apis/transform.rst index 045ef23..7e8fe10 100644 --- a/docs/apis/transform.rst +++ b/docs/apis/transform.rst @@ -1,8 +1,8 @@ -``braincore.transform`` module -============================== +``brainstate.transform`` module +=============================== -.. currentmodule:: braincore.transform -.. automodule:: braincore.transform +.. currentmodule:: brainstate.transform +.. automodule:: brainstate.transform diff --git a/docs/apis/util.rst b/docs/apis/util.rst index c9dfa2b..0014912 100644 --- a/docs/apis/util.rst +++ b/docs/apis/util.rst @@ -1,8 +1,8 @@ -``braincore.util`` module -========================= +``brainstate.util`` module +========================== -.. currentmodule:: braincore.util -.. automodule:: braincore.util +.. currentmodule:: brainstate.util +.. automodule:: brainstate.util .. autosummary:: :toctree: generated/ diff --git a/docs/auto_generater.py b/docs/auto_generater.py index 5d19805..0dface5 100644 --- a/docs/auto_generater.py +++ b/docs/auto_generater.py @@ -449,38 +449,75 @@ def generate_algorithm_docs(path='apis/auto/algorithms/'): def main(): os.makedirs('apis/auto/', exist_ok=True) - _write_module(module_name='brainstate.surrogate', - filename='apis/auto/surrogate.rst', - header='``brainstate.surrogate`` module') + # _write_module(module_name='brainstate.surrogate', + # filename='apis/auto/surrogate.rst', + # header='``brainstate.surrogate`` module') + + # _write_module(module_name='brainstate.random', + # filename='apis/auto/random.rst', + # header='``brainstate.random`` module') + + # _write_module(module_name='brainstate.mixin', + # filename='apis/auto/mixin.rst', + # header='``brainstate.mixin`` module') + + # _write_module(module_name='brainstate.transform', + # filename='apis/auto/transform.rst', + # header='``brainstate.transform`` module') + + # _write_module(module_name='brainstate.math', + # filename='apis/auto/math.rst', + # header='``brainstate.math`` module') + + # _write_module(module_name='brainstate.util', + # filename='apis/auto/util.rst', + # header='``brainstate.util`` module') + + _write_module(module_name='brainstate.optim', + filename='apis/optim.rst', + header='``brainstate.optim`` module') + + _write_module(module_name='brainstate.init', + filename='apis/init.rst', + header='``brainstate.init`` module') + + # module_and_name = [ + # ('_state', 'State System'), + # ('_module', 'Module & Container'), + # ] + # _write_submodules(module_name='brainstate', + # filename='apis/auto/brainstate.rst', + # header='``brainstate`` module', + # submodule_names=[k[0] for k in module_and_name], + # section_names=[k[1] for k in module_and_name]) - _write_module(module_name='brainstate.random', - filename='apis/auto/random.rst', - header='``brainstate.random`` module') - - _write_module(module_name='brainstate.mixin', - filename='apis/auto/mixin.rst', - header='``brainstate.mixin`` module') - - _write_module(module_name='brainstate.transform', - filename='apis/auto/transform.rst', - header='``brainstate.transform`` module') - - _write_module(module_name='brainstate.math', - filename='apis/auto/math.rst', - header='``brainstate.math`` module') - - _write_module(module_name='brainstate.util', - filename='apis/auto/util.rst', - header='``brainstate.util`` module') + module_and_name = [ + ('_activations', 'Activation Functions'), + ('_normalization', 'Normalization Functions'), + ('_spikes', 'Spiking Operations'), + ] + _write_submodules(module_name='brainstate.functional', + filename='apis/brainstate.functional.rst', + header='``brainstate.functional`` module', + submodule_names=[k[0] for k in module_and_name], + section_names=[k[1] for k in module_and_name]) module_and_name = [ - ('_state', 'State System'), - ('_module', 'Module & Container'), + ('_base', 'Base Classes'), ('_projection', 'Synaptic Projections'), + ('_connections', 'Connection Layers'), + ('_dynamics', 'Neuronal/Synaptic Dynamics'), + ('_rate_rnns', 'Rate RNNs'), + ('_readout', 'Readout Layers'), + ('_synouts', 'Synaptic Outputs'), + ('_elementwise', 'Element-wise Layers'), + ('_normalizations', 'Normalization Layers'), + ('_poolings', 'Pooling Layers'), + ('_others', 'Other Layers'), ] - _write_submodules(module_name='brainstate', - filename='apis/auto/brainstate.rst', - header='``brainstate`` module', + _write_submodules(module_name='brainstate.nn', + filename='apis/brainstate.nn.rst', + header='``brainstate.nn`` module', submodule_names=[k[0] for k in module_and_name], section_names=[k[1] for k in module_and_name]) diff --git a/docs/index.rst b/docs/index.rst index 2da5fb8..264c261 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -33,7 +33,7 @@ Installation .. code-block:: bash - pip install -U brainstate[tpu] + pip install -U brainstate[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ---- @@ -43,13 +43,15 @@ See also the BDP ecosystem ^^^^^^^^^^^^^^^^^^^^^^^^^^ -- `brainpy `_: The solution for the general-purpose brain dynamics programming. +- `brainstate `_: A ``State``-based transformation system for brain dynamics programming. -- `brainstate `_: The ``State``-based transformation system for brain dynamics programming. +- `brainunit `_: The unit system for brain dynamics programming. -- `braintools `_: The tools for the brain dynamics simulation and analysis. +- `braintaichi `_: Leveraging Taichi Lang to customize brain dynamics operators. -- `brainscale `_: The scalable online learning for biological spiking neural networks. +- `brainscale `_: The scalable online learning framework for biological neural networks. + +- `braintools `_: The toolbox for the brain dynamics simulation, training and analysis.