-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX and TORCH #18032
Comments
That's correct. The current releases of PyTorch and JAX have incompatible CUDA version dependencies. I reported this issue to the PyTorch developers a while back, but there has been no interest in relaxing their CUDA version dependencies. My recommendations:
Does that resolve your problem? Hope that helps! |
This is quite annoying (and inconvenient) now that people have written torch2jax functionality which allows GPU-accelerated interaction, https://github.com/samuela/torch2jax |
Hi @ywsslr, I've been experimenting the simultaneous usage of Torch and JAX for a while. I'm currently working in a Docker container in which they both work on GPU. JAX was installed according to the official documentation as: pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html I leave here the Conda YAML of the environment, there will probably be some extra packages, but I hope this can help: conda environmentname: base
channels:
- nvidia
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- absl-py=1.4.0=py310h06a4308_0
- alsa-lib=1.2.10=hd590300_0
- appdirs=1.4.4=pyhd3eb1b0_0
- asttokens=2.0.5=pyhd3eb1b0_0
- attr=2.5.1=h166bdaf_1
- backcall=0.2.0=pyhd3eb1b0_0
- binutils=2.40=hdd6e379_0
- binutils_impl_linux-64=2.40=hf600244_0
- binutils_linux-64=2.40=hbdbef99_2
- blas=1.0=openblas
- boltons=23.0.0=pyhd8ed1ab_0
- brotli=1.0.9=he6710b0_2
- brotli-python=1.1.0=py310hc6cd4ac_0
- bzip2=1.0.8=h7f98852_4
- c-ares=1.19.1=hd590300_0
- c-compiler=1.6.0=hd590300_0
- ca-certificates=2023.7.22=hbcca054_0
- cairo=1.16.0=hb05425b_5
- certifi=2023.7.22=pyhd8ed1ab_0
- cffi=1.15.1=py310h255011f_3
- charset-normalizer=3.2.0=pyhd8ed1ab_0
- chex=0.1.5=py310h06a4308_0
- click=8.0.4=py310h06a4308_0
- colorama=0.4.6=pyhd8ed1ab_0
- coloredlogs=15.0.1=py310h06a4308_1
- compilers=1.6.0=ha770c72_0
- conda=23.3.1=py310hff52083_0
- conda-package-handling=2.2.0=pyh38be061_0
- conda-package-streaming=0.9.0=pyhd8ed1ab_0
- contourpy=1.0.5=py310hdb19cb5_0
- cryptography=41.0.3=py310h75e40e8_0
- cuda-nvcc=11.3.58=h2467b9f_0
- cuda-version=11.8=h70ddcb2_2
- cudatoolkit=11.8.0=h4ba93d1_12
- cudnn=8.9.2.26=cuda11_0
- cupti=11.8.0=he078b1a_0
- cxx-compiler=1.6.0=h00ab1b0_0
- cycler=0.11.0=pyhd3eb1b0_0
- dbus=1.13.18=hb2f20db_0
- deap=1.4.1=py310h7cbd5c2_0
- decorator=5.1.1=pyhd3eb1b0_0
- dm-tree=0.1.7=py310h6a678d5_1
- docker-pycreds=0.4.0=pyhd3eb1b0_0
- docstring_parser=0.15=pyhd8ed1ab_0
- exceptiongroup=1.0.4=py310h06a4308_0
- executing=0.8.3=pyhd3eb1b0_0
- expat=2.5.0=h6a678d5_0
- filelock=3.9.0=py310h06a4308_0
- flax=0.6.1=pyhd8ed1ab_1
- fmt=9.1.0=h924138e_0
- font-ttf-dejavu-sans-mono=2.37=hd3eb1b0_0
- font-ttf-inconsolata=2.001=hcb22688_0
- font-ttf-source-code-pro=2.030=hd3eb1b0_0
- font-ttf-ubuntu=0.83=h8b1ccd4_0
- fontconfig=2.14.2=h14ed4e7_0
- fonts-anaconda=1=h8fa9717_0
- fonts-conda-ecosystem=1=hd3eb1b0_0
- fonttools=4.25.0=pyhd3eb1b0_0
- fortran-compiler=1.6.0=heb67821_0
- freetype=2.12.1=h4a9f257_0
- frozendict=2.3.8=py310h2372a71_0
- gcc=12.3.0=h8d2909c_2
- gcc_impl_linux-64=12.3.0=he2b93b0_1
- gcc_linux-64=12.3.0=h76fc315_2
- gettext=0.21.1=h27087fc_0
- gfortran=12.3.0=h499e0f7_2
- gfortran_impl_linux-64=12.3.0=hfcedea8_1
- gfortran_linux-64=12.3.0=h7fe76b4_2
- gitdb=4.0.7=pyhd3eb1b0_0
- gitpython=3.1.30=py310h06a4308_0
- glib=2.78.0=hfc55251_0
- glib-tools=2.78.0=hfc55251_0
- gmp=6.2.1=h295c915_3
- gmpy2=2.1.2=py310heeb90bb_0
- graphite2=1.3.14=h295c915_1
- gst-plugins-base=1.22.5=h8e1006c_1
- gstreamer=1.22.5=h98fc4e7_1
- gxx=12.3.0=h8d2909c_2
- gxx_impl_linux-64=12.3.0=he2b93b0_1
- gxx_linux-64=12.3.0=h8a814eb_2
- harfbuzz=8.2.0=h3d44ed6_0
- humanfriendly=10.0=py310h06a4308_1
- icu=73.2=h59595ed_0
- idna=3.4=pyhd8ed1ab_0
- intel-openmp=2023.1.0=hdb19cb5_46305
- ipython=8.15.0=py310h06a4308_0
- jax-dataclasses=1.5.1=pyhd8ed1ab_0
- jaxlie=1.3.3=pyhd8ed1ab_0
- jedi=0.18.1=py310h06a4308_1
- jinja2=3.1.2=py310h06a4308_0
- jsonpatch=1.32=pyhd8ed1ab_0
- jsonpointer=2.4=py310hff52083_0
- kernel-headers_linux-64=2.6.32=he073ed8_16
- keyutils=1.6.1=h166bdaf_0
- kiwisolver=1.4.4=py310h6a678d5_0
- krb5=1.21.2=h659d440_0
- lame=3.100=h7b6447c_0
- lcms2=2.15=h7f713cb_2
- ld_impl_linux-64=2.40=h41732ed_0
- lerc=4.0.0=h27087fc_0
- libarchive=3.6.2=h039dbb9_1
- libcap=2.69=h0f662aa_0
- libclang=15.0.7=default_h7634d5b_3
- libclang13=15.0.7=default_h9986a30_3
- libcups=2.3.3=h4637d8d_4
- libcurl=8.3.0=hca28451_0
- libdeflate=1.19=hd590300_0
- libedit=3.1.20191231=he28a2e2_2
- libev=4.33=h516909a_1
- libevent=2.1.12=hdbd6064_1
- libexpat=2.5.0=hcb278e6_1
- libffi=3.4.2=h7f98852_5
- libflac=1.4.3=h59595ed_0
- libgcc-devel_linux-64=12.3.0=h8bca6fd_1
- libgcc-ng=13.2.0=h807b86a_0
- libgcrypt=1.10.1=h166bdaf_0
- libgfortran-ng=13.2.0=h69a702a_1
- libgfortran5=13.2.0=ha4646dd_1
- libglib=2.78.0=hebfc3b9_0
- libgomp=13.2.0=h807b86a_0
- libgpg-error=1.47=h71f35ed_0
- libiconv=1.17=h166bdaf_0
- libjpeg-turbo=2.1.5.1=hd590300_1
- libllvm15=15.0.7=h5cf9203_3
- libmamba=1.2.0=hcea66bb_0
- libmambapy=1.2.0=py310h1428755_0
- libnghttp2=1.52.0=h61bc06f_0
- libnsl=2.0.0=h7f98852_0
- libogg=1.3.5=h27cfd23_1
- libopenblas=0.3.21=h043d6bf_0
- libopus=1.3.1=h7b6447c_0
- libpng=1.6.39=h5eee18b_0
- libpq=15.4=hfc447b1_0
- libprotobuf=3.20.3=he621ea3_0
- libsanitizer=12.3.0=h0f45ef3_1
- libsndfile=1.2.2=hbc2eb40_0
- libsolv=0.7.24=hfc55251_4
- libsqlite=3.43.0=h2797004_0
- libssh2=1.11.0=h0841786_0
- libstdcxx-devel_linux-64=12.3.0=h8bca6fd_1
- libstdcxx-ng=13.2.0=h7e041cc_0
- libsystemd0=254=h3516f8a_0
- libtiff=4.6.0=h29866fb_1
- libuuid=2.38.1=h0b41bf4_0
- libvorbis=1.3.7=h7b6447c_0
- libwebp-base=1.3.2=h5eee18b_0
- libxcb=1.15=h7f8727e_0
- libxkbcommon=1.5.0=h5d7e998_3
- libxml2=2.11.5=h232c23b_1
- libzlib=1.2.13=hd590300_5
- lz4-c=1.9.4=hcb278e6_0
- lzo=2.10=h516909a_1000
- magma=2.7.1=h2c23e93_0
- mamba=1.2.0=py310h51d5547_0
- markdown-it-py=2.2.0=py310h06a4308_1
- markupsafe=2.1.1=py310h7f8727e_0
- mashumaro=3.6=py310h06a4308_0
- matplotlib=3.7.2=py310h06a4308_0
- matplotlib-base=3.7.2=py310h1128e8f_0
- matplotlib-inline=0.1.6=py310h06a4308_0
- mdurl=0.1.0=py310h06a4308_0
- mkl=2023.1.0=h213fc3f_46343
- mpc=1.1.0=h10f8cd9_1
- mpfr=4.0.2=hb69a4c5_1
- mpg123=1.31.3=hcb278e6_0
- mpmath=1.3.0=py310h06a4308_0
- msgpack-python=1.0.3=py310hd09550d_0
- munkres=1.1.4=py_0
- mysql-common=8.0.33=hf1915f5_4
- mysql-libs=8.0.33=hca2cd23_4
- ncurses=6.4=hcb278e6_0
- networkx=3.1=py310h06a4308_0
- ninja=1.10.2=h06a4308_5
- ninja-base=1.10.2=hd09550d_5
- nspr=4.35=h6a678d5_0
- nss=3.92=h1d7d5a4_0
- numpy=1.25.2=py310heeff2f4_0
- numpy-base=1.25.2=py310h8a23956_0
- openjpeg=2.5.0=h488ebb8_3
- openssl=3.1.2=hd590300_0
- opt_einsum=3.3.0=pyhd3eb1b0_1
- optax=0.1.4=py310h06a4308_0
- overrides=7.4.0=pyhd8ed1ab_0
- packaging=23.1=pyhd8ed1ab_0
- parso=0.8.3=pyhd3eb1b0_0
- pathtools=0.1.2=pyhd3eb1b0_1
- pcre2=10.40=hc3806b6_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pickleshare=0.7.5=pyhd3eb1b0_1003
- pillow=10.0.1=py310h29da1c1_0
- pip=23.2.1=pyhd8ed1ab_0
- pixman=0.40.0=h7f8727e_1
- pluggy=1.3.0=pyhd8ed1ab_0
- ply=3.11=py310h06a4308_0
- pptree=3.1=pyhd8ed1ab_0
- prompt-toolkit=3.0.36=py310h06a4308_0
- protobuf=3.20.3=py310h6a678d5_0
- psutil=5.9.0=py310h5eee18b_0
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pulseaudio-client=16.1=hb77b528_5
- pure_eval=0.2.2=pyhd3eb1b0_0
- pybind11-abi=4=hd8ed1ab_3
- pycosat=0.6.4=py310h5764c6d_1
- pycparser=2.21=pyhd8ed1ab_0
- pygments=2.15.1=py310h06a4308_1
- pyopenssl=23.2.0=pyhd8ed1ab_1
- pyparsing=3.0.9=py310h06a4308_0
- pyqt=5.15.9=py310h04931ad_4
- pyqt5-sip=12.12.2=py310hc6cd4ac_4
- pysocks=1.7.1=pyha2e5f31_6
- python=3.10.8=h4a9ceb5_0_cpython
- python-dateutil=2.8.2=pyhd3eb1b0_0
- python_abi=3.10=3_cp310
- pytorch=2.0.1=gpu_cuda118py310h7799f5a_0
- pyyaml=6.0=py310h5eee18b_1
- qt-main=5.15.8=hc47bfe8_16
- readline=8.2=h8228510_1
- reproc=14.2.4=h0b41bf4_0
- reproc-cpp=14.2.4=hcb278e6_0
- requests=2.31.0=pyhd8ed1ab_0
- rich=13.3.5=py310h06a4308_0
- ruamel.yaml=0.17.32=py310h2372a71_0
- ruamel.yaml.clib=0.2.7=py310h1fa729e_1
- scipy=1.11.1=py310heeff2f4_0
- sentry-sdk=1.9.0=py310h06a4308_0
- setproctitle=1.2.2=py310h7f8727e_0
- setuptools=68.2.2=pyhd8ed1ab_0
- shtab=1.6.4=pyhd8ed1ab_1
- sip=6.7.11=py310hc6cd4ac_0
- six=1.16.0=pyhd3eb1b0_1
- smmap=4.0.0=pyhd3eb1b0_0
- stack_data=0.2.0=pyhd3eb1b0_0
- sympy=1.11.1=py310h06a4308_0
- sysroot_linux-64=2.12=he073ed8_16
- tbb=2021.8.0=hdb19cb5_0
- tk=8.6.12=h27826a3_0
- toml=0.10.2=pyhd3eb1b0_0
- tomli=2.0.1=py310h06a4308_0
- toolz=0.12.0=pyhd8ed1ab_0
- tornado=6.3.2=py310h5eee18b_0
- tqdm=4.66.1=pyhd8ed1ab_0
- traitlets=5.7.1=py310h06a4308_0
- typing-extensions=4.7.1=py310h06a4308_0
- typing_extensions=4.7.1=py310h06a4308_0
- typing_utils=0.1.0=pyhd8ed1ab_0
- tyro=0.5.7=pyhd8ed1ab_0
- tzdata=2023c=h71feb2d_0
- urllib3=2.0.4=pyhd8ed1ab_0
- wandb=0.15.10=pyhd8ed1ab_0
- wcwidth=0.2.5=pyhd3eb1b0_0
- wheel=0.41.2=pyhd8ed1ab_0
- xcb-util=0.4.0=hd590300_1
- xcb-util-image=0.4.0=h8ee46fc_1
- xcb-util-keysyms=0.4.0=h8ee46fc_1
- xcb-util-renderutil=0.3.9=hd590300_1
- xcb-util-wm=0.4.1=h8ee46fc_1
- xkeyboard-config=2.39=hd590300_0
- xorg-kbproto=1.0.7=h7f98852_1002
- xorg-libice=1.1.1=hd590300_0
- xorg-libsm=1.2.4=h7391055_0
- xorg-libx11=1.8.6=h8ee46fc_0
- xorg-libxau=1.0.11=hd590300_0
- xorg-libxext=1.3.4=h0b41bf4_2
- xorg-libxrender=0.9.11=hd590300_0
- xorg-renderproto=0.11.1=h7f98852_1002
- xorg-xextproto=7.3.0=h0b41bf4_1003
- xorg-xf86vidmodeproto=2.3.1=h7f98852_1002
- xorg-xproto=7.0.31=h27cfd23_1007
- xz=5.2.6=h166bdaf_0
- yaml=0.2.5=h7b6447c_0
- yaml-cpp=0.7.0=h27087fc_2
- zlib=1.2.13=hd590300_5
- zstandard=0.19.0=py310h5764c6d_0
- zstd=1.5.5=hfc55251_0
- pip:
- jax==0.4.18
- jaxlib==0.4.18+cuda12.cudnn89
- ml-dtypes==0.3.1
- nvidia-cublas-cu12==12.2.5.6
- nvidia-cuda-cupti-cu12==12.2.142
- nvidia-cuda-nvcc-cu12==12.2.140
- nvidia-cuda-nvrtc-cu12==12.2.140
- nvidia-cuda-runtime-cu12==12.2.140
- nvidia-cudnn-cu12==8.9.4.25
- nvidia-cufft-cu12==11.0.8.103
- nvidia-cusolver-cu12==11.5.2.141
- nvidia-cusparse-cu12==12.1.2.141
- nvidia-nccl-cu12==2.18.3
- nvidia-nvjitlink-cu12==12.2.140
prefix: /conda
|
Thank you for your all help. For some reason I can't experience it now,but I'll try it soon and reply you. |
ok people, this has been a 1 day nightmare. But finally got this to work on an H100 machine with cuda 12.2, without sudo.
then install pytorch from source as that post says!!!! and bualaaa |
No promises, but informally we're going to try to keep at least one JAX release have a version that is also released with PyTorch. Right now, that's the CUDA 11.8 release of JAX. It's not a guarantee, though; it might happen that for some JAX and Pytorch versions there's no intersecting CUDA version. |
I hit a similar issue when installing pytorch and jax into the same conda environment: when torch is loaded first, A short summary of diagnosis: It turns out that torch is built against cudnn version 8.7 while jaxlib is built against cudnn version 8.8 leading to an exception when executing Here follows a reproducer: mamba create -n test-pytorch-jax pytorch pytorch-cuda=11.8 jaxlib=*=*cuda118* jax -c pytorch -c nvidia --no-channel-priority -y
mamba activate test-pytorch-jax (note: using strict channel priority would lead to a mamba solver problem). Import torch before checking jax.devices: >>> import torch
>>> import jax
>>> jax.devices()
CUDA backend failed to initialize: Found cuDNN version 8700, but JAX was built against version 8800, which is newer. The copy of cuDNN that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)] Import torch after checking jax.devices: >>> import jax
>>> jax.devices()
[cuda(id=0), cuda(id=1)]
>>> import torch
>>> jax.__version__
'0.4.23'
>>> torch.__version__
'2.1.2'
>>> from torch._C import _cudnn
>>> _cudnn.getCompileVersion()
(8, 7, 0) Notices that the result of >>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8902
>>> import torch
>>> jaxlib.cuda._versions.cudnn_get_version()
8902 vs >>> import torch
>>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8700 that qualifies as an incompatible linkage issue: since libcudnn is dynamically loaded, the result of cudnnGetVersion ought to give the version of loaded library and not of the version of the library that a software was built against. The behavior above suggests that torch was linked with libcudnn statically. A possible resolution: Note that cuDNN minor releases are backward compatible with applications built against the same or earlier minor release. Hence, as long as jaxlib and torch are built against libcudnn with the same major version (8), the jax version check ought to ignore cudnn minor versions. Here is a patch:
|
The latest version pair I could find that were compatible with each other were One way to check this would be: cat > requirements.in <<EOF
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
--extra-index-url=https://download.pytorch.org/whl
jax[cuda11_pip]
torch==2.2.1+cu118
EOF
pip-compile
# Check the contents of requirements.txt. |
A workaround that works better for us is to use CUDA 11 with Jax, but CUDA 12 with Torch. So basically |
How did you get this to work? I'm using conda, but after installing
|
We did not have to do anything special. Just installed the two packages in a clean env, and both worked. |
The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel: mamba create -n jaxTorch jaxlib pytorch cuda-nvcc -c conda-forge -c nvidia -c pytorch >>> import torch
>>> import jax
>>> torch.cuda.is_available()
True
>>> jax.devices()
[cuda(id=0)]
>>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8902
>>> torch._C._cudnn.getCompileVersion()
(8, 9, 2) conda list# Name Version Build Channel
_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 2_kmp_llvm conda-forge
_sysroot_linux-64_curr_repodata_hack 3 h69a702a_14 conda-forge
binutils_impl_linux-64 2.40 hf600244_0 conda-forge
binutils_linux-64 2.40 hdade7a5_3 conda-forge
blas 2.116 mkl conda-forge
blas-devel 3.9.0 16_linux64_mkl conda-forge
bzip2 1.0.8 hd590300_5 conda-forge
c-ares 1.27.0 hd590300_0 conda-forge
ca-certificates 2024.2.2 hbcca054_0 conda-forge
cuda-cccl_linux-64 12.1.109 ha770c72_0 conda-forge
cuda-cudart 12.1.105 hd3aeb46_0 conda-forge
cuda-cudart-dev 12.1.105 hd3aeb46_0 conda-forge
cuda-cudart-dev_linux-64 12.1.105 h59595ed_0 conda-forge
cuda-cudart-static 12.1.105 hd3aeb46_0 conda-forge
cuda-cudart-static_linux-64 12.1.105 h59595ed_0 conda-forge
cuda-cudart_linux-64 12.1.105 h59595ed_0 conda-forge
cuda-cupti 12.1.105 h59595ed_0 conda-forge
cuda-driver-dev_linux-64 12.1.105 h59595ed_0 conda-forge
cuda-libraries 12.1.0 0 nvidia
cuda-nvcc 12.1.105 hcdd1206_1 conda-forge
cuda-nvcc-dev_linux-64 12.1.105 ha770c72_0 conda-forge
cuda-nvcc-impl 12.1.105 hd3aeb46_0 conda-forge
cuda-nvcc-tools 12.1.105 hd3aeb46_0 conda-forge
cuda-nvcc_linux-64 12.1.105 h8a487aa_1 conda-forge
cuda-nvrtc 12.1.105 hd3aeb46_0 conda-forge
cuda-nvtx 12.1.105 h59595ed_0 conda-forge
cuda-opencl 12.1.105 h59595ed_0 conda-forge
cuda-runtime 12.1.0 0 nvidia
cuda-version 12.1 h1d6eff3_3 conda-forge
cudnn 8.9.7.29 h092f7fd_3 conda-forge
filelock 3.13.3 pyhd8ed1ab_0 conda-forge
gcc_impl_linux-64 12.3.0 he2b93b0_5 conda-forge
gcc_linux-64 12.3.0 h6477408_3 conda-forge
gxx_impl_linux-64 12.3.0 he2b93b0_5 conda-forge
gxx_linux-64 12.3.0 h4a1b8e8_3 conda-forge
icu 73.2 h59595ed_0 conda-forge
importlib-metadata 7.1.0 pyha770c72_0 conda-forge
importlib_metadata 7.1.0 hd8ed1ab_0 conda-forge
jax 0.4.25 pyhd8ed1ab_0 conda-forge
jaxlib 0.4.23 cuda120py312hc008a70_200 conda-forge
jinja2 3.1.3 pyhd8ed1ab_0 conda-forge
kernel-headers_linux-64 3.10.0 h4a8ded7_14 conda-forge
ld_impl_linux-64 2.40 h41732ed_0 conda-forge
libabseil 20240116.1 cxx17_h59595ed_2 conda-forge
libblas 3.9.0 16_linux64_mkl conda-forge
libcblas 3.9.0 16_linux64_mkl conda-forge
libcublas 12.1.0.26 0 nvidia
libcufft 11.0.2.4 0 nvidia
libcufile 1.6.1.9 hd3aeb46_0 conda-forge
libcurand 10.3.2.106 hd3aeb46_0 conda-forge
libcusolver 11.4.4.55 0 nvidia
libcusparse 12.0.2.55 0 nvidia
libexpat 2.6.2 h59595ed_0 conda-forge
libffi 3.4.2 h7f98852_5 conda-forge
libgcc-devel_linux-64 12.3.0 h8bca6fd_105 conda-forge
libgcc-ng 13.2.0 h807b86a_5 conda-forge
libgfortran-ng 13.2.0 h69a702a_5 conda-forge
libgfortran5 13.2.0 ha4646dd_5 conda-forge
libgomp 13.2.0 h807b86a_5 conda-forge
libgrpc 1.62.1 h15f2491_0 conda-forge
libhwloc 2.9.3 default_h554bfaf_1009 conda-forge
libiconv 1.17 hd590300_2 conda-forge
liblapack 3.9.0 16_linux64_mkl conda-forge
liblapacke 3.9.0 16_linux64_mkl conda-forge
libnpp 12.0.2.50 0 nvidia
libnsl 2.0.1 hd590300_0 conda-forge
libnvjitlink 12.1.105 hd3aeb46_0 conda-forge
libnvjpeg 12.1.1.14 0 nvidia
libprotobuf 4.25.3 h08a7969_0 conda-forge
libre2-11 2023.09.01 h5a48ba9_2 conda-forge
libsanitizer 12.3.0 h0f45ef3_5 conda-forge
libsqlite 3.45.2 h2797004_0 conda-forge
libstdcxx-devel_linux-64 12.3.0 h8bca6fd_105 conda-forge
libstdcxx-ng 13.2.0 h7e041cc_5 conda-forge
libuuid 2.38.1 h0b41bf4_0 conda-forge
libxcrypt 4.4.36 hd590300_1 conda-forge
libxml2 2.12.6 h232c23b_1 conda-forge
libzlib 1.2.13 hd590300_5 conda-forge
llvm-openmp 15.0.7 h0cdce71_0 conda-forge
markupsafe 2.1.5 py312h98912ed_0 conda-forge
mkl 2022.1.0 h84fe81f_915 conda-forge
mkl-devel 2022.1.0 ha770c72_916 conda-forge
mkl-include 2022.1.0 h84fe81f_915 conda-forge
ml_dtypes 0.3.2 py312hfb8ada1_0 conda-forge
mpmath 1.3.0 pyhd8ed1ab_0 conda-forge
nccl 2.20.5.1 h3a97aeb_0 conda-forge
ncurses 6.4.20240210 h59595ed_0 conda-forge
networkx 3.2.1 pyhd8ed1ab_0 conda-forge
numpy 1.26.4 py312heda63a1_0 conda-forge
ocl-icd 2.3.2 hd590300_1 conda-forge
openssl 3.2.1 hd590300_1 conda-forge
opt-einsum 3.3.0 hd8ed1ab_2 conda-forge
opt_einsum 3.3.0 pyhc1e730c_2 conda-forge
pip 24.0 pyhd8ed1ab_0 conda-forge
python 3.12.2 hab00c5b_0_cpython conda-forge
python_abi 3.12 4_cp312 conda-forge
pytorch 2.2.1 py3.12_cuda12.1_cudnn8.9.2_0 pytorch
pytorch-cuda 12.1 ha16c6d3_5 pytorch
pytorch-mutex 1.0 cuda pytorch
pyyaml 6.0.1 py312h98912ed_1 conda-forge
re2 2023.09.01 h7f4b329_2 conda-forge
readline 8.2 h8228510_1 conda-forge
scipy 1.12.0 py312heda63a1_2 conda-forge
setuptools 69.2.0 pyhd8ed1ab_0 conda-forge
sympy 1.12 pyh04b8f61_3 conda-forge
sysroot_linux-64 2.17 h4a8ded7_14 conda-forge
tbb 2021.11.0 h00ab1b0_1 conda-forge
tk 8.6.13 noxft_h4845f30_101 conda-forge
typing_extensions 4.10.0 pyha770c72_0 conda-forge
tzdata 2024a h0c530f3_0 conda-forge
wheel 0.43.0 pyhd8ed1ab_0 conda-forge
xz 5.2.6 h166bdaf_0 conda-forge
yaml 0.2.5 h7f98852_2 conda-forge
zipp 3.17.0 pyhd8ed1ab_0 conda-forge fyi @traversaro |
FYI, at the moment it is not possible to get both jax and pytorch with cuda 12 only using conda-forge dependencies for this reason (I pinned several dependencies to get a clearer error):
Once a conda-forge pytorch version gets compiled with libprotobuf==4.25.3 (i.e. conda-forge/pytorch-cpu-feedstock#228 is ready and merged, big thanks to who the pytorch and jax conda-forge mantainers) it should be possible to install both jax and pytorch with cuda enabled and using cuda 12 just with conda-forge packages. |
JAX 0.4.26 relaxed our CUDA version dependencies so the minimum CUDA version for JAX is 12.1. This is a version also supported by PyTorch. Try it out! We're going to try to make sure our supported version range overlaps with at least one PyTorch release. We dropped support for CUDA 11, note. |
After a bunch of fixes from both jax and pytorch mantainers, now (late May 2024) it is possible to just install jax and pytorch from conda-forge on Linux and out of the box they will work with GPU/CUDA support without the need to use any other conda channel:
If for some reason this command does not install the cuda-enabled jax, perhaps you are still using the classic conda solver, in that case you can force the installation of cuda-enabled jax and pytorch with:
However, this is not necessary if you are using a recent conda install that defaults to use the conda list for reference
|
Can someone please point out the correct version necessary to get pytorch and jax both with GPU support on CUDA 12 as of July 2024? |
@varadVaidya totally by chance I follow this issue, but in general you may have more success in finding help by using official jax help channels (see https://jax.readthedocs.io/en/latest/beginner_guide.html#finding-help), rather then posting in closed issues. More on topic, I have no idea about pip/venv with cuda, but for conda the procedure posted in #18032 (comment) is working fine for me (when I originally posted the message I forgot to add the @eliseoe @bebark @shaikalthaf4 By change I just noticed that you added a 👎🏽 reaction to my previous comment, any reason for doing so? Just fyi, authors do not get (at least by default) notifications for post reactions. |
@traversaro I found that running your command with |
Interestingly, in my system with:
the command
installs the cuda jax, but indeed:
installs cpu jax. Perhaps you are using an old conda version that is using the classic solver by default? (You can see this if you report the However, even with the classic solver forcing the solver to install the cuda version of jaxlib and pytorch works as expected (even if the classic solver is much slower):
I edited the original comment accordingly. |
You are right, I'm using the classic solver:
|
Thanks for the solution, however i have found a possible bug that the jax numpy cannot initialize an array which size is bigger than (2, 52, 10) with both jax and jaxlib version are 0.4.30, so i have to downgrade the jax version to 0.4.23 and then works just fine, so for the insurance, the command could be like
python 3.12 is too newer to some commonly used pkgs |
Just a curiosity, are you actually getting any packages from the |
I'm not sure, maybe later i can do a test,thx for the noticing |
Not sure how you can can end up with jax/jaxlib installed via pypi if you just created the environment with |
In my case, the conflicts comes from the torch and jaxlib stick to different cudnn version, formerly i didn't seek to conda-forge to install the cudatoolkit compatible for both torch and jaxlib. i use the pip command from the official jax documentation btw. |
Ok, but in that case it is probably a good idea not to install |
i think the only reason for the |
But conda has no idea which version of cudnn the jaxlib installed via pip requires. If you want to install cudnn (and even a specific version) with conda, just install cudnn, to avoid problems is tipically useful to avoid to install jax or jaxlib via conda if you are installing it via pip. |
you are right, accidentally i use the pip install, and it just found the cudnn version meets the requirement lol. |
@traversaro apologies to have to revive this issue, but your solution does not work for me:
The error message suggests I need System info: Operating System: Ubuntu 22.04.4 LTS
UPDATE: |
@lucascolley thanks for reporting the issue, can you please open an issue in https://github.com/conda-forge/jaxlib-feedstock and tag me there? Thanks! |
Thanks @lucascolley, indeed it seems a regression in the conda-forge jax package 0.4.31, I opened conda-forge/jaxlib-feedstock#277 to track the problem. |
Description
When I only pip the latesd jax with cuda(pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html), I can use the jax with gpu.But when I pip the torch(pip install torch) later, Ican't use the jax with gpu,it remind me that cuda or cusolver's version is older than jax's.Why? Can Older jax version avoid it?Then how can I pip the jax[cuda] with relevant version?
What jax/jaxlib version are you using?
jax-0.4.18 jaxlib-0.4.18+cuda12.cudnn89
Which accelerator(s) are you using?
GPU
Additional system info
3.10.9/Linux
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: