Skip to content

Commit

Permalink
update codes
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Dec 5, 2024
1 parent d916c59 commit bd2ca79
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 62 deletions.
16 changes: 15 additions & 1 deletion brainstate/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def catch_new_states(tag: str = None) -> List:


class Catcher:
"""
The catcher to catch the new states.
"""
def __init__(self, tag: str):
self.tag = tag
self.state_ids = set()
Expand Down Expand Up @@ -231,6 +234,7 @@ def __init__(
# avoid using self._setattr to avoid the check
vars(self).update(metadata)

# record the state initialization
record_state_init(self)

if not TYPE_CHECKING:
Expand Down Expand Up @@ -290,7 +294,6 @@ def value(self, v) -> None:
v: The value.
"""
self.write_value(v)
self._been_writen = True

def write_value(self, v) -> None:
# value checking
Expand All @@ -301,6 +304,8 @@ def write_value(self, v) -> None:
record_state_value_write(self)
# set the value
self._value = v
# set flag
self._been_writen = True

def restore_value(self, v) -> None:
"""
Expand Down Expand Up @@ -511,6 +516,15 @@ class LongTermState(State):
__module__ = 'brainstate'


class BatchState(LongTermState):
"""
The batch state, which is used to store the batch data in the program.
"""

__module__ = 'brainstate'



class HiddenState(ShortTermState):
"""
The hidden state, which is used to store the hidden data in a dynamic model.
Expand Down
84 changes: 84 additions & 0 deletions brainstate/augment/_mapping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,91 @@ def mul(foo):
self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4)))
self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))

write_state_ids = [id(st) for st in trace.get_write_states()]
read_state_ids = [id(st) for st in trace.get_read_states()]

assert id(foo.a) in read_state_ids
assert id(foo.b) in write_state_ids

print(trace.get_write_states())
print(trace.get_read_states())



def test_vmap_jit(self):
class Foo(bst.nn.Module):
def __init__(self):
super().__init__()
self.a = bst.ParamState(jnp.arange(4))
self.b = bst.ShortTermState(jnp.arange(4))

def __call__(self):
self.b.value = self.a.value * self.b.value

@bst.augment.vmap
def mul(foo):
foo()

@bst.compile.jit
def mul_jit(inp):
mul(foo)
foo.a.value += inp

foo = Foo()
with bst.StateTraceStack() as trace:
mul_jit(1.)

print(foo.a.value)
print(foo.b.value)
self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))

write_state_ids = [id(st) for st in trace.get_write_states()]
read_state_ids = [id(st) for st in trace.get_read_states()]

assert id(foo.a) in write_state_ids
assert id(foo.b) in write_state_ids

print(trace.get_write_states())
print(trace.get_read_states())


def test_vmap_grad(self):
class Foo(bst.nn.Module):
def __init__(self):
super().__init__()
self.a = bst.ParamState(jnp.arange(4.))
self.b = bst.ShortTermState(jnp.arange(4.))

def __call__(self):
self.b.value = self.a.value * self.b.value

@bst.augment.vmap
def mul(foo):
foo()

def loss():
mul(foo)
return jnp.sum(foo.b.value)

foo = Foo()
with bst.StateTraceStack() as trace:
grads, loss = bst.augment.grad(loss, foo.states(bst.ParamState), return_value=True)()
print(grads)
print(loss)

# print(foo.a.value)
# print(foo.b.value)
# self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
# self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
#
# write_state_ids = [id(st) for st in trace.get_write_states()]
# read_state_ids = [id(st) for st in trace.get_read_states()]
#
# assert id(foo.a) in write_state_ids
# assert id(foo.b) in write_state_ids
#
# print(trace.get_write_states())
# print(trace.get_read_states())


5 changes: 4 additions & 1 deletion brainstate/graph/_graph_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,10 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
if isinstance(value, TreefyState):
variable.update_from_ref(value)
elif isinstance(value, State):
variable.restore_value(value.value)
if value._been_writen:
variable.write_value(value.value)
else:
variable.restore_value(value.value)
else:
raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
else: # if it doesn't, create a new variable
Expand Down
49 changes: 34 additions & 15 deletions brainstate/nn/_collective_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@

import jax

from brainstate._state import catch_new_states
from brainstate._utils import set_module_as
from brainstate.graph import nodes
from brainstate.util._filter import Filter
from ._module import Module

# the maximum order
Expand Down Expand Up @@ -74,34 +76,50 @@ def wrap(fun: Callable):


@set_module_as('brainstate.nn')
def init_all_states(target: T, *args, exclude=None, **kwargs) -> T:
def init_all_states(
target: T,
*args,
exclude: Filter = None,
tag: str = None,
**kwargs
) -> T:
"""
Collectively initialize states of all children nodes in the given target.
Args:
target: The target Module.
exclude: The filter to exclude some nodes.
tag: The tag for the new states.
args: The positional arguments for the initialization, which will be passed to the `init_state` method
of each node.
kwargs: The keyword arguments for the initialization, which will be passed to the `init_state` method
of each node.
Returns:
The target Module.
"""
nodes_with_order = []

nodes_ = nodes(target).filter(Module)
if exclude is not None:
nodes_ = nodes_ - nodes_.filter(exclude)
with catch_new_states(tag=tag):

# reset node whose `init_state` has no `call_order`
for node in list(nodes_.values()):
if hasattr(node.init_state, 'call_order'):
nodes_with_order.append(node)
else:
node.init_state(*args, **kwargs)
# node that has `call_order` decorated
nodes_with_order = []

# reset the node's states
for node in sorted(nodes_with_order, key=lambda x: x.init_state.call_order):
node.init_state(*args, **kwargs)
nodes_ = nodes(target).filter(Module)
if exclude is not None:
nodes_ = nodes_ - nodes_.filter(exclude)

return target
# reset node whose `init_state` has no `call_order`
for node in list(nodes_.values()):
if hasattr(node.init_state, 'call_order'):
nodes_with_order.append(node)
else:
node.init_state(*args, **kwargs)

# reset the node's states with `call_order`
for node in sorted(nodes_with_order, key=lambda x: x.init_state.call_order):
node.init_state(*args, **kwargs)

return target


@set_module_as('brainstate.nn')
Expand All @@ -115,6 +133,7 @@ def reset_all_states(target: Module, *args, **kwargs) -> Module:
Returns:
The target Module.
"""

nodes_with_order = []

# reset node whose `init_state` has no `call_order`
Expand Down
45 changes: 27 additions & 18 deletions brainstate/nn/_elementwise/_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

from functools import partial
from typing import Optional
from typing import Optional, Sequence

import brainunit as u
import jax.numpy as jnp
Expand All @@ -29,7 +29,6 @@

__all__ = [
'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',
'AlphaDropout', 'FeatureAlphaDropout',
]


Expand All @@ -47,30 +46,38 @@ class Dropout(ElementWiseBlock):
research 15.1 (2014): 1929-1958.
Args:
prob: Probability to keep element of the tensor.
mode: Mode. The computation mode of the object.
name: str. The name of the dynamic system.
prob: Probability to keep element of the tensor.
broadcast_dims: dimensions that will share the same dropout mask.
name: str. The name of the dynamic system.
"""
__module__ = 'brainstate.nn'

def __init__(
self,
prob: float = 0.5,
broadcast_dims: Sequence[int] = (),
name: Optional[str] = None
) -> None:
super().__init__(name=name)
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
self.prob = prob
self.broadcast_dims = broadcast_dims

def __call__(self, x):
dtype = u.math.get_dtype(x)
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
if fit_phase and self.prob < 1.:
keep_mask = random.bernoulli(self.prob, x.shape)
return jnp.where(keep_mask,
jnp.asarray(x / self.prob, dtype=dtype),
jnp.asarray(0., dtype=dtype))
broadcast_shape = list(x.shape)
for dim in self.broadcast_dims:
broadcast_shape[dim] = 1
keep_mask = random.bernoulli(self.prob, broadcast_shape)
keep_mask = jnp.broadcast_to(keep_mask, x.shape)
return jnp.where(
keep_mask,
jnp.asarray(x / self.prob, dtype=dtype),
jnp.asarray(0., dtype=dtype)
)
else:
return x

Expand All @@ -93,7 +100,6 @@ def __init__(
self.channel_axis = channel_axis

def __call__(self, x):

# check input shape
inp_dim = u.math.ndim(x)
if inp_dim not in (self.minimal_dim, self.minimal_dim + 1):
Expand All @@ -114,10 +120,13 @@ def __call__(self, x):
# generate mask
if fit_phase and self.prob < 1.:
dtype = u.math.get_dtype(x)
keep_mask = jnp.broadcast_to(random.bernoulli(self.prob, mask_shape), x.shape)
return jnp.where(keep_mask,
jnp.asarray(x / self.prob, dtype=dtype),
jnp.asarray(0., dtype=dtype))
keep_mask = random.bernoulli(self.prob, mask_shape)
keep_mask = jnp.broadcast_to(keep_mask, x.shape)
return jnp.where(
keep_mask,
jnp.asarray(x / self.prob, dtype=dtype),
jnp.asarray(0., dtype=dtype)
)
else:
return x

Expand Down Expand Up @@ -296,8 +305,8 @@ class AlphaDropout(_DropoutNd):
"""
__module__ = 'brainstate.nn'

def forward(self, x):
return F.alpha_dropout(x, self.p, self.training)
def update(self, *args, **kwargs):
raise NotImplementedError("AlphaDropout is not supported in the current version.")


class FeatureAlphaDropout(_DropoutNd):
Expand Down Expand Up @@ -344,8 +353,8 @@ class FeatureAlphaDropout(_DropoutNd):
"""
__module__ = 'brainstate.nn'

def forward(self, x):
return F.feature_alpha_dropout(x, self.p, self.training)
def update(self, *args, **kwargs):
raise NotImplementedError("FeatureAlphaDropout is not supported in the current version.")


class DropoutFixed(ElementWiseBlock):
Expand Down
Loading

0 comments on commit bd2ca79

Please sign in to comment.