Mark cuda builds of jaxlib 0.4.30 and 0.4.31 as broken due to jax runtime check on cudnn version #1050
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
jax==0.4.31
has a check at runtime (that is not reflected on cuda metadata) on cudnn>=9, and so trying to usecuda
onjax==0.4.31
result in an error like:See conda-forge/jaxlib-feedstock#277 and conda-forge/jax-feedstock#149 .
As the only possible combinations of
jaxlib
andjax
that can be installed together (given their mutual constrained dependencies) is:I think it make sense to mark as broken the cuda builds of
jaxlib==0.4.31
andjaxlib==0.4.30
as they have the constraint of installingcudnn 8.*
, and they can't be used withjax==0.4.31
.This will ensure that users that install
jaxlib==*=*cuda*
to get a cuda-enabled jax actually get a working jax+cuda. This will continue to install a cpu-only version of jax if one just installsconda install jax
, to actually fix that we need instead to rebuild jaxlib with cudnn==9, see conda-forge/conda-forge-pinning-feedstock#6310 .For more details, see the discussion in conda-forge/jax-feedstock#149 (comment) .
As we were discussing on this with @ngam, it would be great to have his feedback before merging.
ping @conda-forge/jaxlib @conda-forge/jax
Checklist: