Skip to content

Commit

Permalink
update functions and tests of brainstate.transform.jit_error`
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 27, 2024
1 parent b479ab5 commit 6c0e311
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 59 deletions.
7 changes: 3 additions & 4 deletions brainstate/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,7 @@ def retrieve_at_step(self, delay_step, *indices) -> PyTree:
if environ.get(environ.JIT_ERROR_CHECK, False):
def _check_delay(delay_len):
raise ValueError(f'The request delay length should be less than the '
f'maximum delay {self.max_length}. But we got {delay_len}')
f'maximum delay {self.max_length - 1}. But we got {delay_len}')

jit_error(delay_step >= self.max_length, _check_delay, delay_step)

Expand Down Expand Up @@ -1264,16 +1264,15 @@ def retrieve_at_time(self, delay_time, *indices) -> PyTree:
dt = environ.get_dt()

if environ.get(environ.JIT_ERROR_CHECK, False):
def _check_delay(args):
t_now, t_delay = args
def _check_delay(t_now, t_delay):
raise ValueError(f'The request delay time should be within '
f'[{t_now - self.max_time - dt}, {t_now}], '
f'but we got {t_delay}')

jit_error(jnp.logical_or(delay_time > current_time,
delay_time < current_time - self.max_time - dt),
_check_delay,
(current_time, delay_time))
current_time, delay_time)

diff = current_time - delay_time
float_time_step = diff / dt
Expand Down
17 changes: 9 additions & 8 deletions brainstate/_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,17 @@ def test_jit_erro(self):
rotation_delay = bst.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
rotation_delay.init_state()

with bst.environ.context(i=0, t=0):
rotation_delay.retrieve_at_time(-2.0)
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
rotation_delay.retrieve_at_time(-2.1)
rotation_delay.retrieve_at_time(-2.01)
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):

with bst.environ.context(i=0, t=0, jit_error_check=True):
# rotation_delay.retrieve_at_time(-2.0)
# with self.assertRaises(ValueError):
# rotation_delay.retrieve_at_time(-2.1)
# rotation_delay.retrieve_at_time(-2.01)
# with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
rotation_delay.retrieve_at_time(-2.09)
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
# with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
rotation_delay.retrieve_at_time(0.1)
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
# with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
rotation_delay.retrieve_at_time(0.01)

def test_round_interp(self):
Expand Down
103 changes: 56 additions & 47 deletions brainstate/transform/_jit_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@

from __future__ import annotations

from functools import wraps, partial
import functools
from functools import partial
from typing import Callable, Union

import jax
from jax import numpy as jnp
from jax.core import Primitive, ShapedArray
from jax.interpreters import batching, mlir, xla
from jax.lax import cond
from jax.interpreters import batching, mlir

from brainstate._utils import set_module_as

Expand All @@ -32,16 +32,39 @@


@set_module_as('brainstate.transform')
def remove_vmap(x, op='any'):
def remove_vmap(x, op: str = 'any'):
if op == 'any':
return _any_without_vmap(x)
elif op == 'all':
return _all_without_vmap(x)
elif op == 'none':
return _without_vmap(x)
else:
raise ValueError(f'Do not support type: {op}')


_any_no_vmap_prim = Primitive('any_no_vmap')
def _without_vmap(x):
return _no_vmap_prim.bind(x)


def _without_vmap_imp(x):
return x


def _without_vmap_abs(x):
return x


def _without_vmap_batch(x, batch_axes):
(x,) = x
return _without_vmap(x), batching.not_mapped


_no_vmap_prim = Primitive('no_vmap')
_no_vmap_prim.def_impl(_without_vmap_imp)
_no_vmap_prim.def_abstract_eval(_without_vmap_abs)
batching.primitive_batchers[_no_vmap_prim] = _without_vmap_batch
mlir.register_lowering(_no_vmap_prim, mlir.lower_fun(_without_vmap_imp, multiple_results=False))


def _any_without_vmap(x):
Expand All @@ -61,13 +84,12 @@ def _any_without_vmap_batch(x, batch_axes):
return _any_without_vmap(x), batching.not_mapped


_any_no_vmap_prim = Primitive('any_no_vmap')
_any_no_vmap_prim.def_impl(_any_without_vmap_imp)
_any_no_vmap_prim.def_abstract_eval(_any_without_vmap_abs)
batching.primitive_batchers[_any_no_vmap_prim] = _any_without_vmap_batch
mlir.register_lowering(_any_no_vmap_prim, mlir.lower_fun(_any_without_vmap_imp, multiple_results=False))

_all_no_vmap_prim = Primitive('all_no_vmap')


def _all_without_vmap(x):
return _all_no_vmap_prim.bind(x)
Expand All @@ -86,47 +108,35 @@ def _all_without_vmap_batch(x, batch_axes):
return _all_without_vmap(x), batching.not_mapped


_all_no_vmap_prim = Primitive('all_no_vmap')
_all_no_vmap_prim.def_impl(_all_without_vmap_imp)
_all_no_vmap_prim.def_abstract_eval(_all_without_vmap_abs)
batching.primitive_batchers[_all_no_vmap_prim] = _all_without_vmap_batch
if hasattr(xla, "lower_fun"):
xla.register_translation(_all_no_vmap_prim,
xla.lower_fun(_all_without_vmap_imp, multiple_results=False, new_style=True))
mlir.register_lowering(_all_no_vmap_prim, mlir.lower_fun(_all_without_vmap_imp, multiple_results=False))


def _err_jit_true_branch(err_fun, x):
jax.pure_callback(err_fun, None, x)
def _err_jit_true_branch(err_fun, args, kwargs):
jax.debug.callback(err_fun, *args, **kwargs)


def _err_jit_false_branch(x):
def _err_jit_false_branch(args, kwargs):
pass


def _cond(err_fun, pred, err_arg):
@wraps(err_fun)
def true_err_fun(*arg):
err_fun(*arg)

cond(pred,
partial(_err_jit_true_branch, true_err_fun),
_err_jit_false_branch,
err_arg)


def _error_msg(msg, *arg):
if len(arg) == 0:
raise ValueError(msg)
else:
raise ValueError(msg.format(arg))
def _error_msg(msg, *arg, **kwargs):
if len(arg):
msg = msg % arg
if len(kwargs):
msg = msg.format(**kwargs)
raise ValueError(msg)


@set_module_as('brainstate.transform')
def jit_error(
pred,
err_fun: Union[Callable, str],
err_arg=None,
scope: str = 'any'
*err_args,
**err_kwargs,
):
"""
Check errors in a jit function.
Expand All @@ -136,15 +146,15 @@ def jit_error(
It can give a function which receive arguments that passed from the JIT variables and raise errors.
>>> def error(arg):
>>> raise ValueError(f'error {arg}')
>>> def error(x):
>>> raise ValueError(f'error {x}')
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
>>> jit_error(x.sum() < 5., error, err_arg=x)
>>> jit_error(x.sum() < 5., error, x)
Or, it can be a simple string message.
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
>>> jit_error(x.sum() < 5., "Error: the sum is less than 5.")
>>> jit_error(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())
Parameters
Expand All @@ -153,19 +163,18 @@ def jit_error(
The boolean prediction.
err_fun: callable
The error function, which raise errors.
err_arg: any
err_args:
The arguments which passed into `err_f`.
scope: str
The scope of the error message. Can be None, 'all' or 'any'.
err_kwargs:
The keywords which passed into `err_f`.
"""
if isinstance(err_fun, str):
err_fun = partial(_error_msg, err_fun)
if scope is None:
pred = pred
elif scope == 'all':
pred = remove_vmap(pred, 'all')
elif scope == 'any':
pred = remove_vmap(pred, 'any')
else:
raise ValueError(f"Unknown scope: {scope}")
_cond(err_fun, pred, err_arg)

jax.lax.cond(
remove_vmap(pred, op='any'),
partial(_err_jit_true_branch, err_fun),
_err_jit_false_branch,
jax.tree.map(functools.partial(remove_vmap, op='none'), err_args),
jax.tree.map(functools.partial(remove_vmap, op='none'), err_kwargs),
)
55 changes: 55 additions & 0 deletions brainstate/transform/_jit_error_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import unittest

import jax
import jaxlib.xla_extension
import jax.numpy as jnp

import brainstate as bst


class TestJitError(unittest.TestCase):
def test1(self):
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
bst.transform.jit_error(True, 'error')

def err_f(x):
raise ValueError(f'error: {x}')

with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
bst.transform.jit_error(True, err_f, 1.)

def test_vmap(self):

def f(x):
bst.transform.jit_error(x, 'error: {x}', x=x)

jax.vmap(f)(jnp.array([False, False, False]))
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
jax.vmap(f)(jnp.array([True, False, False]))

def test_vmap_vmap(self):

def f(x):
bst.transform.jit_error(x, 'error: {x}', x=x)

jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
[False, False, False]]))
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
[True, False, False]]))

0 comments on commit 6c0e311

Please sign in to comment.