-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Any way to avoid recompiling of jit'ed sub-functions? #4572
Comments
My understanding is that this isn't possible with JAX/XLA today. I agree it would be really nice to have. In theory it would be possible if we created an XLA "CustomCall" for the inner Arguably this is the main reason why use-cases like |
@shoyer Got it. Another related question about compilation is this: I have a JIT'ed function that I specify with static_argnums. The static arguments are really just constant arrays. Yet I find that the check for whether a recompilation is needed is super strict. Even if I create the exact same array (same in terms of its memory content), it will also trigger a recompilation. Is there a way to work around this, perhaps telling JAX that there's really no need to check a certain static argument because the values in the array are guaranteed to be same? |
@gnool that's definitely a foot-gun that we want to revise. See #2813 and #3712. Basically, for array types like numpy.ndarrays or JAX's DeviceArrays, Example:
|
@mattjj If I understand your explanation correctly, there is no way to make static_argnum + unhashable type (arrays) work, right? Which is why it has to at least return an error. Pardon my poor understanding of how JIT works, have to admit I'm still a little puzzled as to why a need of recompilation is needed when the array content (and its metadata except its identity and memory address) is exactly the same. I'm guessing there's a reason as to why a static argument has to be hashable too. My "workaround" now would be to avoid declaring the arrays as static then. Although the subsequent function call is still fast, it is noticeably slower (perhaps by an order of magnitude) than declaring the arrays as static (and avoiding the recompilation). |
The issue is that we'd have to check that the array content is the same. That is, we could define Plus, if you want that behavior, you can simulate it yourself without needing class HashableArrayWrapper:
def __init__(self, val):
self.val = val
def __hash__(self):
return some_hash_function(self.val) # maybe implement this in jax.numpy to save on transfers
def __eq__(self, other):
return isinstance(other, HashableArrayWrapper) and jnp.all(jnp.eq(self.val, other.val)) You could write a wrapper on def gnool_jit(fun, static_array_argnums=()):
@jit # EDIT: forgot to use static_argnums here! see comment below
def callee(*args):
args = list(args)
for i in static_array_argnums:
args[i] = args[i].val
return fun(*args)
def caller(*args):
args = list(args)
for i in static_array_argnums:
args[i] = HashableArrayWrapper(args[i])
return callee(*args)
return caller WDYT? |
@mattjj If implementing my own hash and eq helps bypass the current object identity check, that's definitely something I can explore. I'm running into this error below, any idea?
|
Sorry, I forgot to use from functools import partial
from jax import jit
import jax.numpy as jnp
def some_hash_function(x):
return int(jnp.sum(x))
class HashableArrayWrapper:
def __init__(self, val):
self.val = val
def __hash__(self):
return some_hash_function(self.val)
def __eq__(self, other):
return (isinstance(other, HashableArrayWrapper) and
jnp.all(jnp.equal(self.val, other.val)))
def gnool_jit(fun, static_array_argnums=()):
@partial(jit, static_argnums=static_array_argnums)
def callee(*args):
args = list(args)
for i in static_array_argnums:
args[i] = args[i].val
return fun(*args)
def caller(*args):
args = list(args)
for i in static_array_argnums:
args[i] = HashableArrayWrapper(args[i])
return callee(*args)
return caller
###
@partial(gnool_jit, static_array_argnums=(0,))
def f(x):
print('re-tracing!')
return x ** 2
x = jnp.array([1., 2., 3.])
f(x)
f(x)
x = jnp.array([1., 2., 3.])
f(x) All we're doing here is making a hashable type, that is has hash and eq methods that implement whatever behavior you want. The object identity behavior we were talking about with |
@mattjj Thanks for the super quick reply, it works now! For my own benefit, in our discussion above we touched on things like cache. For a static argument, where exactly is the cache for the array's data stored? In the past few days I've been reading what others here have discussed regarding pickling a JIT compiled function (to save compilation time), and so far the answers have been indicating that this is unsupported by JAX. Out of curiosity I have used cloudpickle to pickle a JIT compiled function (it pickled without error), dump it, load it back again, and unsurprisingly it behaves like an uncompiled function (i.e. needs to warm up again). Is this because cache was not being properly stored during the pickling, or perhaps some other more complicated reasons? |
I believe these lines are where you can find JAX's backend compilation cache: It's store in an LRU cache on |
Glad to hear it works!
The compilation cache is really just this one memoization decorator, defined here and applied here. Argument values that correspond to Here's a little spelunking to show where the array lives in the above example:
However, those implementation details are all subject to change, even in the very near future, so I wouldn't build anything against them. |
@shoyer that's the line for the "op-by-op" cache, which is indeed one compilation cache, though the cache for |
@mattjj @shoyer sounds to me I'm playing with fire then. I'll keep my hands off this and perhaps just stick to pre-compilation at the beginning of the program. Really looking forward to one day where JAX allows us to store the pre-compiled function. Thanks again for all the support and for building this awesome tool! |
Thanks for the words of encouragement! We hear you on the need for pre-compiled executables. We should track that on #476. |
This wrapper acts as a child class for the original jax array that implements
My usecase for this was to use a jax array as static field in https://github.com/brentyi/jax_dataclasses. I could've written a wrapper similar to the gnool_jit but this seems easier. Hope this might also help you :) |
@mattjj Out of curiosity, is there a positive reason to not define some
|
The main reason not to define |
Suppose I have two functions
Theoretically speaking, is there any way to avoid recompiling of main_function_jit when I do jit(wrapper)?
In my work I have a big function that can be broken down into sub-functions, and some of these sub-functions can actually be reused in other big functions. Problem I encounter is that after compiling these sub-functions, it doesn't help to speed up compilation of the big function. Any idea?
The text was updated successfully, but these errors were encountered: