-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX does not recognise my NVIDIA GPU when installed via conda #24604
Comments
We (JAX) only support the |
I briefly looked into this a while ago as well when working on #24139 and #24684. I don't have time or interest to fully resolve the issue as I'm not a conda user myself, but here are a couple of pointers from my notes that might help you debug this. First, in
Even if you fix these, you might have to explicitly pass in |
This should have been fixed by conda-forge/jaxlib-feedstock#288 (jaxlib conda-forge version |
Description
I previously had a working installation of JAX (installed via conda) that recognised my NVIDIA GPU without issue. However, I recently migrated to a new machine and now I cannot get JAX to recognise my GPU when I install via conda. I'm using Miniforge to manage my conda environments, as I did on my old machine, and I installed JAX according to the docs:
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
When I then try to import JAX and check my available devices using:
I get the following output:
tensorflow, however, does recognise my GPU, and so I tried the suggestion from #15268 to install using pip. I created a new environment and ran:
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
When I then ran my JAX code above I got the output:
Voilà, my GPU has been found!
It therefore appears that the conda section of the docs might need updating.
System info (python version, jaxlib version, accelerator, etc.)
Conda installation:
pip installation:
The text was updated successfully, but these errors were encountered: