diff --git a/chex/_src/fake.py b/chex/_src/fake.py index b9c3f9b..a6fdc74 100644 --- a/chex/_src/fake.py +++ b/chex/_src/fake.py @@ -88,11 +88,6 @@ def convert_to_varargs(sig, *args, **kwargs): return bound_args.args -@functools.wraps(jax.jit) -def _fake_jit(fn, *unused_args, **unused_kwargs): - return fn - - def _ignore_axis_index_groups(fn): """Wrapper that forces axis_index_groups to be None. @@ -256,23 +251,7 @@ def foo(x): such as `jax.lax.scan`, etc. """ stack = FakeContext() - if enable_patching: - stack.enter_context(mock.patch('jax.jit', _fake_jit)) - - # Some functions like jax.lax.scan also internally use jit. Most respect - # the config setting `jax_disable_jit` and replace its implementation - # with a dummy, jit-free one if the setting is one. Use this mechanism too. - @contextlib.contextmanager - def _jax_disable_jit(): - original_value = jax.config.jax_disable_jit - jax.config.update('jax_disable_jit', True) - try: - yield - finally: - jax.config.update('jax_disable_jit', original_value) - - stack.enter_context(_jax_disable_jit()) - + stack.enter_context(jax.disable_jit(disable=enable_patching)) return stack