Skip to content

Commit

Permalink
PR #19927: [cuda] Warn about ptxas versions before CUDA 12.6.3
Browse files Browse the repository at this point in the history
Imported from GitHub PR #19927

This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (jax-ml/jax#25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade.

CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer.

Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)):

- nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154))
- nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84))
- ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263))

(As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.)

**Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning.

The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default.

---

Example:
```
# A JAX-Toolbox image affected
$ docker run -it --gpus=all jax:jax-2024-11-25

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
[0 0 0 0 3 6]

$ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85"
(...)

$ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))"
[9 6 3 0 3 6]
```

---

On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant.
Copybara import of the project:

--
eeae01c by Georg Stefan Schmid <[email protected]>:

[cuda] Warn about ptxas versions before CUDA 12.6.3

Merging this change closes #19927

FUTURE_COPYBARA_INTEGRATE_REVIEW=#19927 from gspschmid:gschmid/ptxax-version-warn eeae01c
PiperOrigin-RevId: 701246876
  • Loading branch information
gspschmid authored and Google-ML-Automation committed Nov 29, 2024
1 parent bc54b9c commit 60bce4b
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 0 deletions.
3 changes: 3 additions & 0 deletions xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,9 @@ cc_library(
srcs = ["ptx_compiler_helpers.cc"],
hdrs = ["ptx_compiler_helpers.h"],
deps = [
"//xla/stream_executor:device_description",
"//xla/stream_executor:semantic_version",
"@com_google_absl//absl/base",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
Expand Down
3 changes: 3 additions & 0 deletions xla/stream_executor/cuda/nvjitlink_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ absl::StatusOr<std::vector<uint8_t>> CompileAndLinkUsingLibNvJitLink(
return std::vector<uint8_t>();
}

TF_ASSIGN_OR_RETURN(auto [major, minor], GetNvJitLinkVersion());
WarnIfBadPtxasVersion("nvJitLink", cc, {major, minor, 0});

std::vector<std::string> cli_args;
// On Hopper, default to sm_90a so that all instructions can be used. But
// only sm_90 is forward compatible, so don't use sm_90a with newer hardware:
Expand Down
28 changes: 28 additions & 0 deletions xla/stream_executor/cuda/ptx_compiler_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,32 @@ absl::Status CreateErrorFromPTXASLog(std::string_view log,
return absl::OkStatus();
}

// Warns if the ptxas version should be upgraded.
// Only prints the warning upon the first invocation.
void WarnIfBadPtxasVersion(std::string_view method,
const CudaComputeCapability& cc,
SemanticVersion compiler_version) {
static absl::once_flag run_once;
absl::call_once(run_once, [&] {
// nvbug 4826023: Occurs on Hopper+ in CUDA versions 12.x up to and
// including CUDA 12.6.2; the earliest ptxas release that corresponds to
// CUDA 12.6.3 is 12.6.85.
if (cc.major >= 9 && compiler_version >= SemanticVersion{12, 0, 0} &&
compiler_version < SemanticVersion{12, 6, 85}) {
LOG(ERROR)
<< "*** WARNING *** Invoking " << method << " with version "
<< compiler_version
<< ", which corresponds to a CUDA version <=12.6.2. CUDA versions "
"12.x.y up to and including 12.6.2 miscompile certain edge "
"cases around clamping.\nPlease upgrade to CUDA 12.6.3 or newer.";
if (method != "ptxas" && compiler_version.major() == 12 &&
compiler_version.minor() == 6) {
LOG(ERROR) << "(Note that this warning may be shown spuriously for "
"CUDA 12.6.y, since "
<< method << " does not report patch versions.)";
}
}
});
}

} // namespace stream_executor
7 changes: 7 additions & 0 deletions xla/stream_executor/cuda/ptx_compiler_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
#include <string_view>

#include "absl/status/status.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/semantic_version.h"

namespace stream_executor {
// Checks whether ptxas log contains errors related to register allocation.
Expand All @@ -30,6 +32,11 @@ bool IsPtxRegisterAllocationError(std::string_view);
absl::Status CreateErrorFromPTXASLog(std::string_view log,
std::string_view architecture,
bool cancel_if_reg_spill);

// Warns if the ptxas version should be upgraded.
void WarnIfBadPtxasVersion(std::string_view method,
const CudaComputeCapability& cc,
SemanticVersion compiler_version);
} // namespace stream_executor

#endif // XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_HELPERS_H_
3 changes: 3 additions & 0 deletions xla/stream_executor/cuda/ptx_compiler_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ static std::string_view ToString(nvPTXCompileResult status) {
absl::StatusOr<std::vector<uint8_t>> CompileGpuAsmUsingLibNvPtxCompiler(
const CudaComputeCapability& cc, const std::string& ptx_contents,
GpuAsmOpts options, bool cancel_if_reg_spill) {
TF_ASSIGN_OR_RETURN(auto version, GetLibNvPtxCompilerVersion());
WarnIfBadPtxasVersion("nvPTXCompiler", cc, version);

nvPTXCompilerHandle compiler_handle{};
RETURN_IF_NVPTXCOMPILER_ERROR(nvPTXCompilerCreate(
&compiler_handle, ptx_contents.size(), ptx_contents.data()));
Expand Down
3 changes: 3 additions & 0 deletions xla/stream_executor/cuda/subprocess_compilation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ absl::StatusOr<std::vector<uint8_t>> CompileGpuAsmUsingPtxAs(
absl::StatusOr<std::vector<uint8_t>> CompileGpuAsmUsingPtxAs(
std::string_view ptxas_path, const CudaComputeCapability& cc,
std::string_view ptx, GpuAsmOpts options, bool cancel_if_reg_spill) {
TF_ASSIGN_OR_RETURN(auto version, GetToolVersion(ptxas_path));
WarnIfBadPtxasVersion("ptxas", cc, version);

// Write ptx into a temporary file.
std::string ptx_path;
auto env = tsl::Env::Default();
Expand Down

0 comments on commit 60bce4b

Please sign in to comment.