-
Hi, im quite new to using GPUs. I just wanted to run a module implemented in jax, on a mac m2 GPU (jax==0.4.11, jax-metal==0.0.4, jaxlib==0.4.11). The code runs on CPU, but when using the GPU, i cant instantiate a jnp.zeros/ones array if 64 bit mode is set to true. Minimum example is:
which throws:
Not sure if im making some mistake or this is a known issue. Also later in the code, when trying to update values in the array, i get another issue which is known (https://developer.apple.com/forums/thread/738697). |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
We have GPU CI tests that cover things like 64-bit array creation. Can you say more about the architecture you are running on? |
Beta Was this translation helpful? Give feedback.
Oh, I see you mention
jax-metal
above.jax-metal
is highly experimental, and does not support the full JAX API. You can see the current list of reported metal-related issues here: https://github.com/google/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22Apple+GPU+%28Metal%29+plugin%22It sounds like you're running into the issue previously reported in #16435
Note that jax-metal is closed source and the JAX core team has no access to the code, so there's not much we can do about issues like this.