diff --git a/brainstate/__init__.py b/brainstate/__init__.py index b781a9c..d64fdae 100644 --- a/brainstate/__init__.py +++ b/brainstate/__init__.py @@ -17,7 +17,7 @@ A ``State``-based Transformation System for Brain Dynamics Programming """ -__version__ = "0.0.1" +__version__ = "0.0.1.1" from . import environ from . import functional diff --git a/brainstate/_module.py b/brainstate/_module.py index 321768d..02ea235 100644 --- a/brainstate/_module.py +++ b/brainstate/_module.py @@ -61,11 +61,9 @@ from ._utils import set_module_as from .mixin import Mixin, Mode, DelayedInit, JointTypes, Batching, UpdateReturn from .transform import jit_error +from .typing import Size, ArrayLike, PyTree from .util import unique_name, DictManager, get_unique_name -Shape = Union[int, Sequence[int]] -PyTree = Any -ArrayLike = jax.typing.ArrayLike delay_identifier = '_*_delay_of_' _DELAY_ROTATE = 'rotation' @@ -805,7 +803,7 @@ class Dynamics(ExtendedUpdateWithBA, ReceiveInputProj, UpdateReturn): def __init__( self, - size: Shape, + size: Size, keep_size: bool = False, name: Optional[str] = None, mode: Optional[Mode] = None, @@ -1275,25 +1273,25 @@ def _check_delay(t_now, t_delay): if self.interp_method == _INTERP_LINEAR: # "linear" interpolation # def _interp(target): - # if len(indices) > 0: - # raise NotImplementedError('The slicing indices are not supported in the linear interpolation.') - # if self.delay_method == _DELAY_ROTATE: - # i = environ.get(environ.I, desc='The time step index.') - # _interp_fun = partial(jnp.interp, period=self.max_length) - # for dim in range(1, target.ndim, 1): - # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1) - # di = i - jnp.arange(self.max_length) - # delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32) - # return _interp_fun(float_time_step, delay_idx, target) - # - # elif self.delay_method == _DELAY_CONCAT: - # _interp_fun = partial(jnp.interp, period=self.max_length) - # for dim in range(1, target.ndim, 1): - # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1) - # return _interp_fun(float_time_step, jnp.arange(self.max_length), target) - # - # else: - # raise ValueError(f'Unknown delay updating method "{self.delay_method}"') + # if len(indices) > 0: + # raise NotImplementedError('The slicing indices are not supported in the linear interpolation.') + # if self.delay_method == _DELAY_ROTATE: + # i = environ.get(environ.I, desc='The time step index.') + # _interp_fun = partial(jnp.interp, period=self.max_length) + # for dim in range(1, target.ndim, 1): + # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1) + # di = i - jnp.arange(self.max_length) + # delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32) + # return _interp_fun(float_time_step, delay_idx, target) + # + # elif self.delay_method == _DELAY_CONCAT: + # _interp_fun = partial(jnp.interp, period=self.max_length) + # for dim in range(1, target.ndim, 1): + # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1) + # return _interp_fun(float_time_step, jnp.arange(self.max_length), target) + # + # else: + # raise ValueError(f'Unknown delay updating method "{self.delay_method}"') # return jax.tree.map(_interp, self.history.value) data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices) diff --git a/brainstate/_state.py b/brainstate/_state.py index 992e19c..92618e0 100644 --- a/brainstate/_state.py +++ b/brainstate/_state.py @@ -15,18 +15,16 @@ import contextlib import threading -from typing import Any, Tuple, Dict, List, Callable +from typing import Any, Tuple, Dict, List, Callable, Optional import jax import numpy as np from jax.api_util import shaped_abstractify from jax.extend import source_info_util +from .typing import ArrayLike, PyTree from .util import DictManager -PyTree = Any -max_int = np.iinfo(np.int32) - __all__ = [ 'State', 'ShortTermState', 'LongTermState', 'ParamState', 'StateDictManager', @@ -36,6 +34,7 @@ ] _pytree_registered_objects = set() +max_int = np.iinfo(np.int32) def _register_pytree_cls(cls): @@ -108,9 +107,9 @@ class MyState(State): value: PyTree. It can be anything as a pyTree. """ __module__ = 'brainstate' - __slots__ = ('_value', '_tree', '_level', '_source_info', '_check_tree') + __slots__ = ('_value', '_name', '_tree', '_level', '_source_info', '_check_tree') - def __init__(self, value: PyTree): + def __init__(self, value: PyTree[ArrayLike], name: Optional[str] = None): if isinstance(value, State): value = value.value self._value = value @@ -118,9 +117,24 @@ def __init__(self, value: PyTree): self._check_tree = False self._level = len(thread_local_stack.stack) self._source_info = source_info_util.current() + self._name = name + + @property + def name(self) -> Optional[str]: + """ + The name of the state. + """ + return self._name + + @name.setter + def name(self, name: str) -> None: + """ + Set the name of the state. + """ + self._name = name @property - def value(self) -> PyTree: + def value(self) -> PyTree[ArrayLike]: """ The data and its value. """ @@ -210,7 +224,10 @@ def __repr__(self): leaves, tree = jax.tree.flatten(self._value) leaves_info = [ShapeDtype(leaf.shape, leaf.dtype) for leaf in leaves] tree_info = jax.tree.unflatten(tree, leaves_info) - return f'{self.__class__.__name__}({tree_info})' + if self.name is None: + return f'{self.__class__.__name__}({tree_info})' + else: + return f'{self.__class__.__name__}({self.name}: {tree_info})' class ShapeDtype: diff --git a/brainstate/functional/_activations.py b/brainstate/functional/_activations.py index d421f44..63ff6bb 100644 --- a/brainstate/functional/_activations.py +++ b/brainstate/functional/_activations.py @@ -25,8 +25,8 @@ import jax import jax.numpy as jnp from jax.scipy.special import logsumexp -from jax.typing import ArrayLike +from brainstate.typing import ArrayLike from .. import random __all__ = [ diff --git a/brainstate/functional/_others.py b/brainstate/functional/_others.py index 5ca52ce..e958ba8 100644 --- a/brainstate/functional/_others.py +++ b/brainstate/functional/_others.py @@ -16,12 +16,11 @@ from __future__ import annotations from functools import partial -from typing import Any import jax import jax.numpy as jnp -PyTree = Any +from brainstate.typing import PyTree __all__ = [ 'clip_grad_norm', diff --git a/brainstate/init/_generic.py b/brainstate/init/_generic.py index f753833..c81b59a 100644 --- a/brainstate/init/_generic.py +++ b/brainstate/init/_generic.py @@ -22,8 +22,8 @@ import numpy as np from brainstate._state import State +from brainstate.typing import ArrayLike from ._base import to_size -from ..typing import ArrayLike __all__ = [ 'param', @@ -83,7 +83,7 @@ def _expand_params_to_match_sizes(params, sizes): def param( - parameter: Union[Callable, ArrayLike], + parameter: Union[Callable, ArrayLike, State], sizes: Union[int, Sequence[int]], batch_size: Optional[int] = None, allow_none: bool = True, diff --git a/brainstate/init/_random_inits.py b/brainstate/init/_random_inits.py index 762aa0e..90f598e 100644 --- a/brainstate/init/_random_inits.py +++ b/brainstate/init/_random_inits.py @@ -22,8 +22,8 @@ import numpy as np from brainstate import environ, random +from brainstate.typing import ArrayLike from ._base import Initializer, to_size -from ..typing import ArrayLike __all__ = [ 'Normal', diff --git a/brainstate/mixin.py b/brainstate/mixin.py index 76ffada..8ab6b41 100644 --- a/brainstate/mixin.py +++ b/brainstate/mixin.py @@ -15,14 +15,15 @@ # -*- coding: utf-8 -*- -from typing import Sequence, Optional, TypeVar, Any +from typing import Sequence, Optional, TypeVar from typing import (_SpecialForm, _type_check, _remove_dups_flatten, _UnionGenericAlias) -T = TypeVar('T') -PyTree = Any +from .typing import PyTree +T = TypeVar('T') State = None + __all__ = [ 'Mixin', 'DelayedInit', @@ -207,7 +208,7 @@ def __subclasscheck__(self, subclass): @_SpecialForm def JointTypes(self, parameters): - """All of types; AllOfTypes[X, Y] means both X and Y. + """Joint types; JointTypes[X, Y] means both X and Y. To define a union, use e.g. Union[int, str]. @@ -216,28 +217,28 @@ def JointTypes(self, parameters): - None as an argument is a special case and is replaced by `type(None)`. - Unions of unions are flattened, e.g.:: - AllOfTypes[AllOfTypes[int, str], float] == AllOfTypes[int, str, float] + JointTypes[JointTypes[int, str], float] == JointTypes[int, str, float] - Unions of a single argument vanish, e.g.:: - AllOfTypes[int] == int # The constructor actually returns int + JointTypes[int] == int # The constructor actually returns int - Redundant arguments are skipped, e.g.:: - AllOfTypes[int, str, int] == AllOfTypes[int, str] + JointTypes[int, str, int] == JointTypes[int, str] - When comparing unions, the argument order is ignored, e.g.:: - AllOfTypes[int, str] == AllOfTypes[str, int] + JointTypes[int, str] == JointTypes[str, int] - - You cannot subclass or instantiate a AllOfTypes. - - You can use Optional[X] as a shorthand for AllOfTypes[X, None]. + - You cannot subclass or instantiate a JointTypes. + - You can use Optional[X] as a shorthand for JointTypes[X, None]. """ if parameters == (): raise TypeError("Cannot take a Joint of no types.") if not isinstance(parameters, tuple): parameters = (parameters,) - msg = "AllOfTypes[arg, ...]: each arg must be a type." + msg = "JointTypes[arg, ...]: each arg must be a type." parameters = tuple(_type_check(p, msg) for p in parameters) parameters = _remove_dups_flatten(parameters) if len(parameters) == 1: diff --git a/brainstate/mixin_test.py b/brainstate/mixin_test.py index 90b0a12..534ed21 100644 --- a/brainstate/mixin_test.py +++ b/brainstate/mixin_test.py @@ -30,6 +30,8 @@ def test_mixin(self): self.assertTrue(bc.mixin.Training) + + class TestMode(unittest.TestCase): def test_JointMode(self): a = bc.mixin.JointMode(bc.mixin.Batching(), bc.mixin.Training()) diff --git a/brainstate/nn/_elementwise.py b/brainstate/nn/_elementwise.py index 638f49a..1e10b57 100644 --- a/brainstate/nn/_elementwise.py +++ b/brainstate/nn/_elementwise.py @@ -1139,13 +1139,13 @@ def __init__( name: Optional[str] = None ) -> None: super().__init__(mode=mode, name=name) - assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}." + assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}." self.prob = prob def __call__(self, x): dtype = bu.math.get_dtype(x) fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.') - if fit_phase: + if fit_phase and self.prob < 1.: keep_mask = random.bernoulli(self.prob, x.shape) return jnp.where(keep_mask, jnp.asarray(x / self.prob, dtype=dtype), diff --git a/brainstate/random.py b/brainstate/random.py index c2ff756..1f12a2d 100644 --- a/brainstate/random.py +++ b/brainstate/random.py @@ -1167,23 +1167,32 @@ def default_rng(seed_or_key=None, clone: bool = True) -> RandomState: return RandomState(seed_or_key) -def seed(seed: int = None): +def seed(seed_or_key: int = None): """Sets a new random seed. Parameters ---------- - seed: int, optional - The random seed. + seed_or_key: int, optional + The random seed (an integer) or jax random key. """ with jax.ensure_compile_time_eval(): - if seed is None: - seed = np.random.randint(0, 100000) - np.random.seed(seed) - DEFAULT.seed(seed) + if seed_or_key is None: + seed_or_key = np.random.randint(0, 100000) + + # numpy random seed + if np.size(seed_or_key) == 1: # seed + np.random.seed(seed_or_key) + elif np.size(seed_or_key) == 2: # jax random key + np.random.seed(seed_or_key[0]) + else: + raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.") + + # jax random seed + DEFAULT.seed(seed_or_key) @contextmanager -def seed_context(seed: int): +def seed_context(seed_or_key: SeedOrKey): """ A context manager that sets the random seed for the duration of the block. @@ -1206,16 +1215,19 @@ def seed_context(seed: int): The context manager does not only set the seed for the AX random state, but also for the numpy random state. Args: - seed: The seed (an integer). + seed_or_key: The seed (an integer) or jax random key. - Returns: - The random state. """ old_jrand_key = DEFAULT.value old_np_state = np.random.get_state() try: - np.random.seed(seed) - DEFAULT.seed(seed) + if np.size(seed_or_key) == 1: # seed + np.random.seed(seed_or_key) + elif np.size(seed_or_key) == 2: # jax random key + np.random.seed(seed_or_key[0]) + else: + raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.") + DEFAULT.seed(seed_or_key) yield finally: np.random.set_state(old_np_state) @@ -1223,7 +1235,8 @@ def seed_context(seed: int): def rand(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None): - r"""Random values in a given shape. + r""" + Random values in a given shape. .. note:: This is a convenience function for users porting code from Matlab, @@ -4796,10 +4809,9 @@ def _size2shape(size): def _check_shape(name, shape, *param_shapes): - shape = core.as_named_shape(shape) if param_shapes: - shape_ = lax.broadcast_shapes(shape.positional, *param_shapes) - if shape.positional != shape_: + shape_ = lax.broadcast_shapes(shape, *param_shapes) + if shape != shape_: msg = ("{} parameter shapes must be broadcast-compatible with shape " "argument, and the result of broadcasting the shapes must equal " "the shape argument, but got result {} for shape argument {}.") diff --git a/brainstate/transform/_control.py b/brainstate/transform/_control.py index e5ef89f..a50e051 100644 --- a/brainstate/transform/_control.py +++ b/brainstate/transform/_control.py @@ -25,7 +25,7 @@ import numpy as np from brainstate._utils import set_module_as -from ._jit_error import jit_error +from ._jit_error import jit_error, remove_vmap from ._make_jaxpr import StatefulFunction, _assign_state_values from ._progress_bar import ProgressBar @@ -347,7 +347,7 @@ def _wrap_fun_with_pbar(fun, pbar_runner): def new_fun(new_carry, inputs): i, old_carry = new_carry old_carry, old_outputs = fun(old_carry, inputs) - pbar_runner(i) + pbar_runner(remove_vmap(i, op='none')) return (i + 1, old_carry), old_outputs return new_fun diff --git a/brainstate/transform/_jit_error_test.py b/brainstate/transform/_jit_error_test.py index b554d4d..ebe7437 100644 --- a/brainstate/transform/_jit_error_test.py +++ b/brainstate/transform/_jit_error_test.py @@ -16,8 +16,8 @@ import unittest import jax -import jaxlib.xla_extension import jax.numpy as jnp +import jaxlib.xla_extension import brainstate as bst diff --git a/brainstate/transform/_make_jaxpr.py b/brainstate/transform/_make_jaxpr.py index 2d3a87e..40c811b 100644 --- a/brainstate/transform/_make_jaxpr.py +++ b/brainstate/transform/_make_jaxpr.py @@ -71,8 +71,8 @@ from brainstate._state import State, StateTrace from brainstate._utils import set_module_as +from brainstate.typing import PyTree -PyTree = Any AxisName = Hashable __all__ = [ diff --git a/brainstate/typing.py b/brainstate/typing.py index 24cffb6..b32bdc0 100644 --- a/brainstate/typing.py +++ b/brainstate/typing.py @@ -14,13 +14,16 @@ # ============================================================================== -from typing import Any, Sequence, Protocol, Union +import functools as ft +import typing +from typing import Sequence, Protocol, Union, Any, Generic, TypeVar import brainunit as bu import jax import numpy as np __all__ = [ + 'PyTree', 'Size', 'Axes', 'SeedOrKey', @@ -29,6 +32,151 @@ 'DTypeLike', ] +_T = TypeVar("_T") + + +class _FakePyTree(Generic[_T]): + pass + + +_FakePyTree.__name__ = "PyTree" +_FakePyTree.__qualname__ = "PyTree" +_FakePyTree.__module__ = "builtins" + + +class _MetaPyTree(type): + def __call__(self, *args, **kwargs): + raise RuntimeError("PyTree cannot be instantiated") + + # Can't return a generic (e.g. _FakePyTree[item]) because generic aliases don't do + # the custom __instancecheck__ that we want. + # We can't add that __instancecheck__ via subclassing, e.g. + # type("PyTree", (Generic[_T],), {}), because dynamic subclassing of typeforms + # isn't allowed. + # Likewise we can't do types.new_class("PyTree", (Generic[_T],), {}) because that + # has __module__ "types", e.g. we get types.PyTree[int]. + @ft.lru_cache(maxsize=None) + def __getitem__(cls, item): + if isinstance(item, tuple): + if len(item) == 2: + + class X(PyTree): + leaftype = item[0] + structure = item[1].strip() + + if not isinstance(X.structure, str): + raise ValueError( + "The structure annotation `struct` in " + "`brainstate.typing.PyTree[leaftype, struct]` must be be a string, " + f"e.g. `brainstate.typing.PyTree[leaftype, 'T']`. Got '{X.structure}'." + ) + pieces = X.structure.split() + if len(pieces) == 0: + raise ValueError( + "The string `struct` in `brainstate.typing.PyTree[leaftype, struct]` " + "cannot be the empty string." + ) + for piece_index, piece in enumerate(pieces): + if (piece_index == 0) or (piece_index == len(pieces) - 1): + if piece == "...": + continue + if not piece.isidentifier(): + raise ValueError( + "The string `struct` in " + "`brainstate.typing.PyTree[leaftype, struct]` must be be a " + "whitespace-separated sequence of identifiers, e.g. " + "`brainstate.typing.PyTree[leaftype, 'T']` or " + "`brainstate.typing.PyTree[leaftype, 'foo bar']`.\n" + "(Here, 'identifier' is used in the same sense as in " + "regular Python, i.e. a valid variable name.)\n" + f"Got piece '{piece}' in overall structure '{X.structure}'." + ) + name = str(_FakePyTree[item[0]])[:-1] + ', "' + item[1].strip() + '"]' + else: + raise ValueError( + "The subscript `foo` in `brainstate.typing.PyTree[foo]` must either be a " + "leaf type, e.g. `PyTree[int]`, or a 2-tuple of leaf and " + "structure, e.g. `PyTree[int, 'T']`. Received a tuple of length " + f"{len(item)}." + ) + else: + name = str(_FakePyTree[item]) + + class X(PyTree): + leaftype = item + structure = None + + X.__name__ = name + X.__qualname__ = name + if getattr(typing, "GENERATING_DOCUMENTATION", False): + X.__module__ = "builtins" + else: + X.__module__ = "brainstate.typing" + return X + + +# Can't do `class PyTree(Generic[_T]): ...` because we need to override the +# instancecheck for PyTree[foo], but subclassing +# `type(Generic[int])`, i.e. `typing._GenericAlias` is disallowed. +PyTree = _MetaPyTree("PyTree", (), {}) +if getattr(typing, "GENERATING_DOCUMENTATION", False): + PyTree.__module__ = "builtins" +else: + PyTree.__module__ = "brainstate.typing" +PyTree.__doc__ = """Represents a PyTree. + +Annotations of the following sorts are supported: +```python +a: PyTree +b: PyTree[LeafType] +c: PyTree[LeafType, "T"] +d: PyTree[LeafType, "S T"] +e: PyTree[LeafType, "... T"] +f: PyTree[LeafType, "T ..."] +``` + +These correspond to: + +a. A plain `PyTree` can be used an annotation, in which case `PyTree` is simply a + suggestively-named alternative to `Any`. + ([By definition all types are PyTrees.](https://jax.readthedocs.io/en/latest/pytrees.html)) + +b. `PyTree[LeafType]` denotes a PyTree all of whose leaves match `LeafType`. For + example, `PyTree[int]` or `PyTree[Union[str, Float32[Array, "b c"]]]`. + +c. A structure name can also be passed. In this case + `jax.tree_util.tree_structure(...)` will be called, and bound to the structure name. + This can be used to mark that multiple PyTrees all have the same structure: + ```python + def f(x: PyTree[int, "T"], y: PyTree[int, "T"]): + ... + ``` + +d. A composite structure can be declared. In this case the variable must have a PyTree + structure each to the composition of multiple previously-bound PyTree structures. + For example: + ```python + def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]): + ... + + x = (1, 2) + y = {"key": 3} + z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z` + f(x, y, z) + ``` + When performing runtime type-checking, all the individual pieces must have already + been bound to structures, otherwise the composite structure check will throw an error. + +e. A structure can begin with a `...`, to denote that the lower levels of the PyTree + must match the declared structure, but the upper levels can be arbitrary. As in the + previous case, all named pieces must already have been seen and their structures + bound. + +f. A structure can end with a `...`, to denote that the PyTree must be a prefix of the + declared structure, but the lower levels can be arbitrary. As in the previous two + cases, all named pieces must already have been seen and their structures bound. +""" # noqa: E501 + Size = Union[int, Sequence[int]] Axes = Union[int, Sequence[int]] SeedOrKey = Union[int, jax.Array, np.ndarray] @@ -44,7 +192,7 @@ np.ndarray, # NumPy array type np.bool_, np.number, # NumPy scalar types bool, int, float, complex, # Python scalar types - bu.Quantity, # quantity + bu.Quantity, # Quantity ] # --- Dtype --- # diff --git a/docs/api.rst b/docs/api.rst index 4fb71a9..48b58ee 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -7,12 +7,12 @@ API Documentation apis/changelog.md apis/brainstate.rst apis/init.rst - apis/math.rst apis/mixin.rst apis/optim.rst apis/random.rst apis/surrogate.rst apis/transform.rst apis/util.rst + apis/typing.rst apis/brainstate.nn.rst apis/brainstate.functional.rst diff --git a/docs/apis/init.rst b/docs/apis/init.rst index cf4398f..b770fb8 100644 --- a/docs/apis/init.rst +++ b/docs/apis/init.rst @@ -4,6 +4,11 @@ .. currentmodule:: brainstate.init .. automodule:: brainstate.init + + +Helper Functions +---------------- + .. autosummary:: :toctree: generated/ @@ -11,6 +16,15 @@ state noise to_size + +Initialization Classes +---------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + Initializer ZeroInit Constant diff --git a/docs/apis/math.rst b/docs/apis/math.rst deleted file mode 100644 index eaa654b..0000000 --- a/docs/apis/math.rst +++ /dev/null @@ -1,25 +0,0 @@ -``brainstate.math`` module -========================== - -.. currentmodule:: brainstate.math -.. automodule:: brainstate.math - -.. autosummary:: - :toctree: generated/ - - get_dtype - is_float - is_int - exprel - flatten - unflatten - remove_diag - clip_by_norm - from_numpy - as_numpy - tree_zeros_like - tree_ones_like - einreduce - einrearrange - einrepeat - einshape diff --git a/docs/apis/mixin.rst b/docs/apis/mixin.rst index d8b30a6..4b419eb 100644 --- a/docs/apis/mixin.rst +++ b/docs/apis/mixin.rst @@ -8,8 +8,8 @@ :toctree: generated/ Mixin - Delayed DelayedInit + DelayedInitializer AlignPost BindCondData UpdateReturn @@ -17,5 +17,5 @@ JointMode Batching Training - AllOfTypes + JointTypes OneOfTypes diff --git a/docs/apis/random.rst b/docs/apis/random.rst index ee99c67..978ecb5 100644 --- a/docs/apis/random.rst +++ b/docs/apis/random.rst @@ -25,6 +25,7 @@ Random Helper Functions :template: classtemplate.rst seed + seed_context default_rng split_key split_keys diff --git a/docs/apis/typing.rst b/docs/apis/typing.rst new file mode 100644 index 0000000..e4ca1bb --- /dev/null +++ b/docs/apis/typing.rst @@ -0,0 +1,16 @@ +``brainstate.typing`` module +============================ + +.. currentmodule:: brainstate.typing +.. automodule:: brainstate.typing + +.. autosummary:: + :toctree: generated/ + + PyTree + DType + Size + Axes + SeedOrKey + ArrayLike + DTypeLike diff --git a/docs/auto_generater.py b/docs/auto_generater.py index 0dface5..b3f5333 100644 --- a/docs/auto_generater.py +++ b/docs/auto_generater.py @@ -473,12 +473,16 @@ def main(): # filename='apis/auto/util.rst', # header='``brainstate.util`` module') - _write_module(module_name='brainstate.optim', - filename='apis/optim.rst', - header='``brainstate.optim`` module') + # _write_module(module_name='brainstate.typing', + # filename='apis/typing.rst', + # header='``brainstate.typing`` module') + + # _write_module(module_name='brainstate.optim', + # filename='apis/optim.rst', + # header='``brainstate.optim`` module') _write_module(module_name='brainstate.init', - filename='apis/init.rst', + filename='apis/auto/init.rst', header='``brainstate.init`` module') # module_and_name = [ @@ -491,35 +495,35 @@ def main(): # submodule_names=[k[0] for k in module_and_name], # section_names=[k[1] for k in module_and_name]) - module_and_name = [ - ('_activations', 'Activation Functions'), - ('_normalization', 'Normalization Functions'), - ('_spikes', 'Spiking Operations'), - ] - _write_submodules(module_name='brainstate.functional', - filename='apis/brainstate.functional.rst', - header='``brainstate.functional`` module', - submodule_names=[k[0] for k in module_and_name], - section_names=[k[1] for k in module_and_name]) + # module_and_name = [ + # ('_activations', 'Activation Functions'), + # ('_normalization', 'Normalization Functions'), + # ('_spikes', 'Spiking Operations'), + # ] + # _write_submodules(module_name='brainstate.functional', + # filename='apis/brainstate.functional.rst', + # header='``brainstate.functional`` module', + # submodule_names=[k[0] for k in module_and_name], + # section_names=[k[1] for k in module_and_name]) - module_and_name = [ - ('_base', 'Base Classes'), - ('_projection', 'Synaptic Projections'), - ('_connections', 'Connection Layers'), - ('_dynamics', 'Neuronal/Synaptic Dynamics'), - ('_rate_rnns', 'Rate RNNs'), - ('_readout', 'Readout Layers'), - ('_synouts', 'Synaptic Outputs'), - ('_elementwise', 'Element-wise Layers'), - ('_normalizations', 'Normalization Layers'), - ('_poolings', 'Pooling Layers'), - ('_others', 'Other Layers'), - ] - _write_submodules(module_name='brainstate.nn', - filename='apis/brainstate.nn.rst', - header='``brainstate.nn`` module', - submodule_names=[k[0] for k in module_and_name], - section_names=[k[1] for k in module_and_name]) + # module_and_name = [ + # ('_base', 'Base Classes'), + # ('_projection', 'Synaptic Projections'), + # ('_connections', 'Connection Layers'), + # ('_dynamics', 'Neuronal/Synaptic Dynamics'), + # ('_rate_rnns', 'Rate RNNs'), + # ('_readout', 'Readout Layers'), + # ('_synouts', 'Synaptic Outputs'), + # ('_elementwise', 'Element-wise Layers'), + # ('_normalizations', 'Normalization Layers'), + # ('_poolings', 'Pooling Layers'), + # ('_others', 'Other Layers'), + # ] + # _write_submodules(module_name='brainstate.nn', + # filename='apis/brainstate.nn.rst', + # header='``brainstate.nn`` module', + # submodule_names=[k[0] for k in module_and_name], + # section_names=[k[1] for k in module_and_name]) if __name__ == '__main__': diff --git a/setup.py b/setup.py index 45344cc..15e6729 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ description='A ``State``-based Transformation System for Brain Dynamics Programming.', long_description=README, long_description_content_type="text/markdown", - author='BrainPy Team', + author='BDP', author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.9', @@ -70,9 +70,11 @@ 'cuda12': ['jaxlib[cuda12_pip]'], 'tpu': ['jaxlib[tpu]'], }, - keywords=('computational neuroscience, ' - 'brain-inspired computation, ' - 'brain dynamics programming'), + keywords=( + 'computational neuroscience, ' + 'brain-inspired computation, ' + 'brain dynamics programming' + ), classifiers=[ 'Natural Language :: English', 'Operating System :: OS Independent',