Skip to content

Commit

Permalink
gpu_blas_lt_gemm_runner (#2766)
Browse files Browse the repository at this point in the history
* gpu_blas_lt_gemm_runner

* change the location of gemm runner for Batched GEMM

* use xla flags to enable hipblaslt instead of env vars

* non-xla hipblaslt for ThenBlasGemmStridedBatched
  • Loading branch information
ScXfjiang authored Jan 2, 2025
1 parent 344ab14 commit 1667ddb
Show file tree
Hide file tree
Showing 9 changed files with 846 additions and 110 deletions.
2 changes: 1 addition & 1 deletion tensorflow/compiler/xla/service/computation_placer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
ComputationPlacerCreationFunction creation_function) {
absl::MutexLock lock(&ComputationPlacer::platform_computation_placer_mutex_);
auto* computation_placers = GetPlatformComputationPlacers();
CHECK(computation_placers->find(platform_id) == computation_placers->end());
// CHECK(computation_placers->find(platform_id) == computation_placers->end());
(*computation_placers)[platform_id].creation_function = creation_function;
}

Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/stream_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ tsl_gpu_library(
":temporary_memory_manager",
":timer",
"//tensorflow/compiler/xla/stream_executor/platform",
"//tensorflow/compiler/xla/stream_executor/gpu:gpu_blas_lt_gemm_runner",
"//tensorflow/tsl/platform:env",
"//tensorflow/tsl/platform:errors",
"//tensorflow/tsl/platform:logging",
Expand Down
16 changes: 9 additions & 7 deletions tensorflow/compiler/xla/stream_executor/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ cc_library(
#"//tensorflow/core/platform:env",
"//tensorflow/tsl/util:env_var",
"@com_google_absl//absl/types:any",
"//tensorflow/compiler/xla:debug_options_flags",
]),
)

Expand All @@ -87,13 +88,14 @@ cc_library(
srcs = if_gpu_is_configured(["gpu_blas_lt_gemm_runner.cc"]),
hdrs = if_gpu_is_configured(["gpu_blas_lt_gemm_runner.h"]),
deps = if_gpu_is_configured([
"//tensorflow/core:autotuning_proto_cc",
"//tensorflow/core:autotune_results_proto_cc",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/stream_executor:scratch_allocator",
"//tensorflow/compiler/xla/service/gpu:autotuner_util",
"//tensorflow/compiler/xla:debug_options_flags",
":gpu_blas_lt",
"//tensorflow/core/protobuf:autotuning_proto_cc",
"//tensorflow/compiler/xla:autotune_results_proto_cc",
"//tensorflow/compiler/xla:xla_proto_cc",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/stream_executor:scratch_allocator",
"//tensorflow/compiler/xla/service/gpu:autotuner_util",
"//tensorflow/compiler/xla:debug_options_flags",
":gpu_blas_lt",
]),
)

Expand Down
8 changes: 8 additions & 0 deletions tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/stream_executor/stream_executor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/tsl/util/env_var.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"

namespace stream_executor {

Expand All @@ -31,6 +32,13 @@ using blas::ComputationType;
using blas::DataType;
using xla::PrimitiveType;

bool GpuBlasLtEnabled() {
static std::atomic_bool result{[] {
return xla::GetDebugOptionsFromFlags().xla_gpu_enable_cublaslt();
}()};
return result;
}

namespace {

bool TF32_Enabled() {
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ namespace stream_executor {

namespace gpu {

bool GpuBlasLtEnabled();

xla::StatusOr<blas::DataType> AsBlasDataType(xla::PrimitiveType dtype);

xla::StatusOr<blas::ComputationType> GetBlasComputationType(
Expand Down
Loading

0 comments on commit 1667ddb

Please sign in to comment.