You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Raise a error through the jit boundary using eqx.error_if.
Print a pytree using the jax.debug.print function.
I'm looking to do this because some Nans are arising during my training that are difficult to isolate. I can't use the usual jax debug_nans flag as some of data naturally has Nans present.
Presently I am creating a boolean pytree that checks for any nans on the leaves, and ideally I would be able to print the actually boolean value as opposed to the usual Traced<ShapedArray(bool[])>, which would be done using the jax debug print.
However this doesn't print the actual array values, hence the desire for some interface with jax.debug.print. Is there some simple way to achieve this?
Thanks in advance!
The text was updated successfully, but these errors were encountered:
First of all, if you just want to debug something and want to use jax.debug.print with pretty printing, then this can be done using the following trick:
This defers the pretty-formatting until the repr is called inside jax.debug.print.
If you actually want to go ahead and include this information inside an eqx.error_if call, however, then right now Equinox doesn support that. It should be possible to add, though: we could imagine adding an eqx.error_if(..., format_with=...) argument that is used to format the string before printing. I'd be happy to take a pull request on this.
Ah that is a very nice solution, awesome! Yeah I think having that extra functionality would be great. I'm happy to build this in when I can find some time in the next few weeks.
So I'm trying to do two things here:
eqx.error_if
.jax.debug.print
function.I'm looking to do this because some Nans are arising during my training that are difficult to isolate. I can't use the usual jax debug_nans flag as some of data naturally has Nans present.
Presently I am creating a boolean pytree that checks for any nans on the leaves, and ideally I would be able to print the actually boolean value as opposed to the usual
Traced<ShapedArray(bool[])>
, which would be done using the jax debug print.Extending your
error_if
example:However this doesn't print the actual array values, hence the desire for some interface with
jax.debug.print
. Is there some simple way to achieve this?Thanks in advance!
The text was updated successfully, but these errors were encountered: