Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to detect recompilation reason? #4274

Closed
cgarciae opened this issue Sep 13, 2020 · 12 comments
Closed

How to detect recompilation reason? #4274

cgarciae opened this issue Sep 13, 2020 · 12 comments
Assignees
Labels
enhancement New feature or request question Questions for the JAX team

Comments

@cgarciae
Copy link
Collaborator

cgarciae commented Sep 13, 2020

Hey! Is there a way to make jit explain why its recompiling? I have a complex function with multiple static arguments and deeply structured pytrees, I can't figure out what is forcing recompilation :s

If not, is there a way where I could use something in tree_utils to mimic jax's notion of equality between arguments?

@cgarciae
Copy link
Collaborator Author

It would also be super useful if one could count the number times a function has compiled for testing purposes or even inspect the "signature" of each compiled function.

@mattjj
Copy link
Collaborator

mattjj commented Sep 13, 2020

cc @froystig

We absolutely need tooling like this. There isn't much now, though you can set JAX_LOG_COMPILES=1 to get some basic logging on every recompile.

As for why it's recompiling, this could be #2813 (see also #3701, #3708, #3712, the latter being a first step in fixing it), which is that if you pass unhashable values to static_argnums arguments then JAX will silently hash (and hence cache) on object ID. That could be as simple as passing a list (unhashable) instead of a tuple (hashable) as a static argument, with the result being that you never get compilation cache hits.

In any case, we need to make recompilations (and, as you say, the reason for them) transparent.

If not, is there a way where I could use something in tree_utils to mimic jax's notion of equality between arguments?

Can you say more about this?

@mattjj mattjj added the question Questions for the JAX team label Sep 13, 2020
@cgarciae
Copy link
Collaborator Author

Thanks a lot @mattjj !

Can you say more about this?

I was thinking of putting all the arguments into a tuple on each call and maybe using something like tree_multimap to compare the pytree-tuples between calls, this is just a rough idea. I guess I would have to make a special case for static arguments. My interest would be finding out which arguments force recompile.

That said, I do have some non-hashable (lists and dicts) static arguments which might be the culprits.

@cgarciae
Copy link
Collaborator Author

I was looking into make_jaxpr which seems to give a hint into what the signature of each function looks like.

@mattjj
Copy link
Collaborator

mattjj commented Sep 13, 2020

Just a brainstorming thought: one way to "explain why" we got a recompile might be to compute a kind of minimal edit distance to existing cache entries (maybe we could even use the standard library's difflib).

@cgarciae
Copy link
Collaborator Author

@mattjj may I ask, what are the cache entries? Where are they stored?

I fixed my bug (I think) it might have been multiple things. However, I think users can face this kind of issues very easily, it would be super useful to be able to have a tool for this kind of stuff which is very central to Jax. I can give it a try but I don't know where to start.

@cgarciae
Copy link
Collaborator Author

cgarciae commented Sep 14, 2020

BTW: One part of the bugs was related to Python scalars being promoted to Jaxs arrays when they become outputs of a jitted function, I guess its the users fault but I don't know if a warning of even an error for this kind of bad behavior would be even better. I believe this is already in one the issue you liked at the begining.

@mattjj
Copy link
Collaborator

mattjj commented Sep 14, 2020

There are a few different compilation caches: one for jit, one for pmap, and one for op-by-op primitives. They're all just implemented as memoization decorators.

For jit, the function _xla_callable in xla.py is decorated with the linear_util.py cache decorator. The fact that it's just memoization makes it very easy to reason about! (But it's not quite functionally pure memoization; there's a special kind of side-effect related to the stores you see in that code. We manage it carefully based on the constraint that, in JAX internals, transformed functions are executed exactly once ever. That's where the "linear" in linear_util.py comes from, as in linear types or logic.) The cache is in principle just a dict, the keys are just the arguments to _xla_callable (basically a transformed function and some abstract values, where the transformed function's hash reflects all the transformations that have been applied to it), and the values are return values of _xla_callable, namely partially-applied _execute_compiled functions which just [package together an XlaExecutable with argument and result handlers that unbox argument DeviceArrays into raw buffers and package raw result buffers back into DeviceArrays, respectively.

The pmap compilation cache is the same as the jit one, just in pxla.py.

The op-by-op compilation cache is just a functools.lru_cache applied to a roughly analogous function (taking a primitive and abstract args, and returning a partially-applied _execute_compiled_primitive). It doesn't have any of that stores stuff or transformations to worry about.

@mattjj mattjj added the enhancement New feature or request label Sep 14, 2020
@mattjj
Copy link
Collaborator

mattjj commented Sep 14, 2020

All the above implementation details are subject to change, but this is how things look today!

@jekbradbury
Copy link
Contributor

jekbradbury commented Sep 15, 2020

I think one of the things you were running into has to do with JAX "weak types". We don't want any use of Python scalars (where 0.0 is nominally float64 and 0 is nominally an arbitrary-precision int) to promote JAX values to higher than expected precision, so we treat them as "undecided precision" until they come into contact with a JAX value. Thus e.g.

x = 0.
f = jax.jit(lambda x: x + np.float32(0.))
y = f(x)
z = f(y)

will compile f twice: once with weak type on x and once with strong type on y. This can be a bit of a gotcha for training loops, and the solution (for now) is to initialize scalars (e.g. learning rates) used as arguments to jitted functions as NumPy or jnp scalars with explicit type (e.g. x = np.float32(0.).
(CC @jakevdp, who's been thinking about weak_type recently)

@froystig froystig self-assigned this Nov 2, 2020
@cgarciae cgarciae closed this as completed Apr 8, 2022
@bycn
Copy link

bycn commented Nov 3, 2022

Not sure if there is a main issue for this, but logging +1 for the gotcha @jekbradbury described. Was confused why my method was recompiling during testing: it was because jnp.ones() is "float32" and jnp.arange() is "int32". An warning message here would perhaps be helpful, although I'm not sure if this is generalizable :)

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 3, 2022

FYI, jnp.arange defaults to the type of the argument you pass it, so jnp.arange(10) has int dtype, while jnp.arange(10.0) has float dtype. I agree it might be somewhat surprising, but it's a behavior inherited from numpy. For what it's worth, in JAX's native version of range, lax.iota, the dtype argument must be specified explicitly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

6 participants