Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NumPy constants device_put multiple times during tracing #5308

Closed
trevorcai opened this issue Jan 4, 2021 · 10 comments · Fixed by #6014
Closed

NumPy constants device_put multiple times during tracing #5308

trevorcai opened this issue Jan 4, 2021 · 10 comments · Fixed by #6014
Assignees
Labels
better_errors Improve the error reporting

Comments

@trevorcai
Copy link
Contributor

trevorcai commented Jan 4, 2021

In one of my networks, I have a large (64 MB) NumPy constant that I matmul against 160 different weight matrices.
I am noticing that JAX is repeatedly calling _raw_device_put on it, seemingly taking up >10G of accelerator HBM during tracing (resulting in an OOM at trace-time).

A repro is below:

  1. Add a print statement in convert_element_type when _device_put_raw is called.
  2. Run the below script.
diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py
index eaf94ced..0f3b5483 100644
--- a/jax/_src/lax/lax.py
+++ b/jax/_src/lax/lax.py
@@ -436,6 +436,7 @@ def convert_element_type(operand: Array, new_dtype: DType) -> Array:
     if isinstance(operand, (core.Tracer, xla.DeviceArray)):
       return operand
     else:
+      print('_device_put_raw inside convert_element_type.')
       return _device_put_raw(np.asarray(operand))
   if (dtypes.issubdtype(old_dtype, np.complexfloating) and
       not dtypes.issubdtype(new_dtype, np.complexfloating)):
import functools
import jax
import jax.numpy as jnp
import numpy as np
 
@jax.jit
def make_ws(key):
  return [k for k in jax.random.normal(key, (6, 4, 4))]
 
@functools.partial(jax.jit, static_argnums=1)
def run_10x(ws, wrap_array):
  # Create some NumPy constant.
  x = np.arange(8).reshape((2, 4)).astype(np.float32)
  if wrap_array:
    x = jnp.array(x)
  print('x type: ', type(x))
 
  # Use the same constant repeatedly.
  # If not wrap_array, it gets _raw_device_put repeatedly.
  # Real case: 160 layers, x is 64M -> 10G used from this!
  print('Running jnp.dot(x, w) 6x.')
  ys = [jnp.dot(x, w) for w in ws]
  print('Done.')
  return sum(ys)
 
ws = make_ws(jax.random.PRNGKey(428))
print('w types: ', [type(w) for w in ws])
run_10x(ws, False)
run_10x(ws, True)

Output:

tycai:~/jax$ python test.py 
...
w types:  [<class 'jax.interpreters.xla._DeviceArray'>, <class 'jax.interpreters.xla._DeviceArray'>, ..., <class 'jax.interpreters.xla._DeviceArray'>]
x type:  <class 'numpy.ndarray'>
Running jnp.dot(x, w) 6x.
_device_put_raw inside convert_element_type.
_device_put_raw inside convert_element_type.
_device_put_raw inside convert_element_type.
_device_put_raw inside convert_element_type.
_device_put_raw inside convert_element_type.
_device_put_raw inside convert_element_type.
Done.
_device_put_raw inside convert_element_type.
x type:  <class 'jax.interpreters.xla._DeviceArray'>
Running jnp.dot(x, w) 6x.
Done.
_device_put_raw inside convert_element_type.

This seems like a sharp edge that might trip others up; it took me a couple hours to track down (in context of a larger model).
I did not observe this prior to omnistaging, as I had tie_in-ed the NumPy constant, but tie_in was turned into a no-op.
In the good wrap_array case, jnp.array is basically functioning as a tie_in replacement.

@trevorcai
Copy link
Contributor Author

@mattjj You're probably the relevant poc for this issue.

@mattjj mattjj added bug Something isn't working better_errors Improve the error reporting labels Jan 4, 2021
@mattjj
Copy link
Collaborator

mattjj commented Jan 4, 2021

Thanks for the clean explanation! This does seem like a foot-gun that we can figure out a way to avoid, or at least warn about so it's not silent. I'm not quite sure which yet...

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 5, 2021

This is a tricky one, but it's not clear to me that JAX should be expected to avoid this. Rewriting your list comprehension as a loop looks like this:

ys = []
for w in ws:
  ys.append(jnp.dot(x, w))

When you call jnp.dot, the inputs are pushed to device if necessary, but since x is the same object each time, could JAX elide this and just push it once? Maybe. But what if the code looks like this:

ys = []
for w in ws:
  ys.append(jnp.dot(x, w))
  x[0, 0] = np.random.rand()

This is still the same Python object being pushed to device each time (in the sense that id(x) has not changed), but it would not be correct for JAX to device_put once and re-use the same array each time.

How might JAX know at trace-time whether to re-use an array that it encounters? It would have to do a full equality check of the new array to be pushed with each existing array on the device: if the new array matches an existing array, we use it rather than pushing a new array.

Is that possible to implement? Sure, some sort of global hash table might be a good solution. But do we want this kind of global state in play for every JAX program, in order to silently handle this kind of corner case? It's not clear to me that the added complexity would be worth the benefit.

I think the best answer here would be "use wrap_array=True".

@mattjj
Copy link
Collaborator

mattjj commented Jan 5, 2021

Good point about mutable numpy arrays. I always forget mutability :)

I think if numpy arrays were immutable then we could do something reasonable, e.g. by binding the convert_element_type primitive in more cases and then de-duping constants when we lower to XLA (which we should do anyway), but mutability makes it hopeless. You're totally right.

So I agree the current behavior should probably stay. Maybe we could make this kind of thing easier to catch with some memory-debugger tooling?

@mattjj mattjj removed the bug Something isn't working label Jan 5, 2021
@trevorcai
Copy link
Contributor Author

trevorcai commented Jan 5, 2021

That's a good point, Jake! I think one possible action item might be re-reifying the notion of "baking in" a constant into the computation graph, like tie_in did, as the concept still exists post-omnistaging despite our hopes to the contrary!

Semi-unrelatedly: if I'm not mistaken, isn't the mutable loop buggy anyways? Let x be big and have the mutation occur on [-1, -1]; my understanding is that mutating this NumPy array while it's being device_put may result in the mutation showing in the DeviceArray, defying program order.

This is particularly nasty since the device_put is not visible from user code.

I'm not sure how to fix this without literally blocking on all device_puts and destroying performance, but it feels like there's some remaining room to define the safe interactions of JAX and mutable-NumPy!

@mattjj
Copy link
Collaborator

mattjj commented Jan 5, 2021

I think one possible action item might be re-reifying the notion of "baking in" a constant into the computation graph, like tie_in did, as the concept still exists post-omnistaging despite our hopes to the contrary!

I think just calling jnp.array is the solution here, plus maybe some de-duplication of immutable constants based on object id if we don't already do that in the lowering (something we should do anyway).

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 5, 2021

Edit: this reflects floating point issues, not race conditions...

Good point! That's pretty subtle... here's a repro on GPU:

import numpy as np
import jax.numpy as jnp

x = np.zeros((100, 100))

y = []
for i in range(10):
  x[:, :] = i
  y.append(jnp.array(x))

print([float(yi.mean()) for yi in y])
# [0.0, 1.0, 2.0, 3.000000238418579, 4.0, 5.0, 6.000000476837158, 7.000000476837158, 8.0, 9.0]

@trevorcai
Copy link
Contributor Author

trevorcai commented Jan 5, 2021

Happy with jnp.array being the interface/solution :) and indeed I've already submitted that as the fix into my codebase.

I'm pushing more on "baking constants in is part of the jnp.array API contract", however we choose to document that contract.

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 19, 2021

I believe this issue will be fixed (or at least partially remedied) by #6014

@mattjj
Copy link
Collaborator

mattjj commented Mar 19, 2021

I think you're right, Jake! #6014 will fix the repro because convert_element_type doesn't call _device_put_raw anymore. The total number of _device_put_raw calls (from any caller) goes from 19 (on master) to 2 (on the #6014 branch).

@mattjj mattjj self-assigned this Mar 19, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better_errors Improve the error reporting
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants