diff --git a/chex/_src/asserts_chexify.py b/chex/_src/asserts_chexify.py index ff9565f..8ed8108 100644 --- a/chex/_src/asserts_chexify.py +++ b/chex/_src/asserts_chexify.py @@ -189,6 +189,16 @@ def _chexified_fn(*args, **kwargs): 'Nested @chexify wrapping is disallowed. ' 'Make sure that you only wrap the function at the outermost level.') + if _ai.has_tracers((args, kwargs)): + raise RuntimeError( + '@chexify must be applied on top of all (p)jit/pmap transformations' + ' (otherwise it will result in `UnexpectedTracerError`). If you have' + ' functions that use value assertions, do not wrap them' + ' individually -- just wrap the outermost function after' + ' applying all your JAX transformations. See the example at' + 'https://github.com/google-deepmind/chex#static-and-value-aka-runtime-assertions' + ) + if async_check: # Check completed calls. while async_check_futures and async_check_futures[0].done():