Skip to content

Commit

Permalink
add brainstate.nn.Embedding; support reset_state() (#2)
Browse files Browse the repository at this point in the history
* add `brainstate.nn.Embedding`

* add `reset_state()` function in dynamics models

* add `reset_state()` function
  • Loading branch information
chaoming0625 authored Jun 10, 2024
1 parent 2a1164d commit b430dea
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 9 deletions.
43 changes: 41 additions & 2 deletions brainstate/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
'call_order',

# state processing
'init_states', 'load_states', 'save_states', 'assign_state_values',
'init_states', 'reset_states', 'load_states', 'save_states', 'assign_state_values',
]


Expand Down Expand Up @@ -271,6 +271,12 @@ def init_state(self, *args, **kwargs):
"""
pass

def reset_state(self, *args, **kwargs):
"""
State resetting function.
"""
pass

def save_state(self, **kwargs) -> Dict:
"""Save states as a dictionary. """
return self.states(include_self=True, level=0, method='absolute')
Expand Down Expand Up @@ -1115,6 +1121,12 @@ def init_state(self, batch_size: int = None, **kwargs):
fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
self.history = State(jax.tree.map(fun, self.target_info))

def reset_state(self, batch_size: int = None, **kwargs):
if batch_size is not None:
assert self.mode.has(Batching), 'The mode should have Batching behavior when batch_size is not None.'
fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
self.history.value = jax.tree.map(fun, self.target_info)

def register_entry(
self,
entry: str,
Expand Down Expand Up @@ -1344,7 +1356,7 @@ def wrap(fun: Callable):
@set_module_as('brainstate')
def init_states(target: Module, *args, **kwargs) -> Module:
"""
Reset states of all children nodes in the given target.
Initialize states of all children nodes in the given target.
Args:
target: The target Module.
Expand All @@ -1368,6 +1380,33 @@ def init_states(target: Module, *args, **kwargs) -> Module:
return target


@set_module_as('brainstate')
def reset_states(target: Module, *args, **kwargs) -> Module:
"""
Reset states of all children nodes in the given target.
Args:
target: The target Module.
Returns:
The target Module.
"""
nodes_with_order = []

# reset node whose `init_state` has no `call_order`
for node in list(target.nodes().values()):
if not hasattr(node.reset_state, 'call_order'):
node.reset_state(*args, **kwargs)
else:
nodes_with_order.append(node)

# reset the node's states
for node in sorted(nodes_with_order, key=lambda x: x.reset_state.call_order):
node.reset_state(*args, **kwargs)

return target


@set_module_as('brainstate')
def load_states(target: Module, state_dict: Dict, **kwargs):
"""Copy parameters and buffers from :attr:`state_dict` into
Expand Down
17 changes: 17 additions & 0 deletions brainstate/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ def __init__(self):
def check_state_value_tree() -> None:
"""
The contex manager to check weather the tree structure of the state value keeps consistently.
Once a :py:class:`~.State` is created, the tree structure of the value is fixed. In default,
the tree structure of the value is not checked to avoid off the repeated evaluation.
If you want to check the tree structure of the value once the new value is assigned,
you can use this context manager.
Example::
```python
state = brainstate.ShortTermState(jnp.zeros((2, 3)))
with check_state_value_tree():
state.value = jnp.zeros((2, 3))
# The following code will raise an error.
state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))
```
"""
try:
_global_context_to_check_state_tree.append(True)
Expand Down
4 changes: 4 additions & 0 deletions brainstate/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from ._dynamics import __all__ as dynamics_all
from ._elementwise import *
from ._elementwise import __all__ as elementwise_all
from ._embedding import *
from ._embedding import __all__ as embed_all
from ._misc import *
from ._misc import __all__ as _misc_all
from ._normalizations import *
Expand All @@ -43,6 +45,7 @@
connections_all +
dynamics_all +
elementwise_all +
embed_all +
normalizations_all +
others_all +
poolings_all +
Expand All @@ -58,6 +61,7 @@
connections_all,
dynamics_all,
elementwise_all,
embed_all,
normalizations_all,
others_all,
poolings_all,
Expand Down
17 changes: 10 additions & 7 deletions brainstate/nn/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,24 @@ class ExplicitInOutSize(Mixin):

@property
def in_size(self) -> Tuple[int, ...]:
if self._in_size is None:
raise ValueError(f"The input shape is not set in this node: {self} ")
return self._in_size

@in_size.setter
def in_size(self, in_size: Sequence[int]):
def in_size(self, in_size: Sequence[int] | int):
if isinstance(in_size, int):
in_size = (in_size,)
assert isinstance(in_size, (tuple, list)), f"Invalid type of in_size: {type(in_size)}"
self._in_size = tuple(in_size)

@property
def out_size(self) -> Tuple[int, ...]:
if self._out_size is None:
raise ValueError(f"The output shape is not set in this node: {self}")
return self._out_size

@out_size.setter
def out_size(self, out_size: Sequence[int]):
def out_size(self, out_size: Sequence[int] | int):
if isinstance(out_size, int):
out_size = (out_size,)
assert isinstance(out_size, (tuple, list)), f"Invalid type of out_size: {type(out_size)}"
self._out_size = tuple(out_size)


Expand Down Expand Up @@ -152,7 +154,8 @@ def __init__(self, first: ExplicitInOutSize, *modules_as_tuple, **modules_as_dic
self.children = visible_module_dict(self.format_elements(object, first, *tuple_modules, **dict_modules))

# the input and output shape
self.in_size = tuple(first.in_size)
if first.in_size is not None:
self.in_size = first.in_size
self.out_size = tuple(in_size)

def _format_module(self, module, in_size):
Expand Down
20 changes: 20 additions & 0 deletions brainstate/nn/_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def dv(self, v, t, x):
def init_state(self, batch_size: int = None, **kwargs):
self.V = ShortTermState(init.param(jnp.zeros, self.varshape, batch_size))

def reset_state(self, batch_size: int = None, **kwargs):
self.V.value = init.param(jnp.zeros, self.varshape, batch_size)

def get_spike(self, V=None):
V = self.V.value if V is None else V
v_scaled = (V - self.V_th) / self.V_th
Expand Down Expand Up @@ -160,6 +163,9 @@ def dv(self, v, t, x):
def init_state(self, batch_size: int = None, **kwargs):
self.V = ShortTermState(init.param(init.Constant(self.V_reset), self.varshape, batch_size))

def reset_state(self, batch_size: int = None, **kwargs):
self.V.value = init.param(init.Constant(self.V_reset), self.varshape, batch_size)

def get_spike(self, V=None):
V = self.V.value if V is None else V
v_scaled = (V - self.V_th) / self.V_th
Expand Down Expand Up @@ -214,6 +220,10 @@ def init_state(self, batch_size: int = None, **kwargs):
self.V = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
self.a = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))

def reset_state(self, batch_size: int = None, **kwargs):
self.V.value = init.param(init.Constant(0.), self.varshape, batch_size)
self.a.value = init.param(init.Constant(0.), self.varshape, batch_size)

def get_spike(self, V=None, a=None):
V = self.V.value if V is None else V
a = self.a.value if a is None else a
Expand Down Expand Up @@ -275,6 +285,9 @@ def dg(self, g, t):
def init_state(self, batch_size: int = None, **kwargs):
self.g = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))

def reset_state(self, batch_size: int = None, **kwargs):
self.g.value = init.param(init.Constant(0.), self.varshape, batch_size)

def update(self, x=None):
self.g.value = exp_euler_step(self.dg, self.g.value, environ.get('t'))
if x is not None:
Expand Down Expand Up @@ -325,6 +338,10 @@ def init_state(self, batch_size: int = None, **kwargs):
self.x = ShortTermState(init.param(init.Constant(1.), self.varshape, batch_size))
self.u = ShortTermState(init.param(init.Constant(self.U), self.varshape, batch_size))

def reset_state(self, batch_size: int = None, **kwargs):
self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)

def du(self, u, t):
return self.U - u / self.tau_f

Expand Down Expand Up @@ -390,6 +407,9 @@ def dx(self, x, t):
def init_state(self, batch_size: int = None, **kwargs):
self.x = ShortTermState(init.param(init.Constant(1.), self.varshape, batch_size))

def reset_state(self, batch_size: int = None, **kwargs):
self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)

def update(self, pre_spike):
t = environ.get('t')
x = exp_euler_step(self.dx, self.x.value, t)
Expand Down
66 changes: 66 additions & 0 deletions brainstate/nn/_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Optional, Callable, Union

from ._base import DnnLayer
from .. import init
from .._state import ParamState
from ..mixin import Mode, Training
from ..typing import ArrayLike

__all__ = [
'Embedding',
]


class Embedding(DnnLayer):
r"""
A simple lookup table that stores embeddings of a fixed size.
Args:
num_embeddings: Size of embedding dictionary. Must be non-negative.
embedding_size: Size of each embedding vector. Must be non-negative.
embed_init: The initializer for the embedding lookup table, of shape `(num_embeddings, embedding_size)`.
"""

def __init__(
self,
num_embeddings: int,
embedding_size: int,
embed_init: Union[Callable, ArrayLike] = init.LecunUniform(),
name: Optional[str] = None,
mode: Optional[Mode] = None,
):
super().__init__(name=name, mode=mode)
if num_embeddings < 0:
raise ValueError("num_embeddings must not be negative.")
if embedding_size < 0:
raise ValueError("embedding_size must not be negative.")
self.num_embeddings = num_embeddings
self.embedding_size = embedding_size
self.out_size = (embedding_size,)

weight = init.param(embed_init, (self.num_embeddings, self.embedding_size))
if self.mode.has(Training):
self.weight = ParamState(weight)
else:
self.weight = weight

def update(self, indices: ArrayLike):
if self.mode.has(Training):
return self.weight.value[indices]
return self.weight[indices]
17 changes: 17 additions & 0 deletions brainstate/nn/_rate_rnns.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def __init__(
def init_state(self, batch_size: int = None, **kwargs):
self.h = ShortTermState(init.param(self._state_initializer, self.num_out, batch_size))

def reset_state(self, batch_size: int = None, **kwargs):
self.h.value = init.param(self._state_initializer, self.num_out, batch_size)

def update(self, x):
xh = jnp.concatenate([x, self.h.value], axis=-1)
h = self.W(xh)
Expand Down Expand Up @@ -147,6 +150,9 @@ def __init__(
def init_state(self, batch_size: int = None, **kwargs):
self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))

def reset_state(self, batch_size: int = None, **kwargs):
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)

def update(self, x):
old_h = self.h.value
xh = jnp.concatenate([x, old_h], axis=-1)
Expand Down Expand Up @@ -224,6 +230,9 @@ def __init__(
def init_state(self, batch_size: int = None, **kwargs):
self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))

def reset_state(self, batch_size: int = None, **kwargs):
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)

def update(self, x):
old_h = self.h.value
xh = jnp.concatenate([x, old_h], axis=-1)
Expand Down Expand Up @@ -327,6 +336,10 @@ def init_state(self, batch_size: int = None, **kwargs):
self.c = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))

def reset_state(self, batch_size: int = None, **kwargs):
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)

def update(self, x):
h, c = self.h.value, self.c.value
xh = jnp.concat([x, h], axis=-1)
Expand Down Expand Up @@ -379,6 +392,10 @@ def init_state(self, batch_size: int = None, **kwargs):
self.c = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))

def reset_state(self, batch_size: int = None, **kwargs):
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)

def update(self, x: ArrayLike) -> ArrayLike:
h, c = self.h.value, self.c.value
xh = jnp.concat([x, h], axis=-1)
Expand Down
6 changes: 6 additions & 0 deletions brainstate/nn/_readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def __init__(
def init_state(self, batch_size=None, **kwargs):
self.r = ShortTermState(init.param(init.Constant(0.), self.out_size, batch_size))

def reset_state(self, batch_size=None, **kwargs):
self.r.value = init.param(init.Constant(0.), self.out_size, batch_size)

def update(self, x):
r = self.decay * self.r.value + x @ self.weight.value
self.r.value = r
Expand Down Expand Up @@ -109,6 +112,9 @@ def dv(self, v, t, x):
def init_state(self, batch_size, **kwargs):
self.V = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))

def reset_state(self, batch_size, **kwargs):
self.V.value = init.param(init.Constant(0.), self.varshape, batch_size)

@property
def spike(self):
return self.get_spike(self.V.value)
Expand Down
11 changes: 11 additions & 0 deletions docs/apis/brainstate.nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ Synaptic Projections
VanillaProj


Embedding Layers
-----------------

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

Embedding


Connection Layers
-----------------

Expand Down
Loading

0 comments on commit b430dea

Please sign in to comment.