Replies: 1 comment
-
You may be able to address this by specifying with jax.default_device(cpu):
result = jax.jacrev(func)(a, cpu) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I have a function where a list of JAX arrays needs to be vertically stacked. The problem is that if I use
jax.numpy.vstack
to do the stacking, I run out of GPU memory. On the other hand, if I use ordinarynumpy
, the stacking takes place in CPU (where I have enough memory) but causes tracer error if I try to do any transformation.I tried wrapping the
onp.vstack
with apure_callback
plus acustom_jvp
rule. However, the result of the forward pass seems to end up in GPU memory anyway.Finally, I tried to put the list of arrays into cpu using
device_put()
and then usejax.numpy.vstack
. But I am not sure if that large matrix would always stay in CPU if I perform transformations on the function (like a reverse-mode autodiff).Is there a way to ensure that a particular computation will always take place in CPU, while being JAX transformable?
Beta Was this translation helpful? Give feedback.
All reactions