Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jaxlib installed with CUDA build fails to find GPU #285

Closed
1 task done
flferretti opened this issue Oct 17, 2024 · 9 comments · Fixed by #288
Closed
1 task done

jaxlib installed with CUDA build fails to find GPU #285

flferretti opened this issue Oct 17, 2024 · 9 comments · Fixed by #288
Labels
bug Something isn't working

Comments

@flferretti
Copy link

flferretti commented Oct 17, 2024

Solution to issue cannot be found in the documentation.

  • I checked the documentation.

Issue

When using jaxlib 0.4.32 with CUDA build, jax fails to find a GPU:

Environment created with:

$ conda create -n jaxtest jax jaxlib                                                                conda:jaxtest
$ python -c "import jax; jax.print_environment_info()"                                              conda:jaxtest
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
jax:    0.4.32
jaxlib: 0.4.32
numpy:  2.1.2
python: 3.12.7 | packaged by conda-forge | (main, Oct  4 2024, 16:05:46) [GCC 13.3.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='iitbmp014lw015u', release='6.11.0-8-generic', version='#8-Ubuntu SMP PREEMPT_DYNAMIC Mon Sep 16 13:41:20 UTC 2024', machine='x86_64')


$ nvidia-smi
Thu Oct 17 18:57:11 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4060 ...    Off |   00000000:01:00.0 Off |                  N/A |
| N/A   54C    P8              4W /   40W |      14MiB /   8188MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      3256      G   /usr/bin/gnome-shell                            2MiB |
|    0   N/A  N/A    303067      G   /usr/bin/evolution                              2MiB |
+-----------------------------------------------------------------------------------------+

and by checking the versions of the CUDA packages installed, it seems like some packages are missing:

$ python -c "import jax;jax._src.xla_bridge._check_cuda_versions(debug=True)"                       conda:jaxtest
CUDA components status (debug):
Package: CUDA
Version JAX was built against: 12000
Minimum supported: 12010
Installed version: 12060
--------------------------------------------------
Package: cuDNN
Version JAX was built against: 90300
Minimum supported: 9100
Installed version: 90300
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/fferretti-iit.local/miniforge3/envs/jaxtest/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 421, in _check_cuda_versions
    raise RuntimeError(f'Unable to use CUDA because of the '
RuntimeError: Unable to use CUDA because of the following issues with CUDA components:
Unable to load cuFFT. Is it installed?
Traceback (most recent call last):
  File "/home/fferretti-iit.local/miniforge3/envs/jaxtest/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 339, in _version_check
    version = get_version()
              ^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:54: operation cufftGetVersion(&version) failed: cuFFT internal error

--------------------------------------------------
Unable to load cuSOLVER. Is it installed?
Traceback (most recent call last):
  File "/home/fferretti-iit.local/miniforge3/envs/jaxtest/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 339, in _version_check
    version = get_version()
              ^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:61: operation cusolverGetVersion(&version) failed: cuSolver internal error

--------------------------------------------------
Unable to load cuPTI. Is it installed?
Traceback (most recent call last):
  File "/home/fferretti-iit.local/miniforge3/envs/jaxtest/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 339, in _version_check
    version = get_version()
              ^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:47: operation cuptiGetVersion(&version) failed: Unknown CUPTI error 999. This probably means that JAX was unable to load cupti.

--------------------------------------------------
Unable to load cuBLAS. Is it installed?
Traceback (most recent call last):
  File "/home/fferretti-iit.local/miniforge3/envs/jaxtest/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 339, in _version_check
    version = get_version()
              ^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:71: operation cublasGetVersion( nullptr, &version) failed: cuBlas internal error

--------------------------------------------------
Unable to load cuSPARSE. Is it installed?
Traceback (most recent call last):
  File "/home/fferretti-iit.local/miniforge3/envs/jaxtest/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 339, in _version_check
    version = get_version()
              ^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:80: operation cusparseGetProperty(MAJOR_VERSION, &major) failed: The cuSPARSE library was not found.

C.C. @xela-95 @traversaro @diegoferigo

### Installed packages

# packages in environment at /home/fferretti-iit.local/miniforge3/envs/jaxtest:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
bzip2                     1.0.8                h4bc722e_7    conda-forge
c-ares                    1.34.2               heb4867d_0    conda-forge
ca-certificates           2024.8.30            hbcca054_0    conda-forge
cuda-crt-tools            12.6.77              ha770c72_0    conda-forge
cuda-cudart               12.6.77              h5888daf_0    conda-forge
cuda-cudart_linux-64      12.6.77              h3f2d84a_0    conda-forge
cuda-cupti                12.6.80              hbd13f7d_0    conda-forge
cuda-nvcc-tools           12.6.77              he02047a_0    conda-forge
cuda-nvrtc                12.6.77              hbd13f7d_0    conda-forge
cuda-nvtx                 12.6.77              hbd13f7d_0    conda-forge
cuda-nvvm-tools           12.6.77              he02047a_0    conda-forge
cuda-version              12.6                 h7480c83_3    conda-forge
cudnn                     9.3.0.75             h93bb076_0    conda-forge
importlib-metadata        8.5.0              pyha770c72_0    conda-forge
jax                       0.4.32             pyhd8ed1ab_0    conda-forge
jaxlib                    0.4.32          cuda120py312h9b6c45b_200    conda-forge
ld_impl_linux-64          2.43                 h712a8e2_1    conda-forge
libabseil                 20240722.0      cxx17_h5888daf_1    conda-forge
libblas                   3.9.0           24_linux64_openblas    conda-forge
libcblas                  3.9.0           24_linux64_openblas    conda-forge
libcublas                 12.6.3.3             hbd13f7d_1    conda-forge
libcufft                  11.3.0.4             hbd13f7d_0    conda-forge
libcurand                 10.3.7.77            hbd13f7d_0    conda-forge
libcusolver               11.7.1.2             hbd13f7d_0    conda-forge
libcusparse               12.5.4.2             hbd13f7d_0    conda-forge
libexpat                  2.6.3                h5888daf_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc                    14.2.0               h77fa898_1    conda-forge
libgcc-ng                 14.2.0               h69a702a_1    conda-forge
libgfortran               14.2.0               h69a702a_1    conda-forge
libgfortran-ng            14.2.0               h69a702a_1    conda-forge
libgfortran5              14.2.0               hd5240d6_1    conda-forge
libgomp                   14.2.0               h77fa898_1    conda-forge
libgrpc                   1.65.5               hf5c653b_0    conda-forge
liblapack                 3.9.0           24_linux64_openblas    conda-forge
libnsl                    2.0.1                hd590300_0    conda-forge
libnvjitlink              12.6.77              hbd13f7d_1    conda-forge
libopenblas               0.3.27          pthreads_hac2b453_1    conda-forge
libprotobuf               5.27.5               h5b01275_2    conda-forge
libre2-11                 2024.07.02           hbbce691_1    conda-forge
libsqlite                 3.46.1               hadc24fc_0    conda-forge
libstdcxx                 14.2.0               hc0a3c3a_1    conda-forge
libstdcxx-ng              14.2.0               h4852527_1    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libzlib                   1.3.1                hb9d3cd8_2    conda-forge
ml_dtypes                 0.5.0           py312hf9745cd_0    conda-forge
nccl                      2.23.4.1             h52f6c39_0    conda-forge
ncurses                   6.5                  he02047a_1    conda-forge
numpy                     2.1.2           py312h58c1407_0    conda-forge
openssl                   3.3.2                hb9d3cd8_0    conda-forge
opt-einsum                3.4.0                hd8ed1ab_0    conda-forge
opt_einsum                3.4.0              pyhd8ed1ab_0    conda-forge
pip                       24.2               pyh8b19718_1    conda-forge
python                    3.12.7          hc5c86c4_0_cpython    conda-forge
python_abi                3.12                    5_cp312    conda-forge
re2                       2024.07.02           h77b4e00_1    conda-forge
readline                  8.2                  h8228510_1    conda-forge
scipy                     1.14.1          py312h7d485d2_0    conda-forge
setuptools                75.1.0             pyhd8ed1ab_0    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
tzdata                    2024b                hc8b5060_0    conda-forge
wheel                     0.44.0             pyhd8ed1ab_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
zipp                      3.20.2             pyhd8ed1ab_0    conda-forge
### Environment info


     active environment : None
            shell level : 0
       user config file : /home/fferretti-iit.local/.condarc
 populated config files : /home/fferretti-iit.local/miniforge3/.condarc
                          /home/fferretti-iit.local/.condarc
          conda version : 24.9.1
    conda-build version : not installed
         python version : 3.12.6.final.0
                 solver : libmamba (default)
       virtual packages : __archspec=1=skylake
                          __conda=24.9.1=0
                          __cuda=12.6=0
                          __glibc=2.40=0
                          __linux=6.11.0=0
                          __unix=0=0
       base environment : /home/fferretti-iit.local/miniforge3  (writable)
      conda av data dir : /home/fferretti-iit.local/miniforge3/etc/conda
  conda av metadata url : None
           channel URLs : https://conda.anaconda.org/conda-forge/linux-64
                          https://conda.anaconda.org/conda-forge/noarch
          package cache : /home/fferretti-iit.local/miniforge3/pkgs
                          /home/fferretti-iit.local/.conda/pkgs
       envs directories : /home/fferretti-iit.local/miniforge3/envs
                          /home/fferretti-iit.local/.conda/envs
               platform : linux-64
             user-agent : conda/24.9.1 requests/2.32.3 CPython/3.12.6 Linux/6.11.0-8-generic ubuntu/24.10 glibc/2.40 solver/libmamba conda-libmamba-solver/24.7.0 libmambapy/1.5.9
                UID:GID : 19111:50055
             netrc file : None
           offline mode : False
@flferretti flferretti added the bug Something isn't working label Oct 17, 2024
@diegoferigo
Copy link

diegoferigo commented Oct 17, 2024

Since it is not the first time that there are missing libraries, what about adding the following line in the test/command?

python -c "import jax; jax._src.xla_bridge._check_cuda_versions(debug=True)"

I'm not sure if it returns an exit code different than 0 or the output should be filtered looking for raised exceptions.

Probably this should be done in the jax recipe rather than the jaxlib recipe, unless there is an easy way to extract the jaxlib logic used internally by jax for producing that output.

Edit: https://github.com/jax-ml/jax/blob/1b5cf5a49442f206a25efef238c5623f75563c2b/jax/_src/xla_bridge.py#L289

@traversaro
Copy link
Contributor

That is doable, but I guess it requires adding jaxlib or jax to https://github.com/Quansight/open-gpu-server .

@traversaro
Copy link
Contributor

@taozuoqiao
Copy link

The jaxlib installed with cuda build also fails to find GPU after upgrading to 0.4.34 from 0.4.31,

mamba install jaxlib=*=*cuda120* jax=0.4.34 cuda-nvcc -c conda-forge -c nvidia

The output of python -c "import jax; jax.print_environment_info()"

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
jax:    0.4.34
jaxlib: 0.4.34
numpy:  2.0.2
python: 3.10.15 | packaged by conda-forge | (main, Oct 16 2024, 01:24:24) [GCC 13.3.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='jw168', release='3.10.0-1160.118.1.el7.x86_64', version='#1 SMP Wed Apr 24 16:01:50 UTC 2024', machine='x86_64')


$ nvidia-smi
Mon Oct 21 12:01:49 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:01:00.0 Off |                  Off |
| 44%   25C    P8             20W /  450W |       9MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        Off |   00000000:24:00.0 Off |                  Off |
| 44%   46C    P2            107W /  450W |    2750MiB /  24564MiB |     31%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA GeForce RTX 4090        Off |   00000000:41:00.0 Off |                  Off |
| 44%   25C    P8             15W /  450W |      13MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA GeForce RTX 4090        Off |   00000000:61:00.0 Off |                  Off |
| 43%   25C    P8             22W /  450W |       7MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA GeForce RTX 4090        Off |   00000000:81:00.0 Off |                  Off |
| 44%   37C    P2             99W /  450W |    8380MiB /  24564MiB |     33%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA GeForce RTX 4090        Off |   00000000:A1:00.0 Off |                  Off |
| 44%   26C    P8             21W /  450W |       3MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA GeForce RTX 4090        Off |   00000000:C1:00.0 Off |                  Off |
| 43%   26C    P8             17W /  450W |       3MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA GeForce RTX 4090        Off |   00000000:E1:00.0 Off |                  Off |
| 44%   29C    P2             44W /  450W |    1118MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    1   N/A  N/A     59868      C   python                                       2722MiB |
|    4   N/A  N/A     63101      C   python                                        924MiB |
|    4   N/A  N/A     63123      C   python                                       1074MiB |
|    4   N/A  N/A     63372      C   python                                       1172MiB |
+-----------------------------------------------------------------------------------------+

and all dependencies are installed (python -c "import jax;jax._src.xla_bridge._check_cuda_versions(debug=True)"):

CUDA components status (debug):
Package: CUDA
Version JAX was built against: 12000
Minimum supported: 12010
Installed version: 12040
--------------------------------------------------
Package: cuDNN
Version JAX was built against: 90300
Minimum supported: 9100
Installed version: 90300
--------------------------------------------------
Package: cuFFT
Version JAX was built against: 11000
Minimum supported: 110
Installed version: 11200
--------------------------------------------------
Package: cuSOLVER
Version JAX was built against: 11402
Minimum supported: 11400
Installed version: 11600
--------------------------------------------------
Package: cuPTI
Version JAX was built against: 18
Minimum supported: 18
Installed version: 22
--------------------------------------------------
Package: cuBLAS
Version JAX was built against: 120001
Minimum supported: 120100
Installed version: 120402
--------------------------------------------------
Package: cuSPARSE
Version JAX was built against: 12000
Minimum supported: 12100
Installed version: 12300

@flferretti
Copy link
Author

@conda-forge/jaxlib do you perhaps have any update on this? Thanks

@hanbin973
Copy link

hanbin973 commented Oct 31, 2024

Here is what I found: conda-forge/jax-feedstock#162 (comment)

conda jaxlib currently misses a shared library that connects jax to xla.

@njzjz
Copy link
Member

njzjz commented Nov 18, 2024

JAX documentation says

There are two ways to build jaxlib with CUDA support: (1) use
python build/build.py --enable_cuda to generate a jaxlib wheel with cuda
support, or (2) use
python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12
to generate three wheels (jaxlib without cuda, jax-cuda-plugin, and
jax-cuda-pjrt).

However, I guess the JAX team just uses option (2) but never tests option (1)...

@traversaro
Copy link
Contributor

JAX documentation says

There are two ways to build jaxlib with CUDA support: (1) use
python build/build.py --enable_cuda to generate a jaxlib wheel with cuda
support, or (2) use
python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12
to generate three wheels (jaxlib without cuda, jax-cuda-plugin, and
jax-cuda-pjrt).

However, I guess the JAX team just uses option (2) but never tests option (1)...

Yes, I briefly trying to investigate the issue in the past weeks, and I also got the same impression. As anyhow it does not make a lot of sense to use a method that is not aligned with upstream, we could try to just follow the new instructions, and see if that works.

@traversaro
Copy link
Contributor

For reference, the documentation mentioned by @njzjz is https://jax.readthedocs.io/en/latest/developer.html#building-jaxlib-from-source .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
6 participants