diff --git a/brainstate/_module.py b/brainstate/_module.py index 445d16d..82d5ce9 100644 --- a/brainstate/_module.py +++ b/brainstate/_module.py @@ -1213,7 +1213,7 @@ def retrieve_at_step(self, delay_step, *indices) -> PyTree: assert self.history is not None, 'The delay history is not initialized.' assert delay_step is not None, 'The delay step should be given.' - if environ.get(environ.JIT_ERROR_CHECK, True): + if environ.get(environ.JIT_ERROR_CHECK, False): def _check_delay(delay_len): raise ValueError(f'The request delay length should be less than the ' f'maximum delay {self.max_length}. But we got {delay_len}') @@ -1263,7 +1263,7 @@ def retrieve_at_time(self, delay_time, *indices) -> PyTree: current_time = environ.get(environ.T, desc='The current time.') dt = environ.get_dt() - if environ.get(environ.JIT_ERROR_CHECK, True): + if environ.get(environ.JIT_ERROR_CHECK, False): def _check_delay(args): t_now, t_delay = args raise ValueError(f'The request delay time should be within ' diff --git a/brainstate/_state.py b/brainstate/_state.py index 6dfdd5a..a594ad6 100644 --- a/brainstate/_state.py +++ b/brainstate/_state.py @@ -141,7 +141,7 @@ def value(self, v) -> None: """ # value checking v = v.value if isinstance(v, State) else v - self._check_value(v) + self._check_value_tree(v) # write the value by the stack (>= level) trace: StateTrace for trace in thread_local_stack.stack[self._level:]: @@ -149,9 +149,9 @@ def value(self, v) -> None: # set the value self._value = v - def _check_value(self, v): + def _check_value_tree(self, v): if self._check_tree or _global_context_to_check_state_tree[-1]: - in_tree = jax.tree_util.tree_structure(v) + in_tree = jax.tree.structure(v) if in_tree != self._tree: self._raise_error_with_source_info( ValueError(f'The given value {in_tree} does not ' @@ -370,12 +370,13 @@ def write_its_value(self, state: State) -> None: self.types[index] = 'write' self._written_ids.add(id_) - def collect_values(self, *categories: str) -> Tuple: + def collect_values(self, *categories: str, check_val_tree: bool = False) -> Tuple: """ Collect the values by the given categories. Args: *categories: The categories. + check_val_tree: Whether to check the tree structure of the value. Returns: results: The values. @@ -383,7 +384,10 @@ def collect_values(self, *categories: str) -> Tuple: results = [] for st, ty in zip(self.states, self.types): if ty in categories: - results.append(st.value) + val = st.value + if check_val_tree: + st._check_value_tree(val) + results.append(val) return tuple(results) def recovery_original_values(self) -> None: