Skip to content

Commit

Permalink
[cuda] Warn about ptxas versions before CUDA 12.6.3
Browse files Browse the repository at this point in the history
  • Loading branch information
gspschmid committed Nov 29, 2024
1 parent 2f79665 commit eeae01c
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 0 deletions.
3 changes: 3 additions & 0 deletions xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,9 @@ cc_library(
srcs = ["ptx_compiler_helpers.cc"],
hdrs = ["ptx_compiler_helpers.h"],
deps = [
"//xla/stream_executor:device_description",
"//xla/stream_executor:semantic_version",
"@com_google_absl//absl/base",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
Expand Down
3 changes: 3 additions & 0 deletions xla/stream_executor/cuda/nvjitlink_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ absl::StatusOr<std::vector<uint8_t>> CompileAndLinkUsingLibNvJitLink(
return std::vector<uint8_t>();
}

TF_ASSIGN_OR_RETURN(auto [major, minor], GetNvJitLinkVersion());
WarnIfBadPtxasVersion("nvJitLink", cc, {major, minor, 0});

std::vector<std::string> cli_args;
// On Hopper, default to sm_90a so that all instructions can be used. But
// only sm_90 is forward compatible, so don't use sm_90a with newer hardware:
Expand Down
28 changes: 28 additions & 0 deletions xla/stream_executor/cuda/ptx_compiler_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,32 @@ absl::Status CreateErrorFromPTXASLog(std::string_view log,
return absl::OkStatus();
}

// Warns if the ptxas version should be upgraded.
// Only prints the warning upon the first invocation.
void WarnIfBadPtxasVersion(std::string_view method,
const CudaComputeCapability& cc,
SemanticVersion compiler_version) {
static absl::once_flag run_once;
absl::call_once(run_once, [&] {
// nvbug 4826023: Occurs on Hopper+ in CUDA versions 12.x up to and
// including CUDA 12.6.2; the earliest ptxas release that corresponds to
// CUDA 12.6.3 is 12.6.85.
if (cc.major >= 9 && compiler_version >= SemanticVersion{12, 0, 0} &&
compiler_version < SemanticVersion{12, 6, 85}) {
LOG(ERROR)
<< "*** WARNING *** Invoking " << method << " with version "
<< compiler_version
<< ", which corresponds to a CUDA version <=12.6.2. CUDA versions "
"12.x.y up to and including 12.6.2 miscompile certain edge "
"cases around clamping.\nPlease upgrade to CUDA 12.6.3 or newer.";
if (method != "ptxas" && compiler_version.major() == 12 &&
compiler_version.minor() == 6) {
LOG(ERROR) << "(Note that this warning may be shown spuriously for "
"CUDA 12.6.y, since "
<< method << " does not report patch versions.)";
}
}
});
}

} // namespace stream_executor
7 changes: 7 additions & 0 deletions xla/stream_executor/cuda/ptx_compiler_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
#include <string_view>

#include "absl/status/status.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/semantic_version.h"

namespace stream_executor {
// Checks whether ptxas log contains errors related to register allocation.
Expand All @@ -30,6 +32,11 @@ bool IsPtxRegisterAllocationError(std::string_view);
absl::Status CreateErrorFromPTXASLog(std::string_view log,
std::string_view architecture,
bool cancel_if_reg_spill);

// Warns if the ptxas version should be upgraded.
void WarnIfBadPtxasVersion(std::string_view method,
const CudaComputeCapability& cc,
SemanticVersion compiler_version);
} // namespace stream_executor

#endif // XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_HELPERS_H_
3 changes: 3 additions & 0 deletions xla/stream_executor/cuda/ptx_compiler_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ static std::string_view ToString(nvPTXCompileResult status) {
absl::StatusOr<std::vector<uint8_t>> CompileGpuAsmUsingLibNvPtxCompiler(
const CudaComputeCapability& cc, const std::string& ptx_contents,
GpuAsmOpts options, bool cancel_if_reg_spill) {
TF_ASSIGN_OR_RETURN(auto version, GetLibNvPtxCompilerVersion());
WarnIfBadPtxasVersion("nvPTXCompiler", cc, version);

nvPTXCompilerHandle compiler_handle{};
RETURN_IF_NVPTXCOMPILER_ERROR(nvPTXCompilerCreate(
&compiler_handle, ptx_contents.size(), ptx_contents.data()));
Expand Down
3 changes: 3 additions & 0 deletions xla/stream_executor/cuda/subprocess_compilation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ absl::StatusOr<std::vector<uint8_t>> CompileGpuAsmUsingPtxAs(
absl::StatusOr<std::vector<uint8_t>> CompileGpuAsmUsingPtxAs(
std::string_view ptxas_path, const CudaComputeCapability& cc,
std::string_view ptx, GpuAsmOpts options, bool cancel_if_reg_spill) {
TF_ASSIGN_OR_RETURN(auto version, GetToolVersion(ptxas_path));
WarnIfBadPtxasVersion("ptxas", cc, version);

// Write ptx into a temporary file.
std::string ptx_path;
auto env = tsl::Env::Default();
Expand Down

0 comments on commit eeae01c

Please sign in to comment.