From a27025f94345d873bc9e4718b4afc45651ea2db2 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Wed, 4 Dec 2024 13:37:37 -0800 Subject: [PATCH] Reverts 1c27e02524eda1031728b0a0c00fbf3c9be93248 PiperOrigin-RevId: 702843057 --- third_party/xla/xla/service/gpu/BUILD | 46 +- .../xla/xla/service/gpu/nvptx_compiler.cc | 485 +++++++++++++++--- .../xla/xla/service/gpu/nvptx_compiler.h | 99 +++- .../xla/service/gpu/nvptx_compiler_test.cc | 2 +- .../xla/service/gpu/ptx_compilation_test.cc | 6 +- .../ptx_compile_options_from_debug_options.cc | 36 -- .../ptx_compile_options_from_debug_options.h | 30 -- ...compile_options_from_debug_options_test.cc | 94 ---- 8 files changed, 522 insertions(+), 276 deletions(-) delete mode 100644 third_party/xla/xla/service/gpu/ptx_compile_options_from_debug_options.cc delete mode 100644 third_party/xla/xla/service/gpu/ptx_compile_options_from_debug_options.h delete mode 100644 third_party/xla/xla/service/gpu/ptx_compile_options_from_debug_options_test.cc diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index f4c9758767f7c3..cf16e01186cf3b 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1766,7 +1766,6 @@ cc_library( ], deps = [ ":nvptx_compiler_impl", - "//xla:debug_options_flags", "//xla/service:compiler", "//xla/stream_executor/cuda:cuda_platform_id", "@local_tsl//tsl/platform:path", @@ -1790,13 +1789,12 @@ cc_library( deps = [ ":buffer_sharing", ":cublas_padding_requirements", + ":gpu_asm_opts_util", ":gpu_compiler", ":ir_emission_utils", ":metrics", - ":ptx_compile_options_from_debug_options", ":target_constants", "//xla:autotune_results_proto_cc", - "//xla:debug_options_flags", "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", @@ -1846,14 +1844,24 @@ cc_library( "//xla/stream_executor:dnn", "//xla/stream_executor:semantic_version", "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor/cuda:assemble_compilation_provider", - "//xla/stream_executor/cuda:caching_compilation_provider", - "//xla/stream_executor/cuda:compilation_options", - "//xla/stream_executor/cuda:compilation_provider", "//xla/stream_executor/cuda:cuda_diagnostics", "//xla/stream_executor/cuda:cuda_platform_id", + "//xla/stream_executor/cuda:driver_compilation", + "//xla/stream_executor/cuda:nvjitlink", + "//xla/stream_executor/cuda:nvjitlink_known_issues", + "//xla/stream_executor/cuda:nvjitlink_support", + "//xla/stream_executor/cuda:ptx_compilation_method", + "//xla/stream_executor/cuda:ptx_compiler", + "//xla/stream_executor/cuda:ptx_compiler_support", + "//xla/stream_executor/cuda:ptx_linking_method", + "//xla/stream_executor/cuda:subprocess_compilation", + "//xla/stream_executor/gpu:gpu_asm_opts", + "//xla/tsl/util:env_var", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -1861,6 +1869,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:IRReader", "@llvm-project//llvm:Support", @@ -1868,6 +1877,7 @@ cc_library( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:scoped_annotation", "@local_tsl//tsl/profiler/lib:traceme", @@ -3238,28 +3248,6 @@ xla_cc_test( ], ) -cc_library( - name = "ptx_compile_options_from_debug_options", - srcs = ["ptx_compile_options_from_debug_options.cc"], - hdrs = ["ptx_compile_options_from_debug_options.h"], - deps = [ - "//xla:xla_proto_cc_impl", - "//xla/stream_executor/cuda:compilation_options", - ], -) - -xla_cc_test( - name = "ptx_compile_options_from_debug_options_test", - srcs = ["ptx_compile_options_from_debug_options_test.cc"], - deps = [ - ":ptx_compile_options_from_debug_options", - "//xla/stream_executor/cuda:compilation_options", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:test", - ], -) - cc_library( name = "flag_utils", srcs = ["flag_utils.cc"], diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index 8289cb8f547363..c154ce66293a85 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/nvptx_compiler.h" +#include #include #include #include @@ -27,18 +28,20 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/call_once.h" +#include "absl/cleanup/cleanup.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" -#include "xla/debug_options_flags.h" #include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -62,12 +65,12 @@ limitations under the License. #include "xla/service/gpu/autotuning/gemm_fusion_autotuner.h" #include "xla/service/gpu/buffer_sharing.h" #include "xla/service/gpu/cublas_padding_requirements.h" +#include "xla/service/gpu/gpu_asm_opts_util.h" #include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/llvm_gpu_backend/nvptx_utils.h" #include "xla/service/gpu/metrics.h" -#include "xla/service/gpu/ptx_compile_options_from_debug_options.h" #include "xla/service/gpu/target_constants.h" #include "xla/service/gpu/transforms/algebraic_simplifier.h" #include "xla/service/gpu/transforms/conv_padding_legalization.h" @@ -89,22 +92,30 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" #include "xla/service/llvm_ir/llvm_util.h" -#include "xla/stream_executor/cuda/assemble_compilation_provider.h" -#include "xla/stream_executor/cuda/caching_compilation_provider.h" -#include "xla/stream_executor/cuda/compilation_options.h" -#include "xla/stream_executor/cuda/compilation_provider.h" #include "xla/stream_executor/cuda/cuda_diagnostics.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/cuda/driver_compilation.h" +#include "xla/stream_executor/cuda/nvjitlink.h" +#include "xla/stream_executor/cuda/nvjitlink_known_issues.h" +#include "xla/stream_executor/cuda/nvjitlink_support.h" +#include "xla/stream_executor/cuda/ptx_compilation_method.h" +#include "xla/stream_executor/cuda/ptx_compiler.h" +#include "xla/stream_executor/cuda/ptx_compiler_support.h" +#include "xla/stream_executor/cuda/ptx_linking_method.h" +#include "xla/stream_executor/cuda/subprocess_compilation.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/util/env_var.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" +#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" #include "tsl/profiler/lib/scoped_annotation.h" @@ -545,20 +556,9 @@ void WarnIfBadDriverJITVersion() { }); } -static absl::StatusOr> -CreateCompilationProvider(const DebugOptions& debug_options) { - TF_ASSIGN_OR_RETURN(std::unique_ptr delegate, - se::cuda::AssembleCompilationProvider(debug_options)); - return std::make_unique( - std::move(delegate)); -} - -NVPTXCompiler::NVPTXCompiler(const DebugOptions& debug_options) +NVPTXCompiler::NVPTXCompiler() : GpuCompiler(stream_executor::cuda::kCudaPlatformId, nvptx::TargetTriple(), - nvptx::DataLayout()), - compilation_provider_{CreateCompilationProvider(debug_options)} {} - -NVPTXCompiler::NVPTXCompiler() : NVPTXCompiler(GetDebugOptionsFromFlags()) {} + nvptx::DataLayout()) {} HloDataflowAnalysis::CanShareBuffer NVPTXCompiler::GetCanShareBuffer( const se::DeviceDescription& device_description) const { @@ -568,6 +568,8 @@ HloDataflowAnalysis::CanShareBuffer NVPTXCompiler::GetCanShareBuffer( }; } +constexpr const uint8_t kPtxPrefix[] = {'P', 'T', 'X', ':', ' '}; + absl::StatusOr NVPTXCompiler::CompileTargetBinary( const HloModuleConfig& module_config, llvm::Module* llvm_module, @@ -605,28 +607,248 @@ NVPTXCompiler::CompileTargetBinary( RecordLlvmPassesAndLlvmToPtxDuration(end_usecs - start_usecs); } + TF_ASSIGN_OR_RETURN( + se::PtxLinkingMethod linking_method, + ChooseLinkingMethod(module_config.debug_options(), device_description)); + + if (linking_method == se::PtxLinkingMethod::kNvJitLink && relocatable) { + VLOG(2) << "Deferring the PTX to CUBIN compilation of the relocatable " + "module to the linking step."; + std::vector binary; + if (!ptx.empty()) { + binary.reserve(sizeof(kPtxPrefix) + ptx.size() + 1); + binary.insert(binary.end(), kPtxPrefix, kPtxPrefix + sizeof(kPtxPrefix)); + binary.insert(binary.end(), ptx.begin(), ptx.end()); + binary.emplace_back('\0'); + } + return BackendCompileResult{std::move(ptx), std::move(binary)}; + } + + absl::StatusOr> maybe_cubin = + CompileGpuAsmOrGetCachedResult( + ptx, + std::get( + device_description.gpu_compute_capability()), + module_config, + (debug_module != nullptr ? debug_module->name() : "(unknown)"), + relocatable, options); + + if (!maybe_cubin.ok()) { + return maybe_cubin.status(); + } + return BackendCompileResult{std::move(ptx), std::move(maybe_cubin.value())}; +} + +using stream_executor::PtxCompilationMethod; + +// Returns the supported compilation methods in the order of priority. +std::vector GetSupportedCompilationMethods() { + std::vector methods; + if (se::IsLibNvPtxCompilerSupported()) { + methods.emplace_back(PtxCompilationMethod::kNvPtxCompiler); + } + if (se::IsLibNvJitLinkSupported()) { + methods.emplace_back(PtxCompilationMethod::kNvJitLink); + } + methods.emplace_back(PtxCompilationMethod::kPtxas); + return methods; +} + +absl::StatusOr ChooseCompilationMethod( + absl::Span available_compilation_methods, + const DebugOptions& debug_options, bool relocatable) { + std::vector compilation_methods( + available_compilation_methods.begin(), + available_compilation_methods.end()); + VLOG(2) << "Available compilation methods: " + << absl::StrJoin(compilation_methods, ", "); + + auto remove_compilation_method = [&](PtxCompilationMethod method) { + auto it = absl::c_find(compilation_methods, method); + if (it != compilation_methods.end()) { + compilation_methods.erase(it); + } + }; + + // This is true if the user explicitly requested the use of libNvJitLink + // through the command line flag. In that case we bypass all the sanity checks + // and enable its usage. It means compilation might fail which is a better + // diagnostic to the user instead of silently discarding NvJitLink. + const bool libnvjitlink_force_enabled = + debug_options.xla_gpu_libnvjitlink_mode() == + DebugOptions::LIB_NV_JIT_LINK_MODE_ENABLED; + + if (!stream_executor::IsLibNvJitLinkSupported() && + !libnvjitlink_force_enabled) { + VLOG(3) << "Discarding NvJitLink since it is not supported in this build."; + remove_compilation_method(PtxCompilationMethod::kNvJitLink); + } else if (stream_executor::LoadedNvJitLinkHasKnownIssues() && + !libnvjitlink_force_enabled) { + auto formatted_version = [&]() -> std::string { + absl::StatusOr version = + stream_executor::GetNvJitLinkVersion(); + if (version.ok()) { + return absl::StrCat(std::get<0>(*version), ".", std::get<1>(*version)); + } + return "unknown"; + }(); + + VLOG(3) << "Discarding NvJitLink since the loaded library version (" + << formatted_version << ") has known issues."; + remove_compilation_method(PtxCompilationMethod::kNvJitLink); + } else if (debug_options.xla_gpu_libnvjitlink_mode() == + DebugOptions::LIB_NV_JIT_LINK_MODE_DISABLED) { + VLOG(3) << "Discarding NvJitLink since it was explicitly disabled."; + remove_compilation_method(PtxCompilationMethod::kNvJitLink); + } + if (!debug_options.xla_gpu_enable_libnvptxcompiler()) { + VLOG(3) << "Discarding NvPtxCompiler since it is disabled."; + remove_compilation_method(PtxCompilationMethod::kNvPtxCompiler); + } + + VLOG(2) << "Supported and enabled compilation methods: " + << absl::StrJoin(compilation_methods, ", "); + + if (relocatable && absl::c_linear_search(compilation_methods, + PtxCompilationMethod::kNvJitLink)) { + // NvJitLink can't produce relocatable CUBINs. + VLOG(3) << "Discarding NvJitLink since it can't produce the requested " + "relocatable CUBIN."; + remove_compilation_method(PtxCompilationMethod::kNvJitLink); + } + + VLOG(2) << "Considered compilation methods: " + << absl::StrJoin(compilation_methods, ", "); + + if (compilation_methods.empty()) { + return absl::UnavailableError( + "No supported compilation method is available."); + } + + return compilation_methods.front(); +} + +static absl::StatusOr> AssembleOptionsAndCompile( + const std::string& ptx, se::CudaComputeCapability cc, + const HloModuleConfig& hlo_module_config, + GpuCompiler::CompileOptions options, bool relocatable) { if (ptx.empty()) { - return BackendCompileResult{}; + return std::vector(); } - TF_RETURN_IF_ERROR(compilation_provider_.status()); - se::cuda::CompilationProvider* compilation_provider = - compilation_provider_->get(); + se::GpuAsmOpts ptxas_config = + PtxOptsFromDebugOptions(hlo_module_config.debug_options()); + if (relocatable) { + ptxas_config.extra_flags.push_back("-c"); + } + uint64_t start_usecs = tsl::Env::Default()->NowMicros(); - se::cuda::CompilationOptions compilation_options = - PtxCompileOptionsFromDebugOptions( - module_config.debug_options(), - /*is_autotuning_compilation=*/options.is_autotuning_compilation); + bool cancel_if_reg_spill = + hlo_module_config.debug_options() + .xla_gpu_filter_kernels_spilling_registers_on_autotuning() && + options.is_autotuning_compilation; - se::CudaComputeCapability cc = std::get( - device_description.gpu_compute_capability()); + std::vector supported_compilation_methods = + GetSupportedCompilationMethods(); + TF_ASSIGN_OR_RETURN( + PtxCompilationMethod compilation_method, + ChooseCompilationMethod(supported_compilation_methods, + hlo_module_config.debug_options(), relocatable)); + + VLOG(2) << "Using compilation method: " << compilation_method; + + absl::StatusOr> maybe_cubin = [&] { + switch (compilation_method) { + case PtxCompilationMethod::kNvJitLink: + return se::CompileAndLinkUsingLibNvJitLink( + cc, + {se::NvJitLinkInput{ + se::NvJitLinkInput::Type::kPtx, + absl::Span{ + reinterpret_cast(ptx.c_str()), + ptx.size() + 1 /* We need the null terminator. */}}}, + ptxas_config, cancel_if_reg_spill); + + case PtxCompilationMethod::kNvPtxCompiler: + return se::CompileGpuAsmUsingLibNvPtxCompiler(cc, ptx, ptxas_config, + cancel_if_reg_spill); + case PtxCompilationMethod::kPtxas: + return se::CompileGpuAsmUsingPtxAs(cc, ptx, ptxas_config, + cancel_if_reg_spill); + } + }(); + if (maybe_cubin.ok()) { + uint64_t end_usecs = tsl::Env::Default()->NowMicros(); + // This won't record values for calls that error out (because if they + // error out we have no way of telling how far through the process we + // got). + RecordPtxToCubinDuration(end_usecs - start_usecs); + + VLOG(1) << "Compiled PTX size: " << ptx.size() + << "bytes. CUBIN size: " << maybe_cubin.value().size() << "bytes."; + + return maybe_cubin; + } + + if (maybe_cubin.status().code() == absl::StatusCode::kNotFound) { + if (!hlo_module_config.debug_options() + .xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found()) { + LOG(WARNING) << nvptx::CantFindCudaMessage( + "Can't find ptxas binary in ${CUDA_DIR}/bin. Custom ptxas " + "location can be specified using $PATH.", + hlo_module_config.debug_options().xla_gpu_cuda_data_dir()); + LOG(FATAL) << "Can't find ptxas binary. You can pass the flag " + "--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found " + "to use the GPU driver for compiling ptx instead. However " + "this option is discouraged and can lead to increased " + "memory consumptions and other subtle runtime issues."; + } + + // Missing ptxas is expected in some environments where CUDA SDK + // binaries are not available. We don't want to spam logs with + // identical warnings in this case. + LOG_FIRST_N(WARNING, 1) << nvptx::CantFindCudaMessage( + "Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to " + "the GPU driver for PTX -> sass compilation. This is OK so " + "long as you don't see a warning below about an out-of-date " + "driver version. Custom ptxas location can be specified " + "using $PATH.", + hlo_module_config.debug_options().xla_gpu_cuda_data_dir()); + + // We're going to use the driver to JIT our PTX->SASS, so warn if + // the JIT in the driver has known bugs. + WarnIfBadDriverJITVersion(); + return maybe_cubin; + } + + if (maybe_cubin.status().code() == absl::StatusCode::kCancelled) { + return maybe_cubin; + } + + if (maybe_cubin.status().code() == absl::StatusCode::kResourceExhausted) { + return maybe_cubin; + } + + if (maybe_cubin.status().code() != absl::StatusCode::kUnimplemented) { + return AppendStatus( + maybe_cubin.status(), + "If the error message indicates that a file could not be written, " + "please verify that sufficient filesystem space is provided."); + } + + return maybe_cubin; +} + +absl::StatusOr> +NVPTXCompiler::CompileGpuAsmOrGetCachedResult( + const std::string& ptx, se::CudaComputeCapability cc, + const HloModuleConfig& hlo_module_config, absl::string_view module_name, + bool relocatable, const CompileOptions& options) { // This may print multiple lines per HLO compilation because of the // parallelized compilation of LLVM modules. - std::string module_name = - debug_module != nullptr ? debug_module->name() : "(unknown)"; XLA_SCOPED_LOGGING_TIMER_IF( - absl::StrCat("NVPTXCompiler::CompileTargetBinary - PtxToCubin for ", + absl::StrCat("NVPTXCompiler::CompileGpuAsmOrGetCachedResult for ", module_name), !options.is_autotuning_compilation); tsl::profiler::ScopedAnnotation annotation([&] { @@ -634,38 +856,141 @@ NVPTXCompiler::CompileTargetBinary( }); tsl::profiler::TraceMe activity("PTX->CUBIN", tsl::profiler::TraceMeLevel::kInfo); + CompilationCacheValue* cache_value = nullptr; + bool inserted = [&] { + auto flags = CompilationCacheFlags{ + hlo_module_config.debug_options() + .xla_gpu_filter_kernels_spilling_registers_on_autotuning()}; + absl::MutexLock lock(&mutex_); + auto [iter, inserted] = compilation_cache_.emplace( + std::piecewise_construct, + std::forward_as_tuple(ptx, cc.major, cc.minor, relocatable, flags), + std::forward_as_tuple()); + // Do not move this assignment outside of the critical section. There is + // a TOCTOU if `compilation_cache_` is rehashed before the iterator is used. + cache_value = &iter->second; + return inserted; + }(); - uint64_t start_usecs = tsl::Env::Default()->NowMicros(); - const auto record_ptx_to_cubin_metric = [&]() { - uint64_t end_usecs = tsl::Env::Default()->NowMicros(); - // This won't record values for calls that error out (because if they - // error out we have no way of telling how far through the process we - // got). - RecordPtxToCubinDuration(end_usecs - start_usecs); - }; + // Compile the ptx if it wasn't in the cache before we called this function. + // Other threads asking for the same compilation key will block on + // cache_value->mutex_ until compilation is done. + absl::MutexLock lock(&cache_value->mutex); + if (inserted) { + CHECK(!cache_value->compilation_done); + absl::Cleanup mark_compilation_as_done = [cache_value] { + // Note that we will set this to true also in the error case, so that we + // don't retry this compilation. + cache_value->compilation_done = true; + cache_value->compilation_done_cv.SignalAll(); + }; + + cache_value->maybe_cubin = AssembleOptionsAndCompile( + ptx, cc, hlo_module_config, options, relocatable); + return cache_value->maybe_cubin; + } - if (relocatable) { - TF_ASSIGN_OR_RETURN(se::cuda::RelocatableModule relocatable_module, - compilation_provider->CompileToRelocatableModule( - cc, ptx, compilation_options)); - record_ptx_to_cubin_metric(); - return BackendCompileResult{std::move(ptx), - std::move(relocatable_module.cubin)}; + while (!cache_value->compilation_done) { + cache_value->compilation_done_cv.Wait(&cache_value->mutex); } - TF_ASSIGN_OR_RETURN( - se::cuda::Assembly assembly, - compilation_provider->Compile(cc, ptx, compilation_options)); - record_ptx_to_cubin_metric(); - return BackendCompileResult{std::move(ptx), std::move(assembly.cubin)}; + return cache_value->maybe_cubin; +} + +static bool IsNvlinkEnabled() { + const bool use_nvlink_by_default = +#ifdef TF_DISABLE_NVLINK_BY_DEFAULT + false; +#else + true; +#endif + bool use_nvlink; + TF_CHECK_OK(tsl::ReadBoolFromEnvVar("TF_USE_NVLINK_FOR_PARALLEL_COMPILATION", + /*default_val=*/ + use_nvlink_by_default, &use_nvlink)); + return use_nvlink; +} + +// Returns the version of the PTX compiler that will be used for the given +// debug options and preferred CUDA directory (Either libnvptxcompiler or ptxas) +static absl::StatusOr GetAsmCompilerVersion( + const DebugOptions& debug_options, const std::string& preferred_cuda_dir) { + if (debug_options.xla_gpu_enable_libnvptxcompiler() && + se::IsLibNvPtxCompilerSupported()) { + return stream_executor::GetLibNvPtxCompilerVersion(); + } + + return se::GetAsmCompilerVersion(preferred_cuda_dir); +} + +absl::StatusOr NVPTXCompiler::ChooseLinkingMethod( + const DebugOptions& debug_options, + const stream_executor::DeviceDescription& device_description) { + se::GpuAsmOpts ptxas_config = PtxOptsFromDebugOptions(debug_options); + std::string& preferred_cuda_dir = ptxas_config.preferred_cuda_dir; + + using LinkingMethod = se::PtxLinkingMethod; + + // If the user has explicitly requested NvJitLink we will try to use it and + // fail later during linking if it is not available or has known issues. + if (debug_options.xla_gpu_libnvjitlink_mode() == + DebugOptions::LIB_NV_JIT_LINK_MODE_ENABLED) { + return LinkingMethod::kNvJitLink; + } + + if (stream_executor::IsLibNvJitLinkSupported() && + !stream_executor::LoadedNvJitLinkHasKnownIssues() && + debug_options.xla_gpu_libnvjitlink_mode() != + DebugOptions::LIB_NV_JIT_LINK_MODE_DISABLED) { + return se::PtxLinkingMethod::kNvJitLink; + } + + TF_ASSIGN_OR_RETURN(auto asm_compiler_version, + GetAsmCompilerVersion(debug_options, preferred_cuda_dir)); + + auto nvlink_version = stream_executor::GetNvLinkVersion(preferred_cuda_dir); + if (IsNvlinkEnabled() && nvlink_version.ok() && + nvlink_version.value() >= asm_compiler_version) { + return LinkingMethod::kNvLink; + } + + stream_executor::SemanticVersion driver_version = + device_description.driver_version(); + + auto greater_equal_major_minor = + [](const stream_executor::SemanticVersion& a, + const stream_executor::SemanticVersion& b) { + return std::make_tuple(a.major(), a.minor()) >= + std::make_tuple(b.major(), b.minor()); + }; + + // The patch level version has no meaning when comparing driver to ptxas + // versions. + if (greater_equal_major_minor(driver_version, asm_compiler_version)) { + return LinkingMethod::kDriver; + } + + LOG_FIRST_N(WARNING, 1) + << "The NVIDIA driver's CUDA version is " << driver_version + << " which is older than the PTX compiler version " + << asm_compiler_version + << ". Because the driver is older than the PTX compiler version, XLA is " + "disabling parallel compilation, which may slow down compilation. " + "You should update your NVIDIA driver or use the NVIDIA-provided " + "CUDA forward compatibility packages."; + + return se::PtxLinkingMethod::kNone; } absl::StatusOr NVPTXCompiler::CanUseLinkModules( const HloModuleConfig& hlo_module_config, const stream_executor::DeviceDescription& device_description) { - TF_RETURN_IF_ERROR(compilation_provider_.status()); - return compilation_provider_->get()->SupportsCompileAndLink() && - compilation_provider_->get()->SupportsCompileToRelocatableModule(); + // TODO(phawkins): rather than comparing version numbers, it might be more + // robust if we simply tried to link something the first time we compile. + TF_ASSIGN_OR_RETURN(se::PtxLinkingMethod linking_method, + ChooseLinkingMethod(hlo_module_config.debug_options(), + device_description)); + return linking_method != se::PtxLinkingMethod::kNone; } absl::StatusOr> NVPTXCompiler::LinkModules( @@ -677,28 +1002,42 @@ absl::StatusOr> NVPTXCompiler::LinkModules( auto cc = std::get( device_description.gpu_compute_capability()); - TF_RETURN_IF_ERROR(compilation_provider_.status()); - se::cuda::CompilationProvider* compilation_provider = - compilation_provider_->get(); + TF_ASSIGN_OR_RETURN(se::PtxLinkingMethod linking_method, + ChooseLinkingMethod(debug_options, device_description)); + VLOG(1) << "Linking " << modules.size() + << " modules with linking method: " << linking_method; + + if (linking_method == se::PtxLinkingMethod::kNvJitLink) { + const auto module_contains_ptx = + [](const std::vector& module) -> bool { + return module.size() >= sizeof(kPtxPrefix) && + std::equal(std::begin(kPtxPrefix), std::end(kPtxPrefix), + std::begin(module)); + }; + + std::vector nvjitlink_inputs; + nvjitlink_inputs.reserve(modules.size()); + for (std::vector& module : modules) { + if (module_contains_ptx(module)) { + nvjitlink_inputs.push_back( + {se::NvJitLinkInput::Type::kPtx, + absl::Span(module).subspan(sizeof(kPtxPrefix))}); + } else { + nvjitlink_inputs.push_back({se::NvJitLinkInput::Type::kCubin, module}); + } + } - std::vector inputs; - inputs.reserve(modules.size()); - for (std::vector& module : modules) { - inputs.push_back(se::cuda::RelocatableModule{std::move(module)}); + se::GpuAsmOpts ptxas_config = PtxOptsFromDebugOptions(debug_options); + return stream_executor::CompileAndLinkUsingLibNvJitLink( + cc, nvjitlink_inputs, ptxas_config, + /*cancel_if_reg_spill=*/false); } - se::cuda::CompilationOptions compilation_options = - PtxCompileOptionsFromDebugOptions(debug_options, - /*is_autotuning_compilation=*/false); - - VLOG(1) << "Linking " << modules.size() - << " modules with compilation provider " - << compilation_provider->name(); - TF_ASSIGN_OR_RETURN( - se::cuda::Assembly assembly, - compilation_provider->CompileAndLink(cc, inputs, compilation_options)); + if (linking_method == se::PtxLinkingMethod::kNvLink) { + return LinkUsingNvlink(cc, debug_options.xla_gpu_cuda_data_dir(), modules); + } - return std::move(assembly.cubin); + return LinkGpuAsmUsingDriver(stream_exec, cc, modules); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.h b/third_party/xla/xla/service/gpu/nvptx_compiler.h index ca37748efb33ec..9d2809217ea186 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.h +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.h @@ -17,11 +17,16 @@ limitations under the License. #define XLA_SERVICE_GPU_NVPTX_COMPILER_H_ #include -#include +#include +#include #include +#include "absl/base/thread_annotations.h" +#include "absl/container/node_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "llvm/IR/Module.h" #include "xla/autotune_results.pb.h" #include "xla/hlo/analysis/hlo_dataflow_analysis.h" @@ -32,7 +37,7 @@ limitations under the License. #include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/hlo_module_config.h" -#include "xla/stream_executor/cuda/compilation_provider.h" +#include "xla/stream_executor/cuda/ptx_linking_method.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/semantic_version.h" @@ -48,12 +53,7 @@ void WarnIfBadDriverJITVersion(); // NVPTXCompiler generates efficient GPU executables for NVPTX target. class NVPTXCompiler : public GpuCompiler { public: - // DebugOptions are used to determine which CompilationProvider to use. - explicit NVPTXCompiler(const DebugOptions& debug_options); - - // The same as above, but uses the default DebugOptions determined by - // flags. - explicit NVPTXCompiler(); + NVPTXCompiler(); absl::Status OptimizeHloConvolutionCanonicalization( HloModule* hlo_module, se::GpuComputeCapability gpu_version, @@ -104,8 +104,87 @@ class NVPTXCompiler : public GpuCompiler { std::vector> modules, const DebugOptions& debug_options) override; - absl::StatusOr> - compilation_provider_; + absl::StatusOr ChooseLinkingMethod( + const DebugOptions& debug_options, + const stream_executor::DeviceDescription& device_description); + + // Tries to compile the given ptx string to cubin. Returns a vector with the + // compiled cubin if compilation succeeded. + absl::StatusOr> CompileGpuAsmOrGetCachedResult( + const std::string& ptx, se::CudaComputeCapability cc, + const HloModuleConfig& hlo_module_config, absl::string_view module_name, + bool relocatable, const CompileOptions& options); + + struct CompilationCacheFlags { + template + friend H AbslHashValue(H h, const CompilationCacheFlags& flags) { + return H::combine(std::move(h), + flags.filter_kernels_spilling_registers_on_autotuning); + } + + friend bool operator==(const CompilationCacheFlags& a, + const CompilationCacheFlags& b) { + return a.filter_kernels_spilling_registers_on_autotuning == + b.filter_kernels_spilling_registers_on_autotuning; + } + + bool filter_kernels_spilling_registers_on_autotuning; + }; + + // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} + // -> cubin so we don't recompile the same ptx twice. This is important for + // some interactive workflows. (We also cache at the HLO level, but sometimes + // we can't realize that two modules are the same until we lower to ptx.) + // + // Compilation of distinct PTX happens in parallel. If more than one thread + // attempts to compile the same PTX, the fist thread to obtain + // cache_value_->mutex_ performs the compilation. The rest wait() on + // cache_value_->compilation_done_cv_ until the compilation is done. + // + // If compiling the ptx fails, we return an empty cubin, cross our fingers, + // and leave compilation up to the driver. + struct CompilationCacheKey { + CompilationCacheKey(std::string ptx, int cc_major, int cc_minor, + bool relocatable, CompilationCacheFlags flags) + : ptx(std::move(ptx)), + cc_major(cc_major), + cc_minor(cc_minor), + relocatable(relocatable), + flags(std::move(flags)) {} + + template + friend H AbslHashValue(H h, const CompilationCacheKey& key) { + return H::combine(std::move(h), key.ptx, key.cc_major, key.cc_minor, + key.relocatable, key.flags); + } + + friend bool operator==(const CompilationCacheKey& a, + const CompilationCacheKey& b) { + return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor && + a.ptx == b.ptx && a.relocatable == b.relocatable && + a.flags == b.flags; + } + + std::string ptx; + int cc_major; + int cc_minor; + bool relocatable; + CompilationCacheFlags flags; + }; + + struct CompilationCacheValue { + bool compilation_done = false; + absl::StatusOr> maybe_cubin; + // mutex and condition variable to serialize compilation completing. + absl::Mutex mutex; + absl::CondVar compilation_done_cv; + }; + + // Don't even think about switching this to flat_hash_map; iterator stability + // is critical here. + absl::Mutex mutex_; + absl::node_hash_map + compilation_cache_ ABSL_GUARDED_BY(mutex_); NVPTXCompiler(const NVPTXCompiler&) = delete; NVPTXCompiler& operator=(const NVPTXCompiler&) = delete; diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc b/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc index 24a3ff6ab0b89a..43894e0a7ba7a3 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc @@ -235,7 +235,7 @@ ENTRY main { EXPECT_EQ(while_op->while_body()->root_instruction()->operand(1)->opcode(), HloOpcode::kCopy); - NVPTXCompiler compiler{module->config().debug_options()}; + NVPTXCompiler compiler; TF_EXPECT_OK(compiler.RunPostSchedulingPipelines( module.get(), 100000, backend().default_stream_executor()->GetDeviceDescription())); diff --git a/third_party/xla/xla/service/gpu/ptx_compilation_test.cc b/third_party/xla/xla/service/gpu/ptx_compilation_test.cc index 177fd03f120dd0..c0736bee2412cc 100644 --- a/third_party/xla/xla/service/gpu/ptx_compilation_test.cc +++ b/third_party/xla/xla/service/gpu/ptx_compilation_test.cc @@ -235,7 +235,7 @@ class NVPTXCompilationTests absl::StatusOr> CompileExecutable( std::unique_ptr module) { - NVPTXCompiler compiler{module->config().debug_options()}; + NVPTXCompiler compiler{}; return compiler.RunBackend(std::move(module), backend().default_stream_executor(), @@ -302,8 +302,8 @@ TEST_P(NVPTXCompilationTests, CompareBinaryOutput) { absl::StatusOr> reference = compile(PtxCompilationMethod::kPtxas, reference_linking_method); - ASSERT_THAT(executable, tsl::testing::IsOkAndHolds(::testing::NotNull())); - ASSERT_THAT(reference, tsl::testing::IsOkAndHolds(::testing::NotNull())); + EXPECT_THAT(executable, tsl::testing::IsOkAndHolds(::testing::NotNull())); + EXPECT_THAT(reference, tsl::testing::IsOkAndHolds(::testing::NotNull())); absl::Span executable_binary = static_cast(executable.value().get())->binary(); diff --git a/third_party/xla/xla/service/gpu/ptx_compile_options_from_debug_options.cc b/third_party/xla/xla/service/gpu/ptx_compile_options_from_debug_options.cc deleted file mode 100644 index 2c183c81a2defd..00000000000000 --- a/third_party/xla/xla/service/gpu/ptx_compile_options_from_debug_options.cc +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/ptx_compile_options_from_debug_options.h" - -#include "xla/stream_executor/cuda/compilation_options.h" - -namespace xla::gpu { - -stream_executor::cuda::CompilationOptions PtxCompileOptionsFromDebugOptions( - const DebugOptions& debug_options, bool is_autotuning_compilation) { - stream_executor::cuda::CompilationOptions compilation_options; - compilation_options.cancel_if_reg_spill = - debug_options.xla_gpu_filter_kernels_spilling_registers_on_autotuning() && - is_autotuning_compilation; - compilation_options.disable_optimizations = - debug_options.xla_gpu_disable_gpuasm_optimizations(); - compilation_options.generate_debug_info = - debug_options.xla_gpu_generate_debug_info(); - compilation_options.generate_line_info = - debug_options.xla_gpu_generate_line_info(); - return compilation_options; -} -} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/ptx_compile_options_from_debug_options.h b/third_party/xla/xla/service/gpu/ptx_compile_options_from_debug_options.h deleted file mode 100644 index bd3bff0cffcc47..00000000000000 --- a/third_party/xla/xla/service/gpu/ptx_compile_options_from_debug_options.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_PTX_COMPILE_OPTIONS_FROM_DEBUG_OPTIONS_H_ -#define XLA_SERVICE_GPU_PTX_COMPILE_OPTIONS_FROM_DEBUG_OPTIONS_H_ - -#include "xla/stream_executor/cuda/compilation_options.h" -#include "xla/xla.pb.h" - -namespace xla::gpu { - -// Infers the compilation options from the given debug options. -stream_executor::cuda::CompilationOptions PtxCompileOptionsFromDebugOptions( - const DebugOptions& debug_options, bool is_autotuning_compilation); - -} // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_PTX_COMPILE_OPTIONS_FROM_DEBUG_OPTIONS_H_ diff --git a/third_party/xla/xla/service/gpu/ptx_compile_options_from_debug_options_test.cc b/third_party/xla/xla/service/gpu/ptx_compile_options_from_debug_options_test.cc deleted file mode 100644 index 8b4a46d68ae90c..00000000000000 --- a/third_party/xla/xla/service/gpu/ptx_compile_options_from_debug_options_test.cc +++ /dev/null @@ -1,94 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/ptx_compile_options_from_debug_options.h" - -#include -#include -#include "xla/stream_executor/cuda/compilation_options.h" -#include "tsl/platform/test.h" - -namespace xla::gpu { -namespace { -using ::stream_executor::cuda::CompilationOptions; -using ::testing::Field; - -TEST(PtxCompileOptionsFromDebugOptionsTest, - DefaultDebugOptionsResultsInDefaultCompilationOptions) { - DebugOptions debug_options; - EXPECT_EQ(PtxCompileOptionsFromDebugOptions( - debug_options, /*is_autotuning_compilation=*/false), - CompilationOptions{}); - EXPECT_EQ(PtxCompileOptionsFromDebugOptions( - debug_options, /*is_autotuning_compilation=*/true), - CompilationOptions{}); -} - -TEST(PtxCompileOptionsFromDebugOptionsTest, OptimizationsCanBeDisabled) { - DebugOptions debug_options; - debug_options.set_xla_gpu_disable_gpuasm_optimizations(true); - EXPECT_THAT( - PtxCompileOptionsFromDebugOptions(debug_options, - /*is_autotuning_compilation=*/false), - Field(&CompilationOptions::disable_optimizations, true)); - EXPECT_THAT( - PtxCompileOptionsFromDebugOptions(debug_options, - /*is_autotuning_compilation=*/true), - Field(&CompilationOptions::disable_optimizations, true)); -} - -TEST(PtxCompileOptionsFromDebugOptionsTest, LineInfoCanBeEnabled) { - DebugOptions debug_options; - debug_options.set_xla_gpu_generate_line_info(true); - EXPECT_THAT( - PtxCompileOptionsFromDebugOptions(debug_options, - /*is_autotuning_compilation=*/false), - Field(&CompilationOptions::generate_line_info, true)); - EXPECT_THAT( - PtxCompileOptionsFromDebugOptions(debug_options, - /*is_autotuning_compilation=*/true), - Field(&CompilationOptions::generate_line_info, true)); -} - -TEST(PtxCompileOptionsFromDebugOptionsTest, DebugInfoCanBeEnabled) { - DebugOptions debug_options; - debug_options.set_xla_gpu_generate_debug_info(true); - EXPECT_THAT( - PtxCompileOptionsFromDebugOptions(debug_options, - /*is_autotuning_compilation=*/false), - Field(&CompilationOptions::generate_debug_info, true)); - EXPECT_THAT( - PtxCompileOptionsFromDebugOptions(debug_options, - /*is_autotuning_compilation=*/true), - Field(&CompilationOptions::generate_debug_info, true)); -} - -TEST(PtxCompileOptionsFromDebugOptionsTest, - RegSpillAsErrorCanBeEnabledForAutotuning) { - DebugOptions debug_options; - debug_options.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning( - true); - EXPECT_THAT( - PtxCompileOptionsFromDebugOptions(debug_options, - /*is_autotuning_compilation=*/false), - Field(&CompilationOptions::cancel_if_reg_spill, false)); - EXPECT_THAT( - PtxCompileOptionsFromDebugOptions(debug_options, - /*is_autotuning_compilation=*/true), - Field(&CompilationOptions::cancel_if_reg_spill, true)); -} - -} // namespace -} // namespace xla::gpu