Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR #19927: [cuda] Warn about ptxas versions before CUDA 12.6.3
Imported from GitHub PR #19927 This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade. CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer. Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)): - nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154)) - nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84)) - ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263)) (As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.) **Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning. The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default. --- Example: ``` # A JAX-Toolbox image affected $ docker run -it --gpus=all jax:jax-2024-11-25 $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping. Please upgrade to CUDA 12.6.3 or newer. [0 0 0 0 3 6] $ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85" (...) $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" [9 6 3 0 3 6] ``` --- On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant. Copybara import of the project: -- eeae01c by Georg Stefan Schmid <[email protected]>: [cuda] Warn about ptxas versions before CUDA 12.6.3 Merging this change closes #19927 FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn eeae01c PiperOrigin-RevId: 701246876
- Loading branch information