Skip to content

Commit

Permalink
rename State._check_value() to State._check_value_tree()
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 27, 2024
1 parent 20a9853 commit b479ab5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
4 changes: 2 additions & 2 deletions brainstate/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ def retrieve_at_step(self, delay_step, *indices) -> PyTree:
assert self.history is not None, 'The delay history is not initialized.'
assert delay_step is not None, 'The delay step should be given.'

if environ.get(environ.JIT_ERROR_CHECK, True):
if environ.get(environ.JIT_ERROR_CHECK, False):
def _check_delay(delay_len):
raise ValueError(f'The request delay length should be less than the '
f'maximum delay {self.max_length}. But we got {delay_len}')
Expand Down Expand Up @@ -1263,7 +1263,7 @@ def retrieve_at_time(self, delay_time, *indices) -> PyTree:
current_time = environ.get(environ.T, desc='The current time.')
dt = environ.get_dt()

if environ.get(environ.JIT_ERROR_CHECK, True):
if environ.get(environ.JIT_ERROR_CHECK, False):
def _check_delay(args):
t_now, t_delay = args
raise ValueError(f'The request delay time should be within '
Expand Down
14 changes: 9 additions & 5 deletions brainstate/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,17 @@ def value(self, v) -> None:
"""
# value checking
v = v.value if isinstance(v, State) else v
self._check_value(v)
self._check_value_tree(v)
# write the value by the stack (>= level)
trace: StateTrace
for trace in thread_local_stack.stack[self._level:]:
trace.write_its_value(self)
# set the value
self._value = v

def _check_value(self, v):
def _check_value_tree(self, v):
if self._check_tree or _global_context_to_check_state_tree[-1]:
in_tree = jax.tree_util.tree_structure(v)
in_tree = jax.tree.structure(v)
if in_tree != self._tree:
self._raise_error_with_source_info(
ValueError(f'The given value {in_tree} does not '
Expand Down Expand Up @@ -370,20 +370,24 @@ def write_its_value(self, state: State) -> None:
self.types[index] = 'write'
self._written_ids.add(id_)

def collect_values(self, *categories: str) -> Tuple:
def collect_values(self, *categories: str, check_val_tree: bool = False) -> Tuple:
"""
Collect the values by the given categories.
Args:
*categories: The categories.
check_val_tree: Whether to check the tree structure of the value.
Returns:
results: The values.
"""
results = []
for st, ty in zip(self.states, self.types):
if ty in categories:
results.append(st.value)
val = st.value
if check_val_tree:
st._check_value_tree(val)
results.append(val)
return tuple(results)

def recovery_original_values(self) -> None:
Expand Down

0 comments on commit b479ab5

Please sign in to comment.