-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to detect recompilation reason? #4274
Comments
It would also be super useful if one could count the number times a function has compiled for testing purposes or even inspect the "signature" of each compiled function. |
cc @froystig We absolutely need tooling like this. There isn't much now, though you can set As for why it's recompiling, this could be #2813 (see also #3701, #3708, #3712, the latter being a first step in fixing it), which is that if you pass unhashable values to In any case, we need to make recompilations (and, as you say, the reason for them) transparent.
Can you say more about this? |
Thanks a lot @mattjj !
I was thinking of putting all the arguments into a tuple on each call and maybe using something like That said, I do have some non-hashable (lists and dicts) static arguments which might be the culprits. |
I was looking into |
Just a brainstorming thought: one way to "explain why" we got a recompile might be to compute a kind of minimal edit distance to existing cache entries (maybe we could even use the standard library's difflib). |
@mattjj may I ask, what are the cache entries? Where are they stored? I fixed my bug (I think) it might have been multiple things. However, I think users can face this kind of issues very easily, it would be super useful to be able to have a tool for this kind of stuff which is very central to Jax. I can give it a try but I don't know where to start. |
BTW: One part of the bugs was related to Python scalars being promoted to Jaxs arrays when they become outputs of a jitted function, I guess its the users fault but I don't know if a warning of even an error for this kind of bad behavior would be even better. I believe this is already in one the issue you liked at the begining. |
There are a few different compilation caches: one for For The The op-by-op compilation cache is just a |
All the above implementation details are subject to change, but this is how things look today! |
I think one of the things you were running into has to do with JAX "weak types". We don't want any use of Python scalars (where x = 0.
f = jax.jit(lambda x: x + np.float32(0.))
y = f(x)
z = f(y) will compile |
Not sure if there is a main issue for this, but logging +1 for the gotcha @jekbradbury described. Was confused why my method was recompiling during testing: it was because jnp.ones() is "float32" and jnp.arange() is "int32". An warning message here would perhaps be helpful, although I'm not sure if this is generalizable :) |
FYI, |
Hey! Is there a way to make
jit
explain why its recompiling? I have a complex function with multiple static arguments and deeply structured pytrees, I can't figure out what is forcing recompilation :sIf not, is there a way where I could use something in
tree_utils
to mimic jax's notion of equality between arguments?The text was updated successfully, but these errors were encountered: