Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to merge equinox pretty-printing and jax debug printing? #644

Open
LouisDesdoigts opened this issue Jan 22, 2024 · 2 comments
Open

How to merge equinox pretty-printing and jax debug printing? #644

LouisDesdoigts opened this issue Jan 22, 2024 · 2 comments
Labels
feature New feature question User queries

Comments

@LouisDesdoigts
Copy link
Contributor

LouisDesdoigts commented Jan 22, 2024

So I'm trying to do two things here:

  1. Raise a error through the jit boundary using eqx.error_if.
  2. Print a pytree using the jax.debug.print function.

I'm looking to do this because some Nans are arising during my training that are difficult to isolate. I can't use the usual jax debug_nans flag as some of data naturally has Nans present.

Presently I am creating a boolean pytree that checks for any nans on the leaves, and ideally I would be able to print the actually boolean value as opposed to the usual Traced<ShapedArray(bool[])>, which would be done using the jax debug print.

Extending your error_if example:

@eqx.filter_jit
def f(x):
    bool_tree = jax.tree_map(lambda x: np.isnan(x).any(), x)
    vals = np.array(jax.tree_util.tree_flatten(bool_tree)[0])
    msg = "Nan found in tree:\n" + eqx.tree_pformat(bool_tree, short_arrays=False)
    x = eqx.error_if(x, vals.sum() > 0, msg)
    return x

pytree = (np.zeros(3), np.zeros(5))
_ = f(pytree)

nan_pytree = eqx.tree_at(lambda pytree: pytree[0], pytree, np.zeros(3).at[2].set(np.nan))
_ = f(nan_pytree)

However this doesn't print the actual array values, hence the desire for some interface with jax.debug.print. Is there some simple way to achieve this?

Thanks in advance!

@patrick-kidger
Copy link
Owner

So I think we have a couple options.

First of all, if you just want to debug something and want to use jax.debug.print with pretty printing, then this can be done using the following trick:

class PrettyFormat:
    def __init__(self, obj):
        self.obj = obj

    def __repr__(self):
        return eqx.tree_pformat(self.obj, short_arrays=False)

jax.debug.print("{}", PrettyFormat(bool_tree))

This defers the pretty-formatting until the repr is called inside jax.debug.print.

If you actually want to go ahead and include this information inside an eqx.error_if call, however, then right now Equinox doesn support that. It should be possible to add, though: we could imagine adding an eqx.error_if(..., format_with=...) argument that is used to format the string before printing. I'd be happy to take a pull request on this.

@patrick-kidger patrick-kidger added feature New feature question User queries labels Jan 23, 2024
@LouisDesdoigts
Copy link
Contributor Author

Ah that is a very nice solution, awesome! Yeah I think having that extra functionality would be great. I'm happy to build this in when I can find some time in the next few weeks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature question User queries
Projects
None yet
Development

No branches or pull requests

2 participants