Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Any way to avoid recompiling of jit'ed sub-functions? #4572

Closed
gnool opened this issue Oct 14, 2020 · 16 comments
Closed

Any way to avoid recompiling of jit'ed sub-functions? #4572

gnool opened this issue Oct 14, 2020 · 16 comments
Assignees
Labels
question Questions for the JAX team

Comments

@gnool
Copy link

gnool commented Oct 14, 2020

Suppose I have two functions

def main_function(x):
    return x+1

main_function_jit = jit(main_function)

def wrapper(x):
    return main_function_jit(x)

Theoretically speaking, is there any way to avoid recompiling of main_function_jit when I do jit(wrapper)?

In my work I have a big function that can be broken down into sub-functions, and some of these sub-functions can actually be reused in other big functions. Problem I encounter is that after compiling these sub-functions, it doesn't help to speed up compilation of the big function. Any idea?

@shoyer
Copy link
Collaborator

shoyer commented Oct 14, 2020

My understanding is that this isn't possible with JAX/XLA today. I agree it would be really nice to have.

In theory it would be possible if we created an XLA "CustomCall" for the inner jit decorated function (or with some sort of lower-level change in XLA itself).

Arguably this is the main reason why use-cases like lbfgs_optimize(odeint(fun, ...)) (#3847) are so slow. XLA is recompiling inner functions many redundant times.

@gnool
Copy link
Author

gnool commented Oct 15, 2020

@shoyer Got it. Another related question about compilation is this: I have a JIT'ed function that I specify with static_argnums. The static arguments are really just constant arrays. Yet I find that the check for whether a recompilation is needed is super strict. Even if I create the exact same array (same in terms of its memory content), it will also trigger a recompilation. Is there a way to work around this, perhaps telling JAX that there's really no need to check a certain static argument because the values in the array are guaranteed to be same?

@mattjj
Copy link
Collaborator

mattjj commented Oct 15, 2020

@gnool that's definitely a foot-gun that we want to revise. See #2813 and #3712.

Basically, for array types like numpy.ndarrays or JAX's DeviceArrays, jax.jit will recompile whenever the object identity of a static_argnums argument is new. Because array types are not hashable, right now static_argnums will silently handle arrays (and other unhashable objects) by object id. So using array types with static_argnums is a recipe for recompilation and slowness. In the near future, we plan to make static_argnums + unhashable type = error. [EDIT: edited to improve phrasing, which was previously pasted from chat comments and wasn't very clear]

Example:

In [1]: import jax.numpy as jnp

In [2]: x = jnp.array([1., 2., 3.])
/usr/local/google/home/mattjj/packages/jax/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')

In [3]: def f(x):
   ...:     print('re-tracing!')
   ...:     return x ** 2
   ...:

In [4]: from jax import jit

In [5]: jit_f = jit(f, static_argnums=(0,))

In [6]: jit_f(x)
re-tracing!
Out[6]: DeviceArray([1., 4., 9.], dtype=float32)

In [7]: x = jnp.array([1., 2., 3.])

In [8]: jit_f(x)
re-tracing!
Out[8]: DeviceArray([1., 4., 9.], dtype=float32)

@gnool
Copy link
Author

gnool commented Oct 16, 2020

@mattjj If I understand your explanation correctly, there is no way to make static_argnum + unhashable type (arrays) work, right? Which is why it has to at least return an error. Pardon my poor understanding of how JIT works, have to admit I'm still a little puzzled as to why a need of recompilation is needed when the array content (and its metadata except its identity and memory address) is exactly the same. I'm guessing there's a reason as to why a static argument has to be hashable too.

My "workaround" now would be to avoid declaring the arrays as static then. Although the subsequent function call is still fast, it is noticeably slower (perhaps by an order of magnitude) than declaring the arrays as static (and avoiding the recompilation).

@mattjj
Copy link
Collaborator

mattjj commented Oct 16, 2020

why recompilation is needed when the array content (and its metadata except its identity and memory address) is exactly the same

The issue is that we'd have to check that the array content is the same. That is, we could define __eq__ and __hash__ on the DeviceArray class, but it'd be expensive: on every dispatch of the jitted function we'd have to compute a hash of all the static array arguments and possibly also compare them all elementwise to the cache key entries.

Plus, if you want that behavior, you can simulate it yourself without needing jax.jit to change at all. Just wrap the array objects you want to be static and cached on value in an instance of something like this:

class HashableArrayWrapper:
  def __init__(self, val):
    self.val = val
  def __hash__(self):
    return some_hash_function(self.val)  # maybe implement this in jax.numpy to save on transfers
  def __eq__(self, other):
    return isinstance(other, HashableArrayWrapper) and jnp.all(jnp.eq(self.val, other.val))

You could write a wrapper on jit like this:

def gnool_jit(fun, static_array_argnums=()):
  @jit  # EDIT: forgot to use static_argnums here! see comment below
  def callee(*args):
    args = list(args)
    for i in static_array_argnums:
      args[i] = args[i].val
    return fun(*args)

  def caller(*args):
    args = list(args)
    for i in static_array_argnums:
      args[i] = HashableArrayWrapper(args[i])
    return callee(*args)

  return caller

WDYT?

@mattjj mattjj self-assigned this Oct 16, 2020
@mattjj mattjj added the question Questions for the JAX team label Oct 16, 2020
@gnool
Copy link
Author

gnool commented Oct 16, 2020

@mattjj If implementing my own hash and eq helps bypass the current object identity check, that's definitely something I can explore. I'm running into this error below, any idea?

jax.traceback_util.FilteredStackTrace: TypeError: Argument '<__main__.HashableArrayWrapper object at 0x7fb39bd44ad0>' of type <class '__main__.HashableArrayWrapper'> is not a valid JAX type

@mattjj
Copy link
Collaborator

mattjj commented Oct 16, 2020

Sorry, I forgot to use static_argnums in my example code above. Oops! Here's a working example:

from functools import partial

from jax import jit
import jax.numpy as jnp

def some_hash_function(x):
  return int(jnp.sum(x))

class HashableArrayWrapper:
  def __init__(self, val):
    self.val = val
  def __hash__(self):
    return some_hash_function(self.val)
  def __eq__(self, other):
    return (isinstance(other, HashableArrayWrapper) and
            jnp.all(jnp.equal(self.val, other.val)))

def gnool_jit(fun, static_array_argnums=()):
  @partial(jit, static_argnums=static_array_argnums)
  def callee(*args):
    args = list(args)
    for i in static_array_argnums:
      args[i] = args[i].val
    return fun(*args)

  def caller(*args):
    args = list(args)
    for i in static_array_argnums:
      args[i] = HashableArrayWrapper(args[i])
    return callee(*args)

  return caller


###


@partial(gnool_jit, static_array_argnums=(0,))
def f(x):
  print('re-tracing!')
  return x ** 2


x = jnp.array([1., 2., 3.])
f(x)
f(x)

x = jnp.array([1., 2., 3.])
f(x)

All we're doing here is making a hashable type, that is has hash and eq methods that implement whatever behavior you want. The object identity behavior we were talking about with static_argnums only kicks in when the type isn't hashable.

@gnool
Copy link
Author

gnool commented Oct 16, 2020

@mattjj Thanks for the super quick reply, it works now! For my own benefit, in our discussion above we touched on things like cache. For a static argument, where exactly is the cache for the array's data stored? In the past few days I've been reading what others here have discussed regarding pickling a JIT compiled function (to save compilation time), and so far the answers have been indicating that this is unsupported by JAX. Out of curiosity I have used cloudpickle to pickle a JIT compiled function (it pickled without error), dump it, load it back again, and unsurprisingly it behaves like an uncompiled function (i.e. needs to warm up again). Is this because cache was not being properly stored during the pickling, or perhaps some other more complicated reasons?

@shoyer
Copy link
Collaborator

shoyer commented Oct 16, 2020

I believe these lines are where you can find JAX's backend compilation cache:
https://github.com/google/jax/blob/32010968992ff88c9c065ff1fa5ba6cbbfd21641/jax/interpreters/xla.py#L245-L247

It's store in an LRU cache on xla_primitive_callable, not on the function objects themselves. That might be feasible, but memory management for caching compilation is already a little tricky...

@mattjj
Copy link
Collaborator

mattjj commented Oct 16, 2020

Glad to hear it works!

For a static argument, where exactly is the cache for the array's data stored?

The compilation cache is really just this one memoization decorator, defined here and applied here. Argument values that correspond to static_argnums positions are actually part of the wrapper that makes up a WrappedFun, in particular they're part of the fun argument on this line. (The compiled executable is part of the cache value, rather than the key.)

Here's a little spelunking to show where the array lives in the above example:

In [1]: run gnool.py
/usr/local/google/home/mattjj/packages/jax/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
> /usr/local/google/home/mattjj/packages/jax/jax/linear_util.py(241)memoized_fun()
-> cache = fun_caches.setdefault(fun.f, {})
(Pdb) l
236       fun_caches: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
237       thread_local: threading.local = _CacheLocalContext()
238
239       def memoized_fun(fun: WrappedFun, *args):
240         breakpoint()
241  ->     cache = fun_caches.setdefault(fun.f, {})
242         key = (fun.transforms, fun.params, args)
243         result = cache.get(key, None)
244         if result is not None:
245           ans, stores = result
246           fun.populate_stores(stores)
(Pdb) fun
Wrapped function:
0   : process_env_traces   (xla_call, 0, (('device', None), ('backend', None), ('name', 'callee'), ('donated_invars', ())))
1   : flatten_fun   (PyTreeDef(tuple, [PyTreeDef(tuple, []),PyTreeDef(dict[[]], [])]),)
2   : _argnums_partial   ((), (<jax.util.Hashable object at 0x7f96c89bc690>,))
Core: callee

(Pdb) fun.transforms[2][1][1][0].val.val
DeviceArray([1., 2., 3.], dtype=float32)

However, those implementation details are all subject to change, even in the very near future, so I wouldn't build anything against them.

@mattjj
Copy link
Collaborator

mattjj commented Oct 16, 2020

@shoyer that's the line for the "op-by-op" cache, which is indeed one compilation cache, though the cache for jit is lower down in the same file.

@gnool
Copy link
Author

gnool commented Oct 16, 2020

@mattjj @shoyer sounds to me I'm playing with fire then. I'll keep my hands off this and perhaps just stick to pre-compilation at the beginning of the program. Really looking forward to one day where JAX allows us to store the pre-compiled function. Thanks again for all the support and for building this awesome tool!

@mattjj
Copy link
Collaborator

mattjj commented Oct 16, 2020

Thanks for the words of encouragement! We hear you on the need for pre-compiled executables. We should track that on #476.

@DavidDevoogdt
Copy link

This wrapper acts as a child class for the original jax array that implements __eq__ and __hash__. No special jit wrapper is needed.

from typing import Generic, TypeVar

T = TypeVar('T')      # Declare type variable

class HashableArrayWrapper(Generic[T]):
    def __init__(self, val: T):
        self.val = val

    def __getattribute__(self, prop):
        if prop == 'val' or prop == "__hash__" or prop == "__eq__":
            return super(HashableArrayWrapper, self).__getattribute__(prop)
        return getattr(self.val, prop)

    def __getitem__(self, key):
        return self.val[key]

    def __setitem__(self, key, val):
        self.val[key] = val

    def __hash__(self):
        return hash(self.val.tobytes())

    def __eq__(self, other):
        if isinstance(other, HashableArrayWrapper):
            return self.__hash__() == other.__hash__()

        f = getattr(self.val, "__eq__")
        return f(self, other)

My usecase for this was to use a jax array as static field in https://github.com/brentyi/jax_dataclasses. I could've written a wrapper similar to the gnool_jit but this seems easier. Hope this might also help you :)

@carlosgmartin
Copy link
Contributor

@mattjj Out of curiosity, is there a positive reason to not define some __hash__ method for Array? This popped up in the context of default values for dataclass fields, e.g. sotetsuk/pgx#1062:

TypeError: unhashable type: 'ArrayImpl'

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 13, 2023

The main reason not to define __hash__ for JAX array objects is becuase the existence of __hash__ implies some assumptions about the behavior of __eq__ that do not match the semantics of __eq__ for JAX & numpy arrays. See Python's __hash__ documentation for details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

6 participants