-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[cuda] Bump nvidia-cuda-nvcc-cu12 dependency to 12.6.85 #25091
Merged
copybara-service
merged 1 commit into
jax-ml:main
from
gspschmid:gschmid/nvidia-cuda-nvcc-cu12_12-6-85
Nov 29, 2024
Merged
[cuda] Bump nvidia-cuda-nvcc-cu12 dependency to 12.6.85 #25091
copybara-service
merged 1 commit into
jax-ml:main
from
gspschmid:gschmid/nvidia-cuda-nvcc-cu12_12-6-85
Nov 29, 2024
+1
−1
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
I think the best place for the warning would be in XLA: it should warn if it detects a ptxas that is known to be buggy. |
hawkinsp
approved these changes
Nov 27, 2024
google-ml-butler
bot
added
kokoro:force-run
pull ready
Ready for copybara import and testing
labels
Nov 27, 2024
Corresponding XLA PR that emits a warning: openxla/xla#19927 |
copybara-service bot
pushed a commit
to openxla/xla
that referenced
this pull request
Nov 29, 2024
Imported from GitHub PR #19927 This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade. CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer. Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)): - nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154)) - nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84)) - ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263)) (As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.) **Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning. The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default. --- Example: ``` # A JAX-Toolbox image affected $ docker run -it --gpus=all jax:jax-2024-11-25 $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping. Please upgrade to CUDA 12.6.3 or newer. [0 0 0 0 3 6] $ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85" (...) $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" [9 6 3 0 3 6] ``` --- On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant. Copybara import of the project: -- eeae01c by Georg Stefan Schmid <[email protected]>: [cuda] Warn about ptxas versions before CUDA 12.6.3 Merging this change closes #19927 FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn eeae01c PiperOrigin-RevId: 701246876
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this pull request
Nov 29, 2024
Imported from GitHub PR openxla/xla#19927 This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade. CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer. Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)): - nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154)) - nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84)) - ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263)) (As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.) **Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning. The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default. --- Example: ``` # A JAX-Toolbox image affected $ docker run -it --gpus=all jax:jax-2024-11-25 $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping. Please upgrade to CUDA 12.6.3 or newer. [0 0 0 0 3 6] $ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85" (...) $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" [9 6 3 0 3 6] ``` --- On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant. Copybara import of the project: -- eeae01c99b76aa177916aa2e5106475de416cc5b by Georg Stefan Schmid <[email protected]>: [cuda] Warn about ptxas versions before CUDA 12.6.3 Merging this change closes #19927 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#19927 from gspschmid:gschmid/ptxax-version-warn eeae01c99b76aa177916aa2e5106475de416cc5b PiperOrigin-RevId: 701246876
copybara-service bot
pushed a commit
to openxla/xla
that referenced
this pull request
Nov 29, 2024
Imported from GitHub PR #19927 This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade. CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer. Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)): - nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154)) - nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84)) - ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263)) (As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.) **Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning. The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default. --- Example: ``` # A JAX-Toolbox image affected $ docker run -it --gpus=all jax:jax-2024-11-25 $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping. Please upgrade to CUDA 12.6.3 or newer. [0 0 0 0 3 6] $ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85" (...) $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" [9 6 3 0 3 6] ``` --- On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant. Copybara import of the project: -- eeae01c by Georg Stefan Schmid <[email protected]>: [cuda] Warn about ptxas versions before CUDA 12.6.3 Merging this change closes #19927 FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn eeae01c PiperOrigin-RevId: 701246876
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this pull request
Nov 29, 2024
Imported from GitHub PR openxla/xla#19927 This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade. CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer. Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)): - nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154)) - nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84)) - ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263)) (As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.) **Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning. The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default. --- Example: ``` # A JAX-Toolbox image affected $ docker run -it --gpus=all jax:jax-2024-11-25 $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping. Please upgrade to CUDA 12.6.3 or newer. [0 0 0 0 3 6] $ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85" (...) $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" [9 6 3 0 3 6] ``` --- On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant. Copybara import of the project: -- eeae01c99b76aa177916aa2e5106475de416cc5b by Georg Stefan Schmid <[email protected]>: [cuda] Warn about ptxas versions before CUDA 12.6.3 Merging this change closes #19927 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#19927 from gspschmid:gschmid/ptxax-version-warn eeae01c99b76aa177916aa2e5106475de416cc5b PiperOrigin-RevId: 701246876
copybara-service bot
pushed a commit
to openxla/xla
that referenced
this pull request
Nov 29, 2024
Imported from GitHub PR #19927 This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade. CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer. Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)): - nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154)) - nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84)) - ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263)) (As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.) **Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning. The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default. --- Example: ``` # A JAX-Toolbox image affected $ docker run -it --gpus=all jax:jax-2024-11-25 $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping. Please upgrade to CUDA 12.6.3 or newer. [0 0 0 0 3 6] $ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85" (...) $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" [9 6 3 0 3 6] ``` --- On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant. Copybara import of the project: -- eeae01c by Georg Stefan Schmid <[email protected]>: [cuda] Warn about ptxas versions before CUDA 12.6.3 Merging this change closes #19927 FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn eeae01c PiperOrigin-RevId: 701246876
copybara-service bot
pushed a commit
to openxla/xla
that referenced
this pull request
Nov 29, 2024
Imported from GitHub PR #19927 This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade. CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer. Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)): - nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154)) - nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84)) - ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263)) (As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.) **Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning. The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default. --- Example: ``` # A JAX-Toolbox image affected $ docker run -it --gpus=all jax:jax-2024-11-25 $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping. Please upgrade to CUDA 12.6.3 or newer. [0 0 0 0 3 6] $ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85" (...) $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" [9 6 3 0 3 6] ``` --- On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant. Copybara import of the project: -- eeae01c by Georg Stefan Schmid <[email protected]>: [cuda] Warn about ptxas versions before CUDA 12.6.3 Merging this change closes #19927 FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn eeae01c PiperOrigin-RevId: 701246876
copybara-service bot
pushed a commit
to openxla/xla
that referenced
this pull request
Nov 29, 2024
Imported from GitHub PR #19927 This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade. CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer. Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)): - nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154)) - nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84)) - ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263)) (As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.) **Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning. The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default. --- Example: ``` # A JAX-Toolbox image affected $ docker run -it --gpus=all jax:jax-2024-11-25 $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping. Please upgrade to CUDA 12.6.3 or newer. [0 0 0 0 3 6] $ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85" (...) $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" [9 6 3 0 3 6] ``` --- On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant. Copybara import of the project: -- d32e9b0 by Georg Stefan Schmid <[email protected]>: [cuda] Warn about ptxas versions before CUDA 12.6.3 Merging this change closes #19927 FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn d32e9b0 PiperOrigin-RevId: 701246876
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this pull request
Nov 29, 2024
Imported from GitHub PR openxla/xla#19927 This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade. CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer. Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)): - nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154)) - nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84)) - ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263)) (As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.) **Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning. The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default. --- Example: ``` # A JAX-Toolbox image affected $ docker run -it --gpus=all jax:jax-2024-11-25 $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping. Please upgrade to CUDA 12.6.3 or newer. [0 0 0 0 3 6] $ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85" (...) $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" [9 6 3 0 3 6] ``` --- On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant. Copybara import of the project: -- d32e9b03e8c6c1afc957268f1eefec0e10c5df78 by Georg Stefan Schmid <[email protected]>: [cuda] Warn about ptxas versions before CUDA 12.6.3 Merging this change closes #19927 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#19927 from gspschmid:gschmid/ptxax-version-warn d32e9b03e8c6c1afc957268f1eefec0e10c5df78 PiperOrigin-RevId: 701246876
copybara-service bot
pushed a commit
to openxla/xla
that referenced
this pull request
Nov 29, 2024
Imported from GitHub PR #19927 This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade. CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer. Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)): - nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154)) - nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84)) - ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263)) (As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.) **Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning. The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default. --- Example: ``` # A JAX-Toolbox image affected $ docker run -it --gpus=all jax:jax-2024-11-25 $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping. Please upgrade to CUDA 12.6.3 or newer. [0 0 0 0 3 6] $ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85" (...) $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" [9 6 3 0 3 6] ``` --- On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant. Copybara import of the project: -- d32e9b0 by Georg Stefan Schmid <[email protected]>: [cuda] Warn about ptxas versions before CUDA 12.6.3 Merging this change closes #19927 FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn d32e9b0 PiperOrigin-RevId: 701287778
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this pull request
Nov 29, 2024
Imported from GitHub PR openxla/xla#19927 This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade. CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer. Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)): - nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154)) - nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84)) - ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263)) (As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.) **Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning. The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default. --- Example: ``` # A JAX-Toolbox image affected $ docker run -it --gpus=all jax:jax-2024-11-25 $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping. Please upgrade to CUDA 12.6.3 or newer. [0 0 0 0 3 6] $ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85" (...) $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" [9 6 3 0 3 6] ``` --- On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant. Copybara import of the project: -- d32e9b03e8c6c1afc957268f1eefec0e10c5df78 by Georg Stefan Schmid <[email protected]>: [cuda] Warn about ptxas versions before CUDA 12.6.3 Merging this change closes #19927 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#19927 from gspschmid:gschmid/ptxax-version-warn d32e9b03e8c6c1afc957268f1eefec0e10c5df78 PiperOrigin-RevId: 701287778
copybara-service bot
pushed a commit
to openxla/xla
that referenced
this pull request
Nov 29, 2024
Imported from GitHub PR #19927 This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade. CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer. Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)): - nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154)) - nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84)) - ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263)) (As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.) **Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning. The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default. --- Example: ``` # A JAX-Toolbox image affected $ docker run -it --gpus=all jax:jax-2024-11-25 $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping. Please upgrade to CUDA 12.6.3 or newer. [0 0 0 0 3 6] $ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85" (...) $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" [9 6 3 0 3 6] ``` --- On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant. Copybara import of the project: -- d32e9b0 by Georg Stefan Schmid <[email protected]>: [cuda] Warn about ptxas versions before CUDA 12.6.3 Merging this change closes #19927 FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn d32e9b0 PiperOrigin-RevId: 701246876
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this pull request
Nov 29, 2024
Imported from GitHub PR openxla/xla#19927 This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade. CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer. Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)): - nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154)) - nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84)) - ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263)) (As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.) **Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning. The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default. --- Example: ``` # A JAX-Toolbox image affected $ docker run -it --gpus=all jax:jax-2024-11-25 $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping. Please upgrade to CUDA 12.6.3 or newer. [0 0 0 0 3 6] $ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85" (...) $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" [9 6 3 0 3 6] ``` --- On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant. Copybara import of the project: -- d32e9b03e8c6c1afc957268f1eefec0e10c5df78 by Georg Stefan Schmid <[email protected]>: [cuda] Warn about ptxas versions before CUDA 12.6.3 Merging this change closes #19927 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#19927 from gspschmid:gschmid/ptxax-version-warn d32e9b03e8c6c1afc957268f1eefec0e10c5df78 PiperOrigin-RevId: 701246876
copybara-service bot
pushed a commit
to openxla/xla
that referenced
this pull request
Nov 29, 2024
Imported from GitHub PR #19927 This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade. CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer. Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)): - nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154)) - nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84)) - ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263)) (As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.) **Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning. The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default. --- Example: ``` # A JAX-Toolbox image affected $ docker run -it --gpus=all jax:jax-2024-11-25 $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping. Please upgrade to CUDA 12.6.3 or newer. [0 0 0 0 3 6] $ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85" (...) $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" [9 6 3 0 3 6] ``` --- On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant. Copybara import of the project: -- d32e9b0 by Georg Stefan Schmid <[email protected]>: [cuda] Warn about ptxas versions before CUDA 12.6.3 Merging this change closes #19927 FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn d32e9b0 PiperOrigin-RevId: 701246876
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Bumps the minimum version for
nvidia-cuda-nvcc-cu12
to12.6.85
, which is the earliest published version that of that wheel incorporating CUDA 12.6.3.This resolves the issue underlying #24438. Here's a simplified reproducer:
As a follow-up we might also want to add a warning asking users with an affected version of
nvidia-cuda-nvcc-cu12
to upgrade (e.g., viapip install -U "nvidia-cuda-nvcc-cu12>=12.6.85"
).@hawkinsp Any opinions on whether to have that warning / where it should go? As to how to check for the current ptxas version (https://github.com/jax-ml/jax/pull/24438/files#r1818345222), I suppose we might expose that alongside other CUDA versions here (
jax/jaxlib/cuda/versions.cc
Line 26 in c35f8b2
nvidia-cuda-nvcc-cu12
wheel's__version__
, if present.