diff --git a/brainstate/_module.py b/brainstate/_module.py index 82d5ce9..65f1ecc 100644 --- a/brainstate/_module.py +++ b/brainstate/_module.py @@ -1216,7 +1216,7 @@ def retrieve_at_step(self, delay_step, *indices) -> PyTree: if environ.get(environ.JIT_ERROR_CHECK, False): def _check_delay(delay_len): raise ValueError(f'The request delay length should be less than the ' - f'maximum delay {self.max_length}. But we got {delay_len}') + f'maximum delay {self.max_length - 1}. But we got {delay_len}') jit_error(delay_step >= self.max_length, _check_delay, delay_step) @@ -1264,8 +1264,7 @@ def retrieve_at_time(self, delay_time, *indices) -> PyTree: dt = environ.get_dt() if environ.get(environ.JIT_ERROR_CHECK, False): - def _check_delay(args): - t_now, t_delay = args + def _check_delay(t_now, t_delay): raise ValueError(f'The request delay time should be within ' f'[{t_now - self.max_time - dt}, {t_now}], ' f'but we got {t_delay}') @@ -1273,7 +1272,7 @@ def _check_delay(args): jit_error(jnp.logical_or(delay_time > current_time, delay_time < current_time - self.max_time - dt), _check_delay, - (current_time, delay_time)) + current_time, delay_time) diff = current_time - delay_time float_time_step = diff / dt diff --git a/brainstate/_module_test.py b/brainstate/_module_test.py index d78a53c..41c7f46 100644 --- a/brainstate/_module_test.py +++ b/brainstate/_module_test.py @@ -86,16 +86,17 @@ def test_jit_erro(self): rotation_delay = bst.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round') rotation_delay.init_state() - with bst.environ.context(i=0, t=0): - rotation_delay.retrieve_at_time(-2.0) - with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError): - rotation_delay.retrieve_at_time(-2.1) - rotation_delay.retrieve_at_time(-2.01) - with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError): + + with bst.environ.context(i=0, t=0, jit_error_check=True): + # rotation_delay.retrieve_at_time(-2.0) + # with self.assertRaises(ValueError): + # rotation_delay.retrieve_at_time(-2.1) + # rotation_delay.retrieve_at_time(-2.01) + # with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError): rotation_delay.retrieve_at_time(-2.09) - with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError): + # with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError): rotation_delay.retrieve_at_time(0.1) - with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError): + # with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError): rotation_delay.retrieve_at_time(0.01) def test_round_interp(self): diff --git a/brainstate/transform/_jit_error.py b/brainstate/transform/_jit_error.py index 4cb0b43..f47db71 100644 --- a/brainstate/transform/_jit_error.py +++ b/brainstate/transform/_jit_error.py @@ -15,14 +15,14 @@ from __future__ import annotations -from functools import wraps, partial +import functools +from functools import partial from typing import Callable, Union import jax from jax import numpy as jnp from jax.core import Primitive, ShapedArray -from jax.interpreters import batching, mlir, xla -from jax.lax import cond +from jax.interpreters import batching, mlir from brainstate._utils import set_module_as @@ -32,16 +32,39 @@ @set_module_as('brainstate.transform') -def remove_vmap(x, op='any'): +def remove_vmap(x, op: str = 'any'): if op == 'any': return _any_without_vmap(x) elif op == 'all': return _all_without_vmap(x) + elif op == 'none': + return _without_vmap(x) else: raise ValueError(f'Do not support type: {op}') -_any_no_vmap_prim = Primitive('any_no_vmap') +def _without_vmap(x): + return _no_vmap_prim.bind(x) + + +def _without_vmap_imp(x): + return x + + +def _without_vmap_abs(x): + return x + + +def _without_vmap_batch(x, batch_axes): + (x,) = x + return _without_vmap(x), batching.not_mapped + + +_no_vmap_prim = Primitive('no_vmap') +_no_vmap_prim.def_impl(_without_vmap_imp) +_no_vmap_prim.def_abstract_eval(_without_vmap_abs) +batching.primitive_batchers[_no_vmap_prim] = _without_vmap_batch +mlir.register_lowering(_no_vmap_prim, mlir.lower_fun(_without_vmap_imp, multiple_results=False)) def _any_without_vmap(x): @@ -61,13 +84,12 @@ def _any_without_vmap_batch(x, batch_axes): return _any_without_vmap(x), batching.not_mapped +_any_no_vmap_prim = Primitive('any_no_vmap') _any_no_vmap_prim.def_impl(_any_without_vmap_imp) _any_no_vmap_prim.def_abstract_eval(_any_without_vmap_abs) batching.primitive_batchers[_any_no_vmap_prim] = _any_without_vmap_batch mlir.register_lowering(_any_no_vmap_prim, mlir.lower_fun(_any_without_vmap_imp, multiple_results=False)) -_all_no_vmap_prim = Primitive('all_no_vmap') - def _all_without_vmap(x): return _all_no_vmap_prim.bind(x) @@ -86,47 +108,35 @@ def _all_without_vmap_batch(x, batch_axes): return _all_without_vmap(x), batching.not_mapped +_all_no_vmap_prim = Primitive('all_no_vmap') _all_no_vmap_prim.def_impl(_all_without_vmap_imp) _all_no_vmap_prim.def_abstract_eval(_all_without_vmap_abs) batching.primitive_batchers[_all_no_vmap_prim] = _all_without_vmap_batch -if hasattr(xla, "lower_fun"): - xla.register_translation(_all_no_vmap_prim, - xla.lower_fun(_all_without_vmap_imp, multiple_results=False, new_style=True)) mlir.register_lowering(_all_no_vmap_prim, mlir.lower_fun(_all_without_vmap_imp, multiple_results=False)) -def _err_jit_true_branch(err_fun, x): - jax.pure_callback(err_fun, None, x) +def _err_jit_true_branch(err_fun, args, kwargs): + jax.debug.callback(err_fun, *args, **kwargs) -def _err_jit_false_branch(x): +def _err_jit_false_branch(args, kwargs): pass -def _cond(err_fun, pred, err_arg): - @wraps(err_fun) - def true_err_fun(*arg): - err_fun(*arg) - - cond(pred, - partial(_err_jit_true_branch, true_err_fun), - _err_jit_false_branch, - err_arg) - - -def _error_msg(msg, *arg): - if len(arg) == 0: - raise ValueError(msg) - else: - raise ValueError(msg.format(arg)) +def _error_msg(msg, *arg, **kwargs): + if len(arg): + msg = msg % arg + if len(kwargs): + msg = msg.format(**kwargs) + raise ValueError(msg) @set_module_as('brainstate.transform') def jit_error( pred, err_fun: Union[Callable, str], - err_arg=None, - scope: str = 'any' + *err_args, + **err_kwargs, ): """ Check errors in a jit function. @@ -136,15 +146,15 @@ def jit_error( It can give a function which receive arguments that passed from the JIT variables and raise errors. - >>> def error(arg): - >>> raise ValueError(f'error {arg}') + >>> def error(x): + >>> raise ValueError(f'error {x}') >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,)) - >>> jit_error(x.sum() < 5., error, err_arg=x) + >>> jit_error(x.sum() < 5., error, x) Or, it can be a simple string message. >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,)) - >>> jit_error(x.sum() < 5., "Error: the sum is less than 5.") + >>> jit_error(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum()) Parameters @@ -153,19 +163,18 @@ def jit_error( The boolean prediction. err_fun: callable The error function, which raise errors. - err_arg: any + err_args: The arguments which passed into `err_f`. - scope: str - The scope of the error message. Can be None, 'all' or 'any'. + err_kwargs: + The keywords which passed into `err_f`. """ if isinstance(err_fun, str): err_fun = partial(_error_msg, err_fun) - if scope is None: - pred = pred - elif scope == 'all': - pred = remove_vmap(pred, 'all') - elif scope == 'any': - pred = remove_vmap(pred, 'any') - else: - raise ValueError(f"Unknown scope: {scope}") - _cond(err_fun, pred, err_arg) + + jax.lax.cond( + remove_vmap(pred, op='any'), + partial(_err_jit_true_branch, err_fun), + _err_jit_false_branch, + jax.tree.map(functools.partial(remove_vmap, op='none'), err_args), + jax.tree.map(functools.partial(remove_vmap, op='none'), err_kwargs), + ) diff --git a/brainstate/transform/_jit_error_test.py b/brainstate/transform/_jit_error_test.py new file mode 100644 index 0000000..b554d4d --- /dev/null +++ b/brainstate/transform/_jit_error_test.py @@ -0,0 +1,55 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import unittest + +import jax +import jaxlib.xla_extension +import jax.numpy as jnp + +import brainstate as bst + + +class TestJitError(unittest.TestCase): + def test1(self): + with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError): + bst.transform.jit_error(True, 'error') + + def err_f(x): + raise ValueError(f'error: {x}') + + with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError): + bst.transform.jit_error(True, err_f, 1.) + + def test_vmap(self): + + def f(x): + bst.transform.jit_error(x, 'error: {x}', x=x) + + jax.vmap(f)(jnp.array([False, False, False])) + with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError): + jax.vmap(f)(jnp.array([True, False, False])) + + def test_vmap_vmap(self): + + def f(x): + bst.transform.jit_error(x, 'error: {x}', x=x) + + jax.vmap(jax.vmap(f))(jnp.array([[False, False, False], + [False, False, False]])) + with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError): + jax.vmap(jax.vmap(f))(jnp.array([[False, False, False], + [True, False, False]])) +