Skip to content

Commit

Permalink
upgrade delay to support interoperation
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 22, 2024
1 parent d22cc8a commit 90508b0
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 59 deletions.
178 changes: 141 additions & 37 deletions brainstate/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,18 @@
from ._state import State, StateDictManager, visible_state_dict
from ._utils import set_module_as
from .mixin import Mixin, Mode, DelayedInit, AllOfTypes, Batching, UpdateReturn
from .transform._jit_error import jit_error
from .transform import jit_error
from .util import unique_name, DictManager, get_unique_name

Shape = Union[int, Sequence[int]]
PyTree = Any
ArrayLike = jax.typing.ArrayLike

delay_identifier = '_*_delay_of_'
ROTATE_UPDATE = 'rotation'
CONCAT_UPDATE = 'concat'
_DELAY_ROTATE = 'rotation'
_DELAY_CONCAT = 'concat'
_INTERP_LINEAR = 'linear_interp'
_INTERP_ROUND = 'round'

StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])

Expand Down Expand Up @@ -1036,24 +1038,25 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
delay = length data ]
entries: optional, dict. The delay access entries.
name: str. The delay name.
method: str. The method used for updating delay. Default None.
delay_method: str. The method used for updating delay. Default None.
mode: Mode. The computing mode. Default None.
"""

__module__ = 'brainstate'

non_hash_params = ('time', 'entries', 'name')
max_time: float
non_hashable_params = ('time', 'entries', 'name')
max_time: float #
max_length: int
history: Optional[State]

def __init__(
self,
target_info: PyTree,
time: Optional[Union[int, float]] = None, # delay time
init: Optional[Union[ArrayLike, Callable]] = None, # delay data init
init: Optional[Union[ArrayLike, Callable]] = None, # delay data before t0
entries: Optional[Dict] = None, # delay access entry
method: Optional[str] = ROTATE_UPDATE, # delay method
delay_method: Optional[str] = _DELAY_ROTATE, # delay method
interp_method: str = _INTERP_LINEAR, # interpolation method
# others
name: Optional[str] = None,
mode: Optional[Mode] = None,
Expand All @@ -1063,8 +1066,14 @@ def __init__(
self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)

# delay method
assert method in [ROTATE_UPDATE, CONCAT_UPDATE]
self.method = method
assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (f'Un-supported delay method {delay_method}. '
f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}')
self.delay_method = delay_method

# interp method
assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (f'Un-supported interpolation method {interp_method}. '
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
self.interp_method = interp_method

# delay length and time
self.max_time, delay_length = _get_delay(time, None)
Expand All @@ -1074,7 +1083,8 @@ def __init__(

# delay data
if init is not None:
assert isinstance(init, (numbers.Number, jax.Array, Callable))
if not isinstance(init, (numbers.Number, jax.Array, np.ndarray, Callable)):
raise TypeError(f'init should be Array, Callable, or None. But got {init}')
self._init = init
self._history = None

Expand All @@ -1088,7 +1098,11 @@ def __init__(

def __repr__(self):
name = self.__class__.__name__
return f'{name}(delay_length={self.max_length}, target_info={self.target_info}, method="{self.method}")'
return (f'{name}('
f'delay_length={self.max_length}, '
f'target_info={self.target_info}, '
f'delay_method="{self.delay_method}", '
f'interp_method="{self.interp_method}")')

@property
def history(self):
Expand All @@ -1103,7 +1117,7 @@ def _f_to_init(self, a, batch_size, length):
if batch_size is not None:
shape.insert(self.mode.batch_axis, batch_size)
shape.insert(0, length)
if isinstance(self._init, (jax.Array, numbers.Number)):
if isinstance(self._init, (jax.Array, np.ndarray, numbers.Number)):
data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
elif callable(self._init):
data = self._init(shape, dtype=a.dtype)
Expand All @@ -1130,7 +1144,8 @@ def register_entry(
delay_time: Optional[Union[int, float]] = None,
delay_step: Optional[int] = None,
) -> 'Delay':
"""Register an entry to access the data.
"""
Register an entry to access the delay data.
Args:
entry: str. The entry to access the delay data.
Expand Down Expand Up @@ -1160,7 +1175,8 @@ def register_entry(
return self

def at(self, entry: str, *indices) -> ArrayLike:
"""Get the data at the given entry.
"""
Get the data at the given entry.
Args:
entry: str. The entry to access the data.
Expand All @@ -1176,38 +1192,46 @@ def at(self, entry: str, *indices) -> ArrayLike:
delay_step = self._registered_entries[entry]
if delay_step is None:
delay_step = 0
return self.retrieve(delay_step, *indices)
return self.retrieve_at_step(delay_step, *indices)

def retrieve(self, delay_step, *indices):
"""Retrieve the delay data according to the delay length.
def retrieve_at_step(self, delay_step, *indices) -> PyTree:
"""
Retrieve the delay data at the given delay time step (the integer to indicate the time step).
Parameters
----------
delay_step: int
The delay length used to retrieve the data.
delay_step: int_like
Retrieve the data at the given time step.
indices: tuple
The indices to slice the data.
Returns
-------
delay_data: The delay data at the given delay step.
"""
assert self.history is not None, 'The delay history is not initialized.'
assert delay_step is not None, 'The delay step should be given.'

if environ.get(environ.JIT_ERROR_CHECK, False):
if environ.get(environ.JIT_ERROR_CHECK, True):
def _check_delay(delay_len):
raise ValueError(f'The request delay length should be less than the '
f'maximum delay {self.max_length}. But we got {delay_len}')

jit_error(delay_step >= self.max_length, _check_delay, delay_step)

# rotation method
if self.method == ROTATE_UPDATE:
i = environ.get(environ.I)
if self.delay_method == _DELAY_ROTATE:
i = environ.get(environ.I, desc='The time step index.')
di = i - delay_step
delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
delay_idx = jax.lax.stop_gradient(delay_idx)

elif self.method == CONCAT_UPDATE:
elif self.delay_method == _DELAY_CONCAT:
delay_idx = delay_step

else:
raise ValueError(f'Unknown updating method "{self.method}"')
raise ValueError(f'Unknown delay updating method "{self.delay_method}"')

# the delay index
if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
Expand All @@ -1217,32 +1241,108 @@ def _check_delay(delay_len):
# the delay data
return jax.tree.map(lambda a: a[indices], self.history.value)

def retrieve_at_time(self, delay_time, *indices) -> PyTree:
"""
Retrieve the delay data at the given delay time step (the integer to indicate the time step).
Parameters
----------
delay_time: float
Retrieve the data at the given time.
indices: tuple
The indices to slice the data.
Returns
-------
delay_data: The delay data at the given delay step.
"""
assert self.history is not None, 'The delay history is not initialized.'
assert delay_time is not None, 'The delay time should be given.'

current_time = environ.get(environ.T, desc='The current time.')
dt = environ.get_dt()

if environ.get(environ.JIT_ERROR_CHECK, True):
def _check_delay(args):
t_now, t_delay = args
raise ValueError(f'The request delay time should be within '
f'[{t_now - self.max_time - dt}, {t_now}], '
f'but we got {t_delay}')

jit_error(jnp.logical_or(delay_time > current_time,
delay_time < current_time - self.max_time - dt),
_check_delay,
(current_time, delay_time,))

diff = current_time - delay_time
float_time_step = diff / dt

if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
def _interp(target):
if len(indices) > 0:
raise NotImplementedError('The slicing indices are not supported in the linear interpolation.')

if self.delay_method == _DELAY_ROTATE:
i = environ.get(environ.I, desc='The time step index.')
_interp_fun = partial(jnp.interp, period=self.max_length)
for dim in range(1, target.ndim, 1):
_interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
di = i - jnp.arange(self.max_length)
delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
return _interp_fun(float_time_step, delay_idx, target)

elif self.delay_method == _DELAY_CONCAT:
_interp_fun = partial(jnp.interp, period=self.max_length)
for dim in range(1, target.ndim, 1):
_interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
return _interp_fun(float_time_step, jnp.arange(self.max_length), target)

else:
raise ValueError(f'Unknown delay updating method "{self.delay_method}"')

return jax.tree.map(_interp, self.history.value)

elif self.interp_method == _INTERP_ROUND: # "round" interpolation
return self.retrieve_at_step(
jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32),
*indices
)

else: # raise error
raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')

def update(self, current: PyTree) -> None:
"""
Update delay variable with the new data.
"""
assert self.history is not None, 'The delay history is not initialized.'

# update the delay data at the rotation index
if self.method == ROTATE_UPDATE:
if self.delay_method == _DELAY_ROTATE:
i = environ.get(environ.I)
idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
idx = jax.lax.stop_gradient(idx)
self.history.value = jax.tree.map(lambda hist, cur: hist.at[idx].set(cur),
self.history.value,
current)
self.history.value = jax.tree.map(
lambda hist, cur: hist.at[idx].set(cur),
self.history.value,
current
)
# update the delay data at the first position
elif self.method == CONCAT_UPDATE:
elif self.delay_method == _DELAY_CONCAT:
current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
if self.max_length > 1:
self.history.value = jax.tree.map(lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
self.history.value,
current)
self.history.value = jax.tree.map(
lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
self.history.value,
current
)
else:
self.history.value = current

else:
raise ValueError(f'Unknown updating method "{self.method}"')
raise ValueError(f'Unknown updating method "{self.delay_method}"')


class _StateDelay(Delay):
Expand All @@ -1263,14 +1363,18 @@ def __init__(
time: Optional[Union[int, float]] = None, # delay time
init: Optional[Union[ArrayLike, Callable]] = None, # delay data init
entries: Optional[Dict] = None, # delay access entry
method: Optional[str] = ROTATE_UPDATE, # delay method
delay_method: Optional[str] = _DELAY_ROTATE, # delay method
# others
name: Optional[str] = None,
mode: Optional[Mode] = None,
):
super().__init__(target_info=target.value,
time=time, init=init, entries=entries,
method=method, name=name, mode=mode)
time=time,
init=init,
entries=entries,
delay_method=delay_method,
name=name,
mode=mode)
self.target = target

def update(self, *args, **kwargs):
Expand Down
Loading

0 comments on commit 90508b0

Please sign in to comment.