From 90508b013486409c3c8f693baf2ab019727c8efc Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 22 Jun 2024 17:33:09 +0800 Subject: [PATCH] upgrade delay to support interoperation --- brainstate/_module.py | 178 +++++++++++++++++++++++++++++-------- brainstate/_module_test.py | 109 ++++++++++++++++++----- brainstate/mixin.py | 4 +- 3 files changed, 232 insertions(+), 59 deletions(-) diff --git a/brainstate/_module.py b/brainstate/_module.py index 88dc51a..2e03237 100644 --- a/brainstate/_module.py +++ b/brainstate/_module.py @@ -60,7 +60,7 @@ from ._state import State, StateDictManager, visible_state_dict from ._utils import set_module_as from .mixin import Mixin, Mode, DelayedInit, AllOfTypes, Batching, UpdateReturn -from .transform._jit_error import jit_error +from .transform import jit_error from .util import unique_name, DictManager, get_unique_name Shape = Union[int, Sequence[int]] @@ -68,8 +68,10 @@ ArrayLike = jax.typing.ArrayLike delay_identifier = '_*_delay_of_' -ROTATE_UPDATE = 'rotation' -CONCAT_UPDATE = 'concat' +_DELAY_ROTATE = 'rotation' +_DELAY_CONCAT = 'concat' +_INTERP_LINEAR = 'linear_interp' +_INTERP_ROUND = 'round' StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys']) @@ -1036,14 +1038,14 @@ class Delay(ExtendedUpdateWithBA, DelayedInit): delay = length data ] entries: optional, dict. The delay access entries. name: str. The delay name. - method: str. The method used for updating delay. Default None. + delay_method: str. The method used for updating delay. Default None. mode: Mode. The computing mode. Default None. """ __module__ = 'brainstate' - non_hash_params = ('time', 'entries', 'name') - max_time: float + non_hashable_params = ('time', 'entries', 'name') + max_time: float # max_length: int history: Optional[State] @@ -1051,9 +1053,10 @@ def __init__( self, target_info: PyTree, time: Optional[Union[int, float]] = None, # delay time - init: Optional[Union[ArrayLike, Callable]] = None, # delay data init + init: Optional[Union[ArrayLike, Callable]] = None, # delay data before t0 entries: Optional[Dict] = None, # delay access entry - method: Optional[str] = ROTATE_UPDATE, # delay method + delay_method: Optional[str] = _DELAY_ROTATE, # delay method + interp_method: str = _INTERP_LINEAR, # interpolation method # others name: Optional[str] = None, mode: Optional[Mode] = None, @@ -1063,8 +1066,14 @@ def __init__( self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info) # delay method - assert method in [ROTATE_UPDATE, CONCAT_UPDATE] - self.method = method + assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (f'Un-supported delay method {delay_method}. ' + f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}') + self.delay_method = delay_method + + # interp method + assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (f'Un-supported interpolation method {interp_method}. ' + f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}') + self.interp_method = interp_method # delay length and time self.max_time, delay_length = _get_delay(time, None) @@ -1074,7 +1083,8 @@ def __init__( # delay data if init is not None: - assert isinstance(init, (numbers.Number, jax.Array, Callable)) + if not isinstance(init, (numbers.Number, jax.Array, np.ndarray, Callable)): + raise TypeError(f'init should be Array, Callable, or None. But got {init}') self._init = init self._history = None @@ -1088,7 +1098,11 @@ def __init__( def __repr__(self): name = self.__class__.__name__ - return f'{name}(delay_length={self.max_length}, target_info={self.target_info}, method="{self.method}")' + return (f'{name}(' + f'delay_length={self.max_length}, ' + f'target_info={self.target_info}, ' + f'delay_method="{self.delay_method}", ' + f'interp_method="{self.interp_method}")') @property def history(self): @@ -1103,7 +1117,7 @@ def _f_to_init(self, a, batch_size, length): if batch_size is not None: shape.insert(self.mode.batch_axis, batch_size) shape.insert(0, length) - if isinstance(self._init, (jax.Array, numbers.Number)): + if isinstance(self._init, (jax.Array, np.ndarray, numbers.Number)): data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape) elif callable(self._init): data = self._init(shape, dtype=a.dtype) @@ -1130,7 +1144,8 @@ def register_entry( delay_time: Optional[Union[int, float]] = None, delay_step: Optional[int] = None, ) -> 'Delay': - """Register an entry to access the data. + """ + Register an entry to access the delay data. Args: entry: str. The entry to access the delay data. @@ -1160,7 +1175,8 @@ def register_entry( return self def at(self, entry: str, *indices) -> ArrayLike: - """Get the data at the given entry. + """ + Get the data at the given entry. Args: entry: str. The entry to access the data. @@ -1176,20 +1192,28 @@ def at(self, entry: str, *indices) -> ArrayLike: delay_step = self._registered_entries[entry] if delay_step is None: delay_step = 0 - return self.retrieve(delay_step, *indices) + return self.retrieve_at_step(delay_step, *indices) - def retrieve(self, delay_step, *indices): - """Retrieve the delay data according to the delay length. + def retrieve_at_step(self, delay_step, *indices) -> PyTree: + """ + Retrieve the delay data at the given delay time step (the integer to indicate the time step). Parameters ---------- - delay_step: int - The delay length used to retrieve the data. + delay_step: int_like + Retrieve the data at the given time step. + indices: tuple + The indices to slice the data. + + Returns + ------- + delay_data: The delay data at the given delay step. + """ assert self.history is not None, 'The delay history is not initialized.' assert delay_step is not None, 'The delay step should be given.' - if environ.get(environ.JIT_ERROR_CHECK, False): + if environ.get(environ.JIT_ERROR_CHECK, True): def _check_delay(delay_len): raise ValueError(f'The request delay length should be less than the ' f'maximum delay {self.max_length}. But we got {delay_len}') @@ -1197,17 +1221,17 @@ def _check_delay(delay_len): jit_error(delay_step >= self.max_length, _check_delay, delay_step) # rotation method - if self.method == ROTATE_UPDATE: - i = environ.get(environ.I) + if self.delay_method == _DELAY_ROTATE: + i = environ.get(environ.I, desc='The time step index.') di = i - delay_step delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32) delay_idx = jax.lax.stop_gradient(delay_idx) - elif self.method == CONCAT_UPDATE: + elif self.delay_method == _DELAY_CONCAT: delay_idx = delay_step else: - raise ValueError(f'Unknown updating method "{self.method}"') + raise ValueError(f'Unknown delay updating method "{self.delay_method}"') # the delay index if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): @@ -1217,6 +1241,78 @@ def _check_delay(delay_len): # the delay data return jax.tree.map(lambda a: a[indices], self.history.value) + def retrieve_at_time(self, delay_time, *indices) -> PyTree: + """ + Retrieve the delay data at the given delay time step (the integer to indicate the time step). + + Parameters + ---------- + delay_time: float + Retrieve the data at the given time. + indices: tuple + The indices to slice the data. + + Returns + ------- + delay_data: The delay data at the given delay step. + + """ + assert self.history is not None, 'The delay history is not initialized.' + assert delay_time is not None, 'The delay time should be given.' + + current_time = environ.get(environ.T, desc='The current time.') + dt = environ.get_dt() + + if environ.get(environ.JIT_ERROR_CHECK, True): + def _check_delay(args): + t_now, t_delay = args + raise ValueError(f'The request delay time should be within ' + f'[{t_now - self.max_time - dt}, {t_now}], ' + f'but we got {t_delay}') + + jit_error(jnp.logical_or(delay_time > current_time, + delay_time < current_time - self.max_time - dt), + _check_delay, + (current_time, delay_time,)) + + diff = current_time - delay_time + float_time_step = diff / dt + + if self.interp_method == _INTERP_LINEAR: # "linear" interpolation + def _interp(target): + if len(indices) > 0: + raise NotImplementedError('The slicing indices are not supported in the linear interpolation.') + + if self.delay_method == _DELAY_ROTATE: + i = environ.get(environ.I, desc='The time step index.') + _interp_fun = partial(jnp.interp, period=self.max_length) + for dim in range(1, target.ndim, 1): + _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1) + di = i - jnp.arange(self.max_length) + delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32) + return _interp_fun(float_time_step, delay_idx, target) + + elif self.delay_method == _DELAY_CONCAT: + _interp_fun = partial(jnp.interp, period=self.max_length) + for dim in range(1, target.ndim, 1): + _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1) + return _interp_fun(float_time_step, jnp.arange(self.max_length), target) + + else: + raise ValueError(f'Unknown delay updating method "{self.delay_method}"') + + return jax.tree.map(_interp, self.history.value) + + elif self.interp_method == _INTERP_ROUND: # "round" interpolation + return self.retrieve_at_step( + jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32), + *indices + ) + + else: # raise error + raise ValueError(f'Un-supported interpolation method {self.interp_method}, ' + f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}') + def update(self, current: PyTree) -> None: """ Update delay variable with the new data. @@ -1224,25 +1320,29 @@ def update(self, current: PyTree) -> None: assert self.history is not None, 'The delay history is not initialized.' # update the delay data at the rotation index - if self.method == ROTATE_UPDATE: + if self.delay_method == _DELAY_ROTATE: i = environ.get(environ.I) idx = jnp.asarray(i % self.max_length, dtype=environ.dutype()) idx = jax.lax.stop_gradient(idx) - self.history.value = jax.tree.map(lambda hist, cur: hist.at[idx].set(cur), - self.history.value, - current) + self.history.value = jax.tree.map( + lambda hist, cur: hist.at[idx].set(cur), + self.history.value, + current + ) # update the delay data at the first position - elif self.method == CONCAT_UPDATE: + elif self.delay_method == _DELAY_CONCAT: current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current) if self.max_length > 1: - self.history.value = jax.tree.map(lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0), - self.history.value, - current) + self.history.value = jax.tree.map( + lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0), + self.history.value, + current + ) else: self.history.value = current else: - raise ValueError(f'Unknown updating method "{self.method}"') + raise ValueError(f'Unknown updating method "{self.delay_method}"') class _StateDelay(Delay): @@ -1263,14 +1363,18 @@ def __init__( time: Optional[Union[int, float]] = None, # delay time init: Optional[Union[ArrayLike, Callable]] = None, # delay data init entries: Optional[Dict] = None, # delay access entry - method: Optional[str] = ROTATE_UPDATE, # delay method + delay_method: Optional[str] = _DELAY_ROTATE, # delay method # others name: Optional[str] = None, mode: Optional[Mode] = None, ): super().__init__(target_info=target.value, - time=time, init=init, entries=entries, - method=method, name=name, mode=mode) + time=time, + init=init, + entries=entries, + delay_method=delay_method, + name=name, + mode=mode) self.target = target def update(self, *args, **kwargs): diff --git a/brainstate/_module_test.py b/brainstate/_module_test.py index 392da9d..553664c 100644 --- a/brainstate/_module_test.py +++ b/brainstate/_module_test.py @@ -16,14 +16,15 @@ import unittest import jax.numpy as jnp +import jaxlib.xla_extension -import brainstate as bc +import brainstate as bst -class TestVarDelay(unittest.TestCase): +class TestDelay(unittest.TestCase): def test_delay1(self): - a = bc.State(bc.random.random(10, 20)) - delay = bc.Delay(a.value) + a = bst.State(bst.random.random(10, 20)) + delay = bst.Delay(a.value) delay.register_entry('a', 1.) delay.register_entry('b', 2.) delay.register_entry('c', None) @@ -31,10 +32,10 @@ def test_delay1(self): delay.init_state() with self.assertRaises(KeyError): delay.register_entry('c', 10.) - bc.util.clear_buffer_memory() + bst.util.clear_buffer_memory() def test_rotation_delay(self): - rotation_delay = bc.Delay(jnp.ones((1,))) + rotation_delay = bst.Delay(jnp.ones((1,))) t0 = 0. t1, n1 = 1., 10 t2, n2 = 2., 20 @@ -51,16 +52,16 @@ def test_rotation_delay(self): # print(rotation_delay.max_length) for i in range(100): - bc.environ.set(i=i) + bst.environ.set(i=i) rotation_delay(jnp.ones((1,)) * i) # print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c')) self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i)) self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.))) self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.))) - bc.util.clear_buffer_memory() + bst.util.clear_buffer_memory() def test_concat_delay(self): - rotation_delay = bc.Delay(jnp.ones([1]), method='concat') + rotation_delay = bst.Delay(jnp.ones([1]), delay_method='concat') t0 = 0. t1, n1 = 1., 10 t2, n2 = 2., 20 @@ -73,17 +74,85 @@ def test_concat_delay(self): print() for i in range(100): - bc.environ.set(i=i) + bst.environ.set(i=i) rotation_delay(jnp.ones((1,)) * i) print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c')) self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i)) self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.))) self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.))) - bc.util.clear_buffer_memory() + bst.util.clear_buffer_memory() + + def test_jit_erro(self): + rotation_delay = bst.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round') + rotation_delay.init_state() + + with bst.environ.context(i=0, t=0): + rotation_delay.retrieve_at_time(-2.0) + with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError): + rotation_delay.retrieve_at_time(-2.1) + rotation_delay.retrieve_at_time(-2.01) + with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError): + rotation_delay.retrieve_at_time(-2.09) + with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError): + rotation_delay.retrieve_at_time(0.1) + with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError): + rotation_delay.retrieve_at_time(0.01) + + def test_concat_delay_round_interp(self): + for shape in [(1,), (1, 1), (1, 1, 1)]: + for delay_method in ['rotation', 'concat']: + rotation_delay = bst.Delay(jnp.ones(shape), time=2., delay_method=delay_method, interp_method='round') + t0, n1 = 0.01, 0 + t1, n1 = 1.04, 10 + t2, n2 = 1.06, 11 + rotation_delay.init_state() + + print() + for i in range(100): + t = i * bst.environ.get_dt() + with bst.environ.context(i=i, t=t): + rotation_delay(jnp.ones(shape) * i) + print(i, + rotation_delay.retrieve_at_time(t - t0), + rotation_delay.retrieve_at_time(t - t1), + rotation_delay.retrieve_at_time(t - t2)) + self.assertTrue(jnp.allclose(rotation_delay.retrieve_at_time(t - t0), + jnp.ones(shape) * i)) + self.assertTrue(jnp.allclose(rotation_delay.retrieve_at_time(t - t1), + jnp.maximum(jnp.ones(shape) * i - n1, 0.))) + self.assertTrue(jnp.allclose(rotation_delay.retrieve_at_time(t - t2), + jnp.maximum(jnp.ones(shape) * i - n2, 0.))) + bst.util.clear_buffer_memory() + + def test_concat_delay_linear_interp(self): + for shape in [(1,), (1, 1), (1, 1, 1)]: + for delay_method in ['rotation', 'concat']: + rotation_delay = bst.Delay(jnp.ones(shape), time=2., delay_method=delay_method, interp_method='linear_interp') + t0, n0 = 0.01, 0.1 + t1, n1 = 1.04, 10.4 + t2, n2 = 1.06, 10.6 + rotation_delay.init_state() + + print() + for i in range(100): + t = i * bst.environ.get_dt() + with bst.environ.context(i=i, t=t): + rotation_delay(jnp.ones(shape) * i) + print(i, + rotation_delay.retrieve_at_time(t - t0), + rotation_delay.retrieve_at_time(t - t1), + rotation_delay.retrieve_at_time(t - t2)) + self.assertTrue(jnp.allclose(rotation_delay.retrieve_at_time(t - t0), + jnp.maximum(jnp.ones(shape) * i - n0, 0.))) + self.assertTrue(jnp.allclose(rotation_delay.retrieve_at_time(t - t1), + jnp.maximum(jnp.ones(shape) * i - n1, 0.))) + self.assertTrue(jnp.allclose(rotation_delay.retrieve_at_time(t - t2), + jnp.maximum(jnp.ones(shape) * i - n2, 0.))) + bst.util.clear_buffer_memory() def test_rotation_and_concat_delay(self): - rotation_delay = bc.Delay(jnp.ones((1,))) - concat_delay = bc.Delay(jnp.ones([1]), method='concat') + rotation_delay = bst.Delay(jnp.ones((1,))) + concat_delay = bst.Delay(jnp.ones([1]), delay_method='concat') t0 = 0. t1, n1 = 1., 10 t2, n2 = 2., 20 @@ -100,29 +169,29 @@ def test_rotation_and_concat_delay(self): print() for i in range(100): - bc.environ.set(i=i) + bst.environ.set(i=i) new = jnp.ones((1,)) * i rotation_delay(new) concat_delay(new) self.assertTrue(jnp.allclose(rotation_delay.at('a'), concat_delay.at('a'), )) self.assertTrue(jnp.allclose(rotation_delay.at('b'), concat_delay.at('b'), )) self.assertTrue(jnp.allclose(rotation_delay.at('c'), concat_delay.at('c'), )) - bc.util.clear_buffer_memory() + bst.util.clear_buffer_memory() class TestModule(unittest.TestCase): def test_states(self): - class A(bc.Module): + class A(bst.Module): def __init__(self): super().__init__() - self.a = bc.State(bc.random.random(10, 20)) - self.b = bc.State(bc.random.random(10, 20)) + self.a = bst.State(bst.random.random(10, 20)) + self.b = bst.State(bst.random.random(10, 20)) - class B(bc.Module): + class B(bst.Module): def __init__(self): super().__init__() self.a = A() - self.b = bc.State(bc.random.random(10, 20)) + self.b = bst.State(bst.random.random(10, 20)) b = B() print() diff --git a/brainstate/mixin.py b/brainstate/mixin.py index 821264f..67cef08 100644 --- a/brainstate/mixin.py +++ b/brainstate/mixin.py @@ -68,7 +68,7 @@ class DelayedInit(Mixin): Note this Mixin can be applied in any Python object. """ - non_hash_params: Optional[Sequence[str]] = None + non_hashable_params: Optional[Sequence[str]] = None @classmethod def delayed(cls, *args, **kwargs) -> 'DelayedInitializer': @@ -94,7 +94,7 @@ class DelayedInitializer(metaclass=NoSubclassMeta): """ def __init__(self, cls: T, *desc_tuple, **desc_dict): - self.cls = cls + self.cls: type = cls # arguments self.args = desc_tuple