Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cuda] Bump nvidia-cuda-nvcc-cu12 dependency to 12.6.85 #25091

Merged

Conversation

gspschmid
Copy link
Contributor

Bumps the minimum version for nvidia-cuda-nvcc-cu12 to 12.6.85, which is the earliest published version that of that wheel incorporating CUDA 12.6.3.

This resolves the issue underlying #24438. Here's a simplified reproducer:

$ docker run -ti --gpus all ubuntu
$ nvidia-smi
(... H100)

$ apt-get update
$ apt-get install -y python3-pip
$ export PIP_BREAK_SYSTEM_PACKAGES=1

# Install JAX nightly (picks up recent nvidia-cuda-nvcc-cu12 version, e.g. 12.6.85)
$ pip install -U --pre jax jaxlib jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html

# => GOOD
$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[9 6 3 0 3 6]

# Force-install bad version
$ pip install --force-reinstall "nvidia-cuda-nvcc-cu12==12.6.77"

# => BAD
$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[0 0 0 0 3 6]

# Upgrade nvcc dependency
$ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85"

# => GOOD
$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[9 6 3 0 3 6]

As a follow-up we might also want to add a warning asking users with an affected version of nvidia-cuda-nvcc-cu12 to upgrade (e.g., via pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85").

@hawkinsp Any opinions on whether to have that warning / where it should go? As to how to check for the current ptxas version (https://github.com/jax-ml/jax/pull/24438/files#r1818345222), I suppose we might expose that alongside other CUDA versions here (

NB_MODULE(_versions, m) {
)? A cruder alternative approach might be to check for the nvidia-cuda-nvcc-cu12 wheel's __version__, if present.

@hawkinsp
Copy link
Collaborator

I think the best place for the warning would be in XLA: it should warn if it detects a ptxas that is known to be buggy.

@gspschmid
Copy link
Contributor Author

Corresponding XLA PR that emits a warning: openxla/xla#19927

@copybara-service copybara-service bot merged commit 6d4278d into jax-ml:main Nov 29, 2024
16 of 17 checks passed
@gspschmid gspschmid deleted the gschmid/nvidia-cuda-nvcc-cu12_12-6-85 branch November 29, 2024 12:06
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 29, 2024
Imported from GitHub PR #19927

This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade.

CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer.

Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)):

- nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154))
- nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84))
- ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263))

(As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.)

**Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning.

The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default.

---

Example:
```
# A JAX-Toolbox image affected
$ docker run -it --gpus=all jax:jax-2024-11-25

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
[0 0 0 0 3 6]

$ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85"
(...)

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[9 6 3 0 3 6]
```

---

On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant.
Copybara import of the project:

--
eeae01c by Georg Stefan Schmid <[email protected]>:

[cuda] Warn about ptxas versions before CUDA 12.6.3

Merging this change closes #19927

FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn eeae01c
PiperOrigin-RevId: 701246876
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 29, 2024
Imported from GitHub PR openxla/xla#19927

This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade.

CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer.

Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)):

- nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154))
- nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84))
- ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263))

(As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.)

**Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning.

The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default.

---

Example:
```
# A JAX-Toolbox image affected
$ docker run -it --gpus=all jax:jax-2024-11-25

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
[0 0 0 0 3 6]

$ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85"
(...)

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[9 6 3 0 3 6]
```

---

On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant.
Copybara import of the project:

--
eeae01c99b76aa177916aa2e5106475de416cc5b by Georg Stefan Schmid <[email protected]>:

[cuda] Warn about ptxas versions before CUDA 12.6.3

Merging this change closes #19927

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#19927 from gspschmid:gschmid/ptxax-version-warn eeae01c99b76aa177916aa2e5106475de416cc5b
PiperOrigin-RevId: 701246876
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 29, 2024
Imported from GitHub PR #19927

This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade.

CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer.

Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)):

- nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154))
- nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84))
- ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263))

(As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.)

**Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning.

The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default.

---

Example:
```
# A JAX-Toolbox image affected
$ docker run -it --gpus=all jax:jax-2024-11-25

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
[0 0 0 0 3 6]

$ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85"
(...)

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[9 6 3 0 3 6]
```

---

On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant.
Copybara import of the project:

--
eeae01c by Georg Stefan Schmid <[email protected]>:

[cuda] Warn about ptxas versions before CUDA 12.6.3

Merging this change closes #19927

FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn eeae01c
PiperOrigin-RevId: 701246876
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 29, 2024
Imported from GitHub PR openxla/xla#19927

This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade.

CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer.

Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)):

- nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154))
- nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84))
- ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263))

(As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.)

**Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning.

The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default.

---

Example:
```
# A JAX-Toolbox image affected
$ docker run -it --gpus=all jax:jax-2024-11-25

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
[0 0 0 0 3 6]

$ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85"
(...)

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[9 6 3 0 3 6]
```

---

On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant.
Copybara import of the project:

--
eeae01c99b76aa177916aa2e5106475de416cc5b by Georg Stefan Schmid <[email protected]>:

[cuda] Warn about ptxas versions before CUDA 12.6.3

Merging this change closes #19927

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#19927 from gspschmid:gschmid/ptxax-version-warn eeae01c99b76aa177916aa2e5106475de416cc5b
PiperOrigin-RevId: 701246876
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 29, 2024
Imported from GitHub PR #19927

This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade.

CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer.

Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)):

- nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154))
- nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84))
- ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263))

(As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.)

**Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning.

The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default.

---

Example:
```
# A JAX-Toolbox image affected
$ docker run -it --gpus=all jax:jax-2024-11-25

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
[0 0 0 0 3 6]

$ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85"
(...)

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[9 6 3 0 3 6]
```

---

On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant.
Copybara import of the project:

--
eeae01c by Georg Stefan Schmid <[email protected]>:

[cuda] Warn about ptxas versions before CUDA 12.6.3

Merging this change closes #19927

FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn eeae01c
PiperOrigin-RevId: 701246876
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 29, 2024
Imported from GitHub PR #19927

This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade.

CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer.

Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)):

- nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154))
- nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84))
- ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263))

(As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.)

**Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning.

The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default.

---

Example:
```
# A JAX-Toolbox image affected
$ docker run -it --gpus=all jax:jax-2024-11-25

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
[0 0 0 0 3 6]

$ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85"
(...)

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[9 6 3 0 3 6]
```

---

On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant.
Copybara import of the project:

--
eeae01c by Georg Stefan Schmid <[email protected]>:

[cuda] Warn about ptxas versions before CUDA 12.6.3

Merging this change closes #19927

FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn eeae01c
PiperOrigin-RevId: 701246876
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 29, 2024
Imported from GitHub PR #19927

This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade.

CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer.

Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)):

- nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154))
- nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84))
- ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263))

(As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.)

**Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning.

The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default.

---

Example:
```
# A JAX-Toolbox image affected
$ docker run -it --gpus=all jax:jax-2024-11-25

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
[0 0 0 0 3 6]

$ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85"
(...)

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[9 6 3 0 3 6]
```

---

On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant.
Copybara import of the project:

--
d32e9b0 by Georg Stefan Schmid <[email protected]>:

[cuda] Warn about ptxas versions before CUDA 12.6.3

Merging this change closes #19927

FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn d32e9b0
PiperOrigin-RevId: 701246876
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 29, 2024
Imported from GitHub PR openxla/xla#19927

This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade.

CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer.

Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)):

- nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154))
- nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84))
- ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263))

(As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.)

**Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning.

The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default.

---

Example:
```
# A JAX-Toolbox image affected
$ docker run -it --gpus=all jax:jax-2024-11-25

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
[0 0 0 0 3 6]

$ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85"
(...)

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[9 6 3 0 3 6]
```

---

On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant.
Copybara import of the project:

--
d32e9b03e8c6c1afc957268f1eefec0e10c5df78 by Georg Stefan Schmid <[email protected]>:

[cuda] Warn about ptxas versions before CUDA 12.6.3

Merging this change closes #19927

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#19927 from gspschmid:gschmid/ptxax-version-warn d32e9b03e8c6c1afc957268f1eefec0e10c5df78
PiperOrigin-RevId: 701246876
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 29, 2024
Imported from GitHub PR #19927

This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade.

CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer.

Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)):

- nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154))
- nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84))
- ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263))

(As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.)

**Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning.

The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default.

---

Example:
```
# A JAX-Toolbox image affected
$ docker run -it --gpus=all jax:jax-2024-11-25

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
[0 0 0 0 3 6]

$ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85"
(...)

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[9 6 3 0 3 6]
```

---

On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant.
Copybara import of the project:

--
d32e9b0 by Georg Stefan Schmid <[email protected]>:

[cuda] Warn about ptxas versions before CUDA 12.6.3

Merging this change closes #19927

FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn d32e9b0
PiperOrigin-RevId: 701287778
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 29, 2024
Imported from GitHub PR openxla/xla#19927

This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade.

CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer.

Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)):

- nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154))
- nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84))
- ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263))

(As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.)

**Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning.

The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default.

---

Example:
```
# A JAX-Toolbox image affected
$ docker run -it --gpus=all jax:jax-2024-11-25

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
[0 0 0 0 3 6]

$ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85"
(...)

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[9 6 3 0 3 6]
```

---

On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant.
Copybara import of the project:

--
d32e9b03e8c6c1afc957268f1eefec0e10c5df78 by Georg Stefan Schmid <[email protected]>:

[cuda] Warn about ptxas versions before CUDA 12.6.3

Merging this change closes #19927

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#19927 from gspschmid:gschmid/ptxax-version-warn d32e9b03e8c6c1afc957268f1eefec0e10c5df78
PiperOrigin-RevId: 701287778
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 29, 2024
Imported from GitHub PR #19927

This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade.

CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer.

Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)):

- nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154))
- nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84))
- ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263))

(As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.)

**Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning.

The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default.

---

Example:
```
# A JAX-Toolbox image affected
$ docker run -it --gpus=all jax:jax-2024-11-25

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
[0 0 0 0 3 6]

$ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85"
(...)

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[9 6 3 0 3 6]
```

---

On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant.
Copybara import of the project:

--
d32e9b0 by Georg Stefan Schmid <[email protected]>:

[cuda] Warn about ptxas versions before CUDA 12.6.3

Merging this change closes #19927

FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn d32e9b0
PiperOrigin-RevId: 701246876
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 29, 2024
Imported from GitHub PR openxla/xla#19927

This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade.

CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer.

Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)):

- nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154))
- nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84))
- ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263))

(As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.)

**Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning.

The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default.

---

Example:
```
# A JAX-Toolbox image affected
$ docker run -it --gpus=all jax:jax-2024-11-25

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
[0 0 0 0 3 6]

$ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85"
(...)

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[9 6 3 0 3 6]
```

---

On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant.
Copybara import of the project:

--
d32e9b03e8c6c1afc957268f1eefec0e10c5df78 by Georg Stefan Schmid <[email protected]>:

[cuda] Warn about ptxas versions before CUDA 12.6.3

Merging this change closes #19927

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#19927 from gspschmid:gschmid/ptxax-version-warn d32e9b03e8c6c1afc957268f1eefec0e10c5df78
PiperOrigin-RevId: 701246876
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 29, 2024
Imported from GitHub PR #19927

This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade.

CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer.

Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)):

- nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154))
- nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84))
- ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263))

(As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.)

**Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning.

The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default.

---

Example:
```
# A JAX-Toolbox image affected
$ docker run -it --gpus=all jax:jax-2024-11-25

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
[0 0 0 0 3 6]

$ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85"
(...)

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[9 6 3 0 3 6]
```

---

On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant.
Copybara import of the project:

--
d32e9b0 by Georg Stefan Schmid <[email protected]>:

[cuda] Warn about ptxas versions before CUDA 12.6.3

Merging this change closes #19927

FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn d32e9b0
PiperOrigin-RevId: 701246876
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants