-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NumPy constants device_put multiple times during tracing #5308
Comments
@mattjj You're probably the relevant poc for this issue. |
Thanks for the clean explanation! This does seem like a foot-gun that we can figure out a way to avoid, or at least warn about so it's not silent. I'm not quite sure which yet... |
This is a tricky one, but it's not clear to me that JAX should be expected to avoid this. Rewriting your list comprehension as a loop looks like this: ys = []
for w in ws:
ys.append(jnp.dot(x, w)) When you call ys = []
for w in ws:
ys.append(jnp.dot(x, w))
x[0, 0] = np.random.rand() This is still the same Python object being pushed to device each time (in the sense that How might JAX know at trace-time whether to re-use an array that it encounters? It would have to do a full equality check of the new array to be pushed with each existing array on the device: if the new array matches an existing array, we use it rather than pushing a new array. Is that possible to implement? Sure, some sort of global hash table might be a good solution. But do we want this kind of global state in play for every JAX program, in order to silently handle this kind of corner case? It's not clear to me that the added complexity would be worth the benefit. I think the best answer here would be "use |
Good point about mutable numpy arrays. I always forget mutability :) I think if numpy arrays were immutable then we could do something reasonable, e.g. by binding the convert_element_type primitive in more cases and then de-duping constants when we lower to XLA (which we should do anyway), but mutability makes it hopeless. You're totally right. So I agree the current behavior should probably stay. Maybe we could make this kind of thing easier to catch with some memory-debugger tooling? |
That's a good point, Jake! I think one possible action item might be re-reifying the notion of "baking in" a constant into the computation graph, like Semi-unrelatedly: if I'm not mistaken, isn't the mutable loop buggy anyways? Let This is particularly nasty since the I'm not sure how to fix this without literally blocking on all |
I think just calling |
Edit: this reflects floating point issues, not race conditions...
import numpy as np
import jax.numpy as jnp
x = np.zeros((100, 100))
y = []
for i in range(10):
x[:, :] = i
y.append(jnp.array(x))
print([float(yi.mean()) for yi in y])
# [0.0, 1.0, 2.0, 3.000000238418579, 4.0, 5.0, 6.000000476837158, 7.000000476837158, 8.0, 9.0] |
Happy with I'm pushing more on "baking constants in is part of the |
I believe this issue will be fixed (or at least partially remedied) by #6014 |
In one of my networks, I have a large (64 MB) NumPy constant that I matmul against 160 different weight matrices.
I am noticing that JAX is repeatedly calling
_raw_device_put
on it, seemingly taking up >10G of accelerator HBM during tracing (resulting in an OOM at trace-time).A repro is below:
convert_element_type
when_device_put_raw
is called.Output:
This seems like a sharp edge that might trip others up; it took me a couple hours to track down (in context of a larger model).
I did not observe this prior to omnistaging, as I had
tie_in
-ed the NumPy constant, buttie_in
was turned into a no-op.In the good
wrap_array
case,jnp.array
is basically functioning as atie_in
replacement.The text was updated successfully, but these errors were encountered: