From 13f93566bfb1bfeff51def41887b4080e7421880 Mon Sep 17 00:00:00 2001 From: Iurii Kemaev Date: Tue, 31 Oct 2023 05:12:27 -0700 Subject: [PATCH] Improve usability of `chexify`. PiperOrigin-RevId: 578145061 --- chex/_src/asserts_chexify.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/chex/_src/asserts_chexify.py b/chex/_src/asserts_chexify.py index ff9565f5..8ed81084 100644 --- a/chex/_src/asserts_chexify.py +++ b/chex/_src/asserts_chexify.py @@ -189,6 +189,16 @@ def _chexified_fn(*args, **kwargs): 'Nested @chexify wrapping is disallowed. ' 'Make sure that you only wrap the function at the outermost level.') + if _ai.has_tracers((args, kwargs)): + raise RuntimeError( + '@chexify must be applied on top of all (p)jit/pmap transformations' + ' (otherwise it will result in `UnexpectedTracerError`). If you have' + ' functions that use value assertions, do not wrap them' + ' individually -- just wrap the outermost function after' + ' applying all your JAX transformations. See the example at' + 'https://github.com/google-deepmind/chex#static-and-value-aka-runtime-assertions' + ) + if async_check: # Check completed calls. while async_check_futures and async_check_futures[0].done():