diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index f00a1399aefec3..29972f2764af8d 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -163,7 +163,7 @@ StatusOr ComputationPlacer::AssignDevices( ComputationPlacerCreationFunction creation_function) { absl::MutexLock lock(&ComputationPlacer::platform_computation_placer_mutex_); auto* computation_placers = GetPlatformComputationPlacers(); - CHECK(computation_placers->find(platform_id) == computation_placers->end()); + // CHECK(computation_placers->find(platform_id) == computation_placers->end()); (*computation_placers)[platform_id].creation_function = creation_function; } diff --git a/tensorflow/compiler/xla/stream_executor/BUILD b/tensorflow/compiler/xla/stream_executor/BUILD index 0425da4aea423a..34f7034934856a 100644 --- a/tensorflow/compiler/xla/stream_executor/BUILD +++ b/tensorflow/compiler/xla/stream_executor/BUILD @@ -450,6 +450,7 @@ tsl_gpu_library( ":temporary_memory_manager", ":timer", "//tensorflow/compiler/xla/stream_executor/platform", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_blas_lt_gemm_runner", "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:logging", diff --git a/tensorflow/compiler/xla/stream_executor/gpu/BUILD b/tensorflow/compiler/xla/stream_executor/gpu/BUILD index 1c327bb60ca64f..5b07fe0723f9e0 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/BUILD +++ b/tensorflow/compiler/xla/stream_executor/gpu/BUILD @@ -78,6 +78,7 @@ cc_library( #"//tensorflow/core/platform:env", "//tensorflow/tsl/util:env_var", "@com_google_absl//absl/types:any", + "//tensorflow/compiler/xla:debug_options_flags", ]), ) @@ -87,13 +88,14 @@ cc_library( srcs = if_gpu_is_configured(["gpu_blas_lt_gemm_runner.cc"]), hdrs = if_gpu_is_configured(["gpu_blas_lt_gemm_runner.h"]), deps = if_gpu_is_configured([ - "//tensorflow/core:autotuning_proto_cc", - "//tensorflow/core:autotune_results_proto_cc", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/stream_executor:scratch_allocator", - "//tensorflow/compiler/xla/service/gpu:autotuner_util", - "//tensorflow/compiler/xla:debug_options_flags", - ":gpu_blas_lt", + "//tensorflow/core/protobuf:autotuning_proto_cc", + "//tensorflow/compiler/xla:autotune_results_proto_cc", + "//tensorflow/compiler/xla:xla_proto_cc", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/stream_executor:scratch_allocator", + "//tensorflow/compiler/xla/service/gpu:autotuner_util", + "//tensorflow/compiler/xla:debug_options_flags", + ":gpu_blas_lt", ]), ) diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc index 195e1161a3aa3a..a47a0b01a6e895 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/stream_executor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/tsl/util/env_var.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" namespace stream_executor { @@ -31,6 +32,13 @@ using blas::ComputationType; using blas::DataType; using xla::PrimitiveType; +bool GpuBlasLtEnabled() { + static std::atomic_bool result{[] { + return xla::GetDebugOptionsFromFlags().xla_gpu_enable_cublaslt(); + }()}; + return result; +} + namespace { bool TF32_Enabled() { diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h index 9eb8d121bb44fc..90625d46c46a5e 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h @@ -34,6 +34,8 @@ namespace stream_executor { namespace gpu { +bool GpuBlasLtEnabled(); + xla::StatusOr AsBlasDataType(xla::PrimitiveType dtype); xla::StatusOr GetBlasComputationType( diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc new file mode 100644 index 00000000000000..8693b0e42300f8 --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc @@ -0,0 +1,341 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/core/util/env_var.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h" +#include "tensorflow/compiler/xla/stream_executor/stream.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" + +namespace stream_executor { +namespace gpu { + +bool BlasLtGemmRunner::autotune_enabled_ = true; + +bool operator ==(const GroupedGemmConfig& rhs, const GroupedGemmConfig& lhs) { + return AsTuple(rhs) == AsTuple(lhs); +} + +bool operator ==(const StridedGemmConfig& rhs, const StridedGemmConfig& lhs) { + return AsTuple(rhs) == AsTuple(lhs); +} + +std::ostream& operator <<(std::ostream& os, const StridedGemmConfig& cfg) { + return os << "trans_a/b: " << (int)cfg.trans_a << "/" << (int)cfg.trans_b << + " m: " << cfg.m << " n: " << cfg.n << " k: " << cfg.k << + " batch_count: " << cfg.batch_count << + " lda: " << cfg.lda << " ldb: " << cfg.ldb << " ldc: " << cfg.ldc << + " stride_a: " << cfg.stride_a << " stride_b: " << cfg.stride_b << + " stride_c: " << cfg.stride_c << + " type_a: " << (int)cfg.type_a << " type_b: " << (int)cfg.type_b << + " type_c: " << (int)cfg.type_c << + " alpha: " << cfg.alpha << " beta: " << cfg.beta; +} + +BlasLtGemmRunner::BlasLtGemmRunner(StreamExecutor *parent) : + mutex_(std::make_unique< absl::Mutex >()), + autotune_config_(std::make_unique< xla::gpu::AutotuneConfig >( + xla::gpu::DeviceConfig{parent, nullptr}, + xla::GetDebugOptionsFromFlags())) + { } + +BlasLtGemmRunner::~BlasLtGemmRunner() { } + + +/*static*/ BlasLtGemmRunner& BlasLtGemmRunner::i(const Stream *stream) { + static absl::Mutex m(absl::kConstInit); + // Each GPU gets a different cache instance + static std::vector> meta(8); + absl::MutexLock lock(&m); + size_t dev_id = stream->parent()->device_ordinal(); + if (dev_id >= meta.size()) meta.resize(dev_id + 1); + auto& res = meta[dev_id]; + if (!res) { + autotune_enabled_ = xla::GetDebugOptionsFromFlags().xla_gpu_autotune_level() > 0; + res.reset(new BlasLtGemmRunner(stream->parent())); + xla::gpu::AutotunerUtil::LoadAutotuneResultsFromFileOnce(*res->autotune_config_); + } + return *res; +} + +template < class TuneFunc > +xla::StatusOr< gpu::BlasLt::MatmulAlgorithm > BlasLtGemmRunner::Autotune( + const std::vector< gpu::BlasLt::MatmulAlgorithm >& algorithms, + TuneFunc&& benchmark_func) { + gpu::BlasLt::MatmulAlgorithm best_algo; + float best_ms = std::numeric_limits< float >::max(), total_ms = 0; + uint32_t n_warmups = 1, n_iters = 5, n_total = n_warmups + n_iters, i = 0; + + for (uint32_t j = 0; j < algorithms.size(); j++) { + const auto& algo = algorithms[j]; + if (!benchmark_func(algo, nullptr).ok()) continue; + + blas::ProfileResult profile; + for (i = 0, total_ms = 0; i < n_total; i++) { + auto res = benchmark_func(algo, &profile); + if (!res.ok() || !profile.is_valid()) { + VLOG(1) << j << ": gemm algorithm is not valid: " /* << res.error_message() */; + break; + } + if (i >= n_warmups) total_ms += profile.elapsed_time_in_ms(); + } + if (i < n_total) continue; // invalid algorithm + total_ms /= n_iters; + VLOG(2) << j << ": gemm algorithm " << profile.algorithm() << " took " + << total_ms << "ms, workspace: " << algo.workspace_size; + if (total_ms < best_ms) { + best_ms = total_ms, best_algo = algo; + } + } // for algorithms + if (!best_algo.opaque_algo.has_value()) { + return xla::InternalError("No valid gemm algorithms found!"); + } + return best_algo; +} + +xla::StatusOr< std::array< uint64_t, 3 >> BlasLtGemmRunner::ContiguousStrides( + const ArraySlice& a, + const ArraySlice& b, + const ArraySlice& c, int64 batch_count) { + + uint64_t bsa = 0, bsb = 0, bsc = 0; + using CT = const uint8_t; + for(int64 i = 0; i < batch_count-1; i++) { + uint64_t da = (CT *)a[i + 1]->opaque() - (CT *)a[i]->opaque(), + db = (CT *)b[i + 1]->opaque() - (CT *)b[i]->opaque(), + dc = (CT *)c[i + 1]->opaque() - (CT *)c[i]->opaque(); + if(i == 0) { + bsa = da, bsb = db, bsc = dc; + } else if(!(bsa == da && bsb == db && bsc == dc)) { // strides mismatch + return xla::InternalError("Strides are not consistent!"); + } + } + return std::array< uint64_t, 3 >{ bsa, bsb, bsc }; +} + +xla::Status BlasLtGemmRunner::RunBatchedImpl(Stream& stream, + blas::Transpose trans_a, blas::Transpose trans_b, int64 m, int64 n, int64 k, + const void *alpha, blas::DataType type_a, const void** a, int64 lda, + blas::DataType type_b, const void** b, int64 ldb, const void *beta, + blas::DataType type_c, void** c, int64 ldc, int64 batch_count, + ScratchAllocator* allocator) +{ + + TF_ASSIGN_OR_RETURN(auto compute_type, + gpu::GetBlasComputationType(type_a, type_c, 0)); + + GroupedGemmConfig cfg{ + .m = (int64)m, + .n = (int64)n, + .k = (int64)k, + .batch_count = (int64)batch_count, + .trans_a = trans_a, + .trans_b = trans_b, + .alpha = alpha, + .beta = beta, + .type_a = type_a, + .type_b = type_b, + .type_c = type_c, + .type_d = type_c, + .lda = (int64)lda, + .ldb = (int64)ldb, + .ldc = (int64)ldc, + .ldd = (int64)ldc, + .compute_type = compute_type, + .a = a, + .b = b, + .c = const_cast< const void **>(c), + .d = c, + }; + + absl::MutexLock lock(mutex_.get()); + + auto res = grouped_gemm_map_.find(cfg); + if (res == grouped_gemm_map_.end()) { + // NOTE: we assume that pointers a,b,c come from the device mem + // hence we need to block stream here + TF_ASSIGN_OR_RETURN(auto plan_res, + gpu::BlasLt::CreateGroupedMatmulPlan(&stream, cfg)); + res = grouped_gemm_map_.emplace(cfg, std::move(plan_res)).first; + + size_t num_solutions = autotune_enabled_ ? gpu::BlasLt::kMaxAlgorithms : 1; + // discard solutions with non-zero workspace if allocator is not given + TF_ASSIGN_OR_RETURN(auto algorithms, res->second->GetAlgorithms( + num_solutions, allocator == nullptr ? 0 : 1ull << 32)); + + VLOG(1) << stream.parent() << ": new GGemm config: " << + grouped_gemm_map_.size() << " #valid algorithms: " << algorithms.size(); + + BlasLt::MatmulAlgorithm best_algo; + if (!autotune_enabled_) { + if (algorithms.empty()) return xla::InternalError("No GG algorithms found!"); + best_algo = algorithms[0]; // otherwise use default algorithm + } else { + TF_ASSIGN_OR_RETURN(auto best_algo, Autotune(algorithms, + [&](const gpu::BlasLt::MatmulAlgorithm& algo, blas::ProfileResult *profile){ + if (profile == nullptr) { + return res->second->SetAlgorithm(algo, allocator); + } + return res->second->ExecuteOnStream(&stream, cfg, profile); + })); + } + TF_RETURN_IF_ERROR(res->second->SetAlgorithm(best_algo, allocator)); + } + return res->second->ExecuteOnStream(&stream, cfg); +} + +xla::Status BlasLtGemmRunner::RunStridedBatchedImpl(Stream& stream, + blas::Transpose trans_a, blas::Transpose trans_b, int64 m, int64 n, int64 k, + xla::complex128 alpha, + blas::DataType type_a, const DeviceMemoryBase& a, int64 lda, int64 stride_a, + blas::DataType type_b, const DeviceMemoryBase& b, int64 ldb, int64 stride_b, + double beta, + blas::DataType type_c, DeviceMemoryBase *c, int64 ldc, int64 stride_c, + int64 batch_count, ScratchAllocator* allocator) +{ + StridedGemmConfig scfg{ + .m = m, + .n = n, + .k = k, + .batch_count = (int64)batch_count, + .trans_a = trans_a, + .trans_b = trans_b, + .alpha = alpha, + .beta = beta, + .type_a = type_a, + .type_b = type_b, + .type_c = type_c, + .lda = lda, + .ldb = ldb, + .ldc = ldc, + .stride_a = stride_a, + .stride_b = stride_b, + .stride_c = stride_c, + }; + + absl::MutexLock lock(mutex_.get()); + + auto res = strided_gemm_map_.find(scfg); + while (res == strided_gemm_map_.end()) { + int64 row_a = m, col_a = k, row_b = k, col_b = n; + if (trans_a == blas::Transpose::kTranspose) std::swap(row_a, col_a); + if (trans_b == blas::Transpose::kTranspose) std::swap(row_b, col_b); + + auto order = MatrixLayout::Order::kColumnMajor; + GemmConfig cfg = { + .lhs_layout = MatrixLayout(type_a, row_a, col_a, order, batch_count, + lda, stride_a, trans_a), + + .rhs_layout = MatrixLayout(type_b, row_b, col_b, order, batch_count, + ldb, stride_b, trans_b), + + .c_layout = MatrixLayout(type_c, m, n, order, batch_count, + ldc, stride_c), + .output_layout = MatrixLayout(type_c, m, n, order, batch_count, + ldc, stride_c), + .alpha = alpha, + .beta = beta, + .compute_precision = -1, + .epilogue = gpu::BlasLt::Epilogue::kDefault, + }; + + TF_ASSIGN_OR_RETURN(auto plan_res, + gpu::BlasLt::GetMatmulPlan(&stream, cfg)); + res = strided_gemm_map_.emplace(scfg, std::move(plan_res)).first; + + size_t num_solutions = autotune_enabled_ ? gpu::BlasLt::kMaxAlgorithms : 1; + // discard solutions with non-zero workspace if allocator is not given + TF_ASSIGN_OR_RETURN(auto algorithms, res->second->GetAlgorithms( + num_solutions, allocator == nullptr ? 0 : 1ull << 32)); + + VLOG(1) << &stream << " dev " << stream.parent() << '(' << + stream.parent()->device_ordinal() << "): new StridedBatched config: " + << strided_gemm_map_.size() << " #algorithms: " << algorithms.size(); + + if (!autotune_enabled_) { + if (algorithms.empty()) return xla::InternalError("No algorithms found!"); + res->second->SetAlgorithm(algorithms[0]); + break; + } + + BlasLt::MatmulAlgorithm best_algo{ .id = blas::kNoAlgorithm }; + xla::gpu::AutotuneCacheKey key(ToCSVString(cfg, /*full_string*/false)); + auto opt_res = xla::gpu::AutotunerUtil::TryToFindInInMemoryCache(key); + if (opt_res.has_value()) { + auto id = *opt_res; + for (const auto& algo : algorithms) { + if (algo.id == id) best_algo = algo; + } + if (best_algo.id == blas::kNoAlgorithm) { + LOG(WARNING) << "Best algorithm not valid: need to autotune.."; + } + } + + if (best_algo.id == blas::kNoAlgorithm) { + TF_ASSIGN_OR_RETURN(best_algo, Autotune(algorithms, + [&](const gpu::BlasLt::MatmulAlgorithm& algo, blas::ProfileResult *profile){ + if (profile == nullptr) { + return res->second->SetAlgorithm(algo); + } + return res->second->ExecuteOnStream( + &stream, a, b, *c, *c, + DeviceMemoryBase{}, // bias + DeviceMemoryBase{}, // aux + DeviceMemoryBase{}, // a_scale + DeviceMemoryBase{}, // b_scale + DeviceMemoryBase{}, // c_scale + DeviceMemoryBase{}, // d_scale + DeviceMemoryBase{}, // d_amax + absl::nullopt, // workspace + allocator, // allocator + profile); + })); + xla::gpu::AutotunerUtil::CacheValue ares = best_algo.id; + // reread algorithm ID from cache again (in case some other thread has + // already added this config to the cache to be sure we use the same ID) + auto new_id = xla::gpu::AutotunerUtil::AddResultToInMemoryCache(key, ares, + *autotune_config_); + + if (new_id != best_algo.id) { + for (const auto& algo : algorithms) { + if (algo.id == new_id) best_algo = algo; + } + } + } // best_algo.id == blas::kNoAlgorithm + + res->second->SetAlgorithm(best_algo); + break; + } // while + return res->second->ExecuteOnStream( + &stream, a, b, *c, *c, + DeviceMemoryBase{}, // bias + DeviceMemoryBase{}, // aux + DeviceMemoryBase{}, // a_scale + DeviceMemoryBase{}, // b_scale + DeviceMemoryBase{}, // c_scale + DeviceMemoryBase{}, // d_scale + DeviceMemoryBase{}, // d_amax + absl::nullopt, // workspace + allocator); // allocator +} + +} // namespace gpu + +} // namespace stream_executor diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h new file mode 100644 index 00000000000000..97c18a5e64ed42 --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h @@ -0,0 +1,260 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_GEMM_RUNNER_H_ +#define TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_GEMM_RUNNER_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h" +#include "tensorflow/compiler/xla/stream_executor/scratch_allocator.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/compiler/xla/util.h" + +using tensorflow::gtl::ArraySlice; +typedef ::std::int64_t int64; + + +namespace xla { +namespace gpu { +class AutotuneConfig; +} +} + +namespace stream_executor { + +namespace gpu { + +struct StridedGemmConfig { + int64 m, n, k, batch_count; + blas::Transpose trans_a, trans_b; + xla::complex128 alpha; + double beta; + blas::DataType type_a, type_b, type_c; + int64 lda, ldb, ldc; + int64 stride_a, stride_b, stride_c; +}; + +namespace { + +auto AsTuple(const GroupedGemmConfig& p) { + // NOTE: alpha, beta and data pointers are not included in cache !! + return std::make_tuple(p.m, p.n, p.k, p.batch_count, + p.trans_a, p.trans_b, + p.type_a, p.type_b, p.type_c, p.type_d, + p.lda, p.ldb, p.ldc, p.ldd, + p.compute_type); +} + +auto AsTuple(const StridedGemmConfig& p) { + return std::make_tuple(p.m, p.n, p.k, p.batch_count, + p.trans_a, p.trans_b, p.alpha.real(), p.alpha.imag(), p.beta, + p.type_a, p.type_b, p.type_c, + p.lda, p.ldb, p.ldc, + p.stride_a, p.stride_b, p.stride_c); +} + +} // namespace + +bool operator ==(const GroupedGemmConfig& rhs, const GroupedGemmConfig& lhs); +bool operator ==(const StridedGemmConfig& rhs, const StridedGemmConfig& lhs); + +template +H AbslHashValue(H h, const GroupedGemmConfig& params) { + return H::combine(std::move(h), AsTuple(params)); +} + +template +H AbslHashValue(H h, const StridedGemmConfig& params) { + return H::combine(std::move(h), AsTuple(params)); +} + +struct BlasLtGemmRunner { + + static BlasLtGemmRunner& i(const Stream *stream); + + template < class Scalar > + xla::complex128 Convert(Scalar x) { + if constexpr(std::is_same::value || + std::is_same::value) { + return static_cast< xla::complex128 >(x); + } else { + return static_cast< double >(x); + } + } + + template < class Scalar, class TypeA, class TypeB, class TypeC > + xla::Status Run(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + Scalar alpha, const DeviceMemory& a, int64 lda, + const DeviceMemory& b, int64 ldb, + Scalar beta, DeviceMemory *c, int64 ldc, + ScratchAllocator* allocator) { + + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + return RunStridedBatchedImpl(stream, trans_a, trans_b, m, n, k, + Convert(alpha), + type_a, a, lda, 0, type_b, b, ldb, 0, + Convert(beta).real(), // only real betas are supported!! + type_c, c, ldc, 0, 1, allocator); + } + + template < class Scalar, class TypeA, class TypeB, class TypeC > + xla::Status Run(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + Scalar alpha, const TypeA* a, int64 lda, + const TypeB *b, int64 ldb, + Scalar beta, TypeC *c, int64 ldc, + ScratchAllocator* allocator) { + + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + + DeviceMemoryBase mem_c{c}; + return RunStridedBatchedImpl(stream, trans_a, trans_b, m, n, k, + Convert(alpha), + type_a, DeviceMemoryBase{const_cast< TypeA *>(a)}, lda, 0, + type_b, DeviceMemoryBase{const_cast< TypeB *>(b)}, ldb, 0, + Convert(beta).real(), // only real betas are supported!! + type_c, &mem_c, ldc, 0, 1, allocator); + } + + + template < class Scalar, class TypeA, class TypeB, class TypeC> + xla::Status RunStridedBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + Scalar alpha, const TypeA* a, int64 lda, int64 stride_a, + const TypeB* b, int64 ldb, int64 stride_b, + Scalar beta, TypeC* c, int64 ldc, int64 stride_c, + int64 batch_count, ScratchAllocator* allocator) { + + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + DeviceMemoryBase mem_c{c}; + return RunStridedBatchedImpl( + stream, trans_a, trans_b, m, n, k, Convert(alpha), type_a, + DeviceMemoryBase{const_cast(a)}, lda, stride_a, type_b, + DeviceMemoryBase{const_cast(a)}, ldb, stride_b, + Convert(beta).real(), // only real betas are supported!! + type_c, &mem_c, ldc, stride_c, batch_count, allocator); + } + + template < class Scalar, class TypeA, class TypeB, class TypeC> + xla::Status RunStridedBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + Scalar alpha, const DeviceMemory& a, int64 lda, int64 stride_a, + const DeviceMemory& b, int64 ldb, int64 stride_b, + Scalar beta, DeviceMemory *c, int64 ldc, int64 stride_c, + int64 batch_count, ScratchAllocator* allocator) { + + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + return RunStridedBatchedImpl(stream, trans_a, trans_b, m, n, k, + Convert(alpha), + type_a, a, lda, stride_a, type_b, b, ldb, stride_b, + Convert(beta).real(), // only real betas are supported!! + type_c, c, ldc, stride_c, batch_count, allocator); + } + + template < class Scalar, class T > + xla::Status RunBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, Scalar alpha, + const ArraySlice *> &a, int64 lda, + const ArraySlice *> &b, int64 ldb, Scalar beta, + const ArraySlice *> &c, int64 ldc, + int64 batch_count, ScratchAllocator* allocator) { + + // NOTE: Scalar types shall be verified for correctness vs T!! + auto type = dnn::ToDataType::value; + auto cvt = [](auto x){ + using TT = ArraySlice; + auto ptr = reinterpret_cast(&x); + return *reinterpret_cast(ptr); + }; + + auto res = ContiguousStrides(cvt(a), cvt(b), cvt(c), batch_count); + if (res.ok()) { + auto strides = std::move(res.value()); + return RunStridedBatchedImpl(stream, trans_a, trans_b, m, n, k, + Convert(alpha), + type, *a[0], lda, strides[0] / sizeof(T), + type, *b[0], ldb, strides[1] / sizeof(T), + Convert(beta).real(), // only real betas are supported!! + type, c[0], ldc, strides[2] / sizeof(T), batch_count, allocator); + } + return xla::InternalError("RunBatched: port::ArraySlice NYI!"); + } + + + template < class Scalar, class T > + xla::Status RunBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, uint64 m, uint64 n, uint64 k, + Scalar alpha, const T** a, int lda, + const T** b, int ldb, Scalar beta, + T** c, int64 ldc, int64 batch_count, ScratchAllocator* allocator){ + + auto type = dnn::ToDataType::value; + return RunBatchedImpl(stream, trans_a, trans_b, m, n, k, + &alpha, type, reinterpret_cast< const void **>(a), lda, + type, reinterpret_cast< const void **>(b), ldb, &beta, + type, reinterpret_cast< void **>(c), ldc, batch_count, allocator); + } + + ~BlasLtGemmRunner(); + BlasLtGemmRunner& operator=(BlasLtGemmRunner&& rhs) noexcept = default; + BlasLtGemmRunner(BlasLtGemmRunner&& rhs) noexcept = default; + +private: + explicit BlasLtGemmRunner(StreamExecutor *parent); + + template < class TuneFunc > + xla::StatusOr< gpu::BlasLt::MatmulAlgorithm > Autotune( + const std::vector< gpu::BlasLt::MatmulAlgorithm >& algorithms, + TuneFunc&& benchmark_func); + + + xla::Status RunBatchedImpl(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + const void *alpha, blas::DataType type_a, const void** a, int64 lda, + blas::DataType type_b, const void** b, int64 ldb, const void *beta, + blas::DataType type_c, void** c, int64 ldc, int64 batch_count, + ScratchAllocator* allocator); + + xla::Status RunStridedBatchedImpl(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, xla::complex128 alpha, + blas::DataType type_a, const DeviceMemoryBase& a, int64 lda, int64 stride_a, + blas::DataType type_b, const DeviceMemoryBase& b, int64 ldb, int64 stride_b, + double beta, + blas::DataType type_c, DeviceMemoryBase *c, int64 ldc, int64 stride_c, + int64 batch_count, ScratchAllocator* allocator); + + xla::StatusOr< std::array< uint64_t, 3 >> ContiguousStrides( + const ArraySlice& a, + const ArraySlice& b, + const ArraySlice& c, int64 batch_count); + + static bool autotune_enabled_; + std::unique_ptr< absl::Mutex > mutex_; + std::unique_ptr< xla::gpu::AutotuneConfig > autotune_config_; + absl::flat_hash_map grouped_gemm_map_; + absl::flat_hash_map strided_gemm_map_; +}; + +} // namespace gpu + +} // namespace stream_executor + +#endif // TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_GEMM_RUNNER_H_ diff --git a/tensorflow/compiler/xla/stream_executor/stream.cc b/tensorflow/compiler/xla/stream_executor/stream.cc index f9d3f37a73b63c..d269590b1a7696 100644 --- a/tensorflow/compiler/xla/stream_executor/stream.cc +++ b/tensorflow/compiler/xla/stream_executor/stream.cc @@ -33,6 +33,8 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" #include "tensorflow/tsl/platform/stacktrace.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h" namespace stream_executor { @@ -1427,6 +1429,185 @@ Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k, x, incx, beta, y, incy); } +template +tsl::Status Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64 n, uint64 k, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + DeviceMemory *c, int ldc, + blas::ComputePrecision precision, + blas::CallContext context) { + InputType alpha{1.0}; + InputType beta{0.0}; + if(gpu::GpuBlasLtEnabled()) { + auto& r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.Run(*this, transa, transb, m, n, k, + alpha, a, lda, b, ldb, beta, c, ldc, + /* allocator */nullptr)); //! NOTE: allocator is not available!! + return ::tsl::OkStatus(); + } + return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc, precision, context); +} + +#define INSTANTIATE_THEN_BLAS_GEMM(INPUT_TYPE) \ + template tsl::Status Stream::ThenBlasGemm( \ + blas::Transpose transa, blas::Transpose transb, \ + uint64_t m, uint64 n, uint64 k, \ + const DeviceMemory& a, int lda, \ + const DeviceMemory& b, int ldb, \ + DeviceMemory* c, int ldc, \ + blas::ComputePrecision precision, \ + blas::CallContext context); + +INSTANTIATE_THEN_BLAS_GEMM(float) +INSTANTIATE_THEN_BLAS_GEMM(double) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::half) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::bfloat16) +INSTANTIATE_THEN_BLAS_GEMM(std::complex) +INSTANTIATE_THEN_BLAS_GEMM(std::complex) + +#undef INSTANTIATE_THEN_BLAS_GEMM + +template +tsl::Status Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64 n, uint64 k, ConstantType alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + ConstantType beta, DeviceMemory *c, + int ldc, blas::ComputePrecision precision, + blas::CallContext context) { + static_assert( + detail::is_any_of, std::complex>(), + "Input can be half, bf16, float, double, std::complex or " + "std::complex"); + static_assert(!std::is_same_v || + detail::is_any_of(), + "If input is Eigen::half, constant has to be either " + "Eigen::half or float"); + static_assert(!std::is_same_v || + detail::is_any_of(), + "If input is Eigen::bfloat16, constant has to be either " + "Eigen::bfloat16 or float"); + static_assert( + detail::is_any_of(), + "If input is not Eigen::half, constant and input types have to match"); + + if(gpu::GpuBlasLtEnabled()) { + auto& r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.Run(*this, transa, transb, m, n, k, + alpha, a, lda, b, ldb, beta, c, ldc, + /* allocator */nullptr)); //! NOTE: allocator is not available!! + return ::tsl::OkStatus(); + } + + blas::BlasSupport *blas = parent()->AsBlas(); + if (!blas) { + return tsl::errors::Internal( + "Attempting to perform BLAS operation using " + "StreamExecutor without BLAS support"); + } + + void *alpha_ptr = α + void *beta_ptr = β + float alpha_storage, beta_storage; + UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, + &beta_storage); + + return blas->DoBlasGemm(this, transa, transb, m, n, k, + blas::ToDataType::value, alpha_ptr, a, + lda, b, ldb, beta_ptr, c, ldc, precision, + context); +} + +#define INSTANTIATE_THEN_BLAS_GEMM(INPUT_TYPE, CONSTANT_TYPE) \ + template tsl::Status Stream::ThenBlasGemm( \ + blas::Transpose transa, blas::Transpose transb, \ + uint64_t m, uint64 n, uint64 k, CONSTANT_TYPE alpha, \ + const DeviceMemory& a, int lda, \ + const DeviceMemory& b, int ldb, \ + CONSTANT_TYPE beta, DeviceMemory* c, int ldc, \ + blas::ComputePrecision precision, \ + blas::CallContext context); + +INSTANTIATE_THEN_BLAS_GEMM(float, float) +INSTANTIATE_THEN_BLAS_GEMM(double, double) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::half, Eigen::half) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::bfloat16, Eigen::bfloat16) +INSTANTIATE_THEN_BLAS_GEMM(std::complex, std::complex) +INSTANTIATE_THEN_BLAS_GEMM(std::complex, std::complex) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::half, float) +INSTANTIATE_THEN_BLAS_GEMM(Eigen::bfloat16, float) + +#undef INSTANTIATE_THEN_BLAS_GEMM + +template +tsl::Status Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, + uint64_t k, ConstantType alpha, const DeviceMemory &a, int lda, + int64_t stride_a, const DeviceMemory &b, int ldb, + int64_t stride_b, ConstantType beta, DeviceMemory *c, int ldc, + int64_t stride_c, int batch_count, blas::ComputePrecision precision, + blas::CallContext context) { + static_assert( + detail::is_any_of, std::complex>(), + "Unsupported input type"); + static_assert( + std::is_same_v || + (detail::is_any_of() && + std::is_same_v), + "Mismatched input and alpha/beta types"); + + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunStridedBatched( + *this, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_count, + /* allocator */ nullptr)); //! NOTE: allocator is not available!! + return ::tsl::OkStatus(); + } + + blas::BlasSupport *blas = parent()->AsBlas(); + if (!blas) { + return tsl::errors::Internal( + "Attempting to perform BLAS operation using " + "StreamExecutor without BLAS support"); + } + + void *alpha_ptr = α + void *beta_ptr = β + float alpha_storage, beta_storage; + UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, + &beta_storage); + + return blas->DoBlasGemmStridedBatched( + this, transa, transb, m, n, k, blas::ToDataType::value, + alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc, + stride_c, batch_count, precision, context); +} + +#define INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(INPUT_TYPE, CONSTANT_TYPE) \ + template tsl::Status Stream::ThenBlasGemmStridedBatched ( \ + blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, \ + uint64_t k, CONSTANT_TYPE alpha, const DeviceMemory &a, int lda, \ + int64_t stride_a, const DeviceMemory &b, int ldb, \ + int64_t stride_b, CONSTANT_TYPE beta, DeviceMemory *c, int ldc, \ + int64_t stride_c, int batch_count, blas::ComputePrecision precision, \ + blas::CallContext context); + +INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(float, float) +INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(double, double) +INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(Eigen::half, Eigen::half) +INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(Eigen::bfloat16, Eigen::bfloat16) +INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(std::complex, std::complex) +INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(std::complex, std::complex) +INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(Eigen::half, float) +INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(Eigen::bfloat16, float) + +#undef INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED + namespace { // Like ThenBlasImpl, except this expects the last argument of blas_func to be a // blas::ProfileResult*. This functor doesn't put the stream into an error @@ -1605,7 +1786,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, scratch_allocator)); + return *this; + } ThenBlasImpl, int, DeviceMemorySlice, int, float, @@ -1625,7 +1811,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, scratch_allocator)); + return *this; + } ThenBlasImpl, int, DeviceMemorySlice, int, float, @@ -1657,7 +1848,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, scratch_allocator)); + return *this; + } ThenBlasImpl, int, DeviceMemorySlice, int, float, DeviceMemorySlice, int, int, @@ -1689,7 +1885,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, scratch_allocator)); + return *this; + } ThenBlasImpl, int, DeviceMemorySlice, int, double, @@ -1721,7 +1922,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, scratch_allocator)); + return *this; + } ThenBlasImpl, DeviceMemorySlice>, int, DeviceMemorySlice>, int, std::complex, @@ -1756,7 +1962,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, scratch_allocator)); + return *this; + } ThenBlasImpl, DeviceMemorySlice>, int, DeviceMemorySlice>, int, diff --git a/tensorflow/compiler/xla/stream_executor/stream.h b/tensorflow/compiler/xla/stream_executor/stream.h index bb3c3b511a1605..7b9175b53d854c 100644 --- a/tensorflow/compiler/xla/stream_executor/stream.h +++ b/tensorflow/compiler/xla/stream_executor/stream.h @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" #include "tensorflow/compiler/xla/stream_executor/temporary_memory_manager.h" + namespace stream_executor { namespace host { @@ -898,24 +899,7 @@ class Stream { const DeviceMemory &b, int ldb, DeviceMemory *c, int ldc, blas::ComputePrecision precision, - blas::CallContext context) { - InputType alpha{1.0}; - InputType beta{0.0}; - return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, - ldc, precision, context); - } - - // TODO(parkers): Update all callers to pass kDefaultComputePrecision. - template - tsl::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64 k, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, - DeviceMemory *c, int ldc, - blas::CallContext context) { - return ThenBlasGemm(transa, transb, m, n, k, a, lda, b, ldb, c, ldc, - blas::kDefaultComputePrecision,context); - } + blas::CallContext context); template tsl::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, @@ -924,53 +908,7 @@ class Stream { const DeviceMemory &b, int ldb, ConstantType beta, DeviceMemory *c, int ldc, blas::ComputePrecision precision, - blas::CallContext context) { - static_assert( - detail::is_any_of, std::complex>(), - "Input can be half, bf16, float, double, std::complex or " - "std::complex"); - static_assert(!std::is_same_v || - detail::is_any_of(), - "If input is Eigen::half, constant has to be either " - "Eigen::half or float"); - static_assert(!std::is_same_v || - detail::is_any_of(), - "If input is Eigen::bfloat16, constant has to be either " - "Eigen::bfloat16 or float"); - static_assert( - detail::is_any_of(), - "If input is not Eigen::half, constant and input types have to match"); - blas::BlasSupport *blas = parent()->AsBlas(); - if (!blas) { - return tsl::errors::Internal( - "Attempting to perform BLAS operation using " - "StreamExecutor without BLAS support"); - } - - void *alpha_ptr = α - void *beta_ptr = β - float alpha_storage, beta_storage; - UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, - &beta_storage); - - return blas->DoBlasGemm(this, transa, transb, m, n, k, - blas::ToDataType::value, alpha_ptr, a, - lda, b, ldb, beta_ptr, c, ldc, precision, - context); - } - - // TODO(parkers): Update all callers to pass kDefaultComputePrecision. - template - tsl::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64 k, ConstantType alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, - ConstantType beta, DeviceMemory *c, - int ldc, blas::CallContext context) { - return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, - ldc, blas::kDefaultComputePrecision, context); - } + blas::CallContext context); template tsl::Status ThenBlasGemmWithAlgorithm( @@ -984,9 +922,9 @@ class Stream { OutputType alpha{1}; OutputType beta{0}; return ThenBlasGemmWithAlgorithm(transa, transb, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc, computation_type, - algorithm, blas::kDefaultComputePrecision, - output_profile_result, context); + ldb, beta, c, ldc, computation_type, + algorithm, blas::kDefaultComputePrecision, + output_profile_result, context); } template @@ -1155,34 +1093,7 @@ class Stream { int64_t stride_a, const DeviceMemory &b, int ldb, int64_t stride_b, ConstantType beta, DeviceMemory *c, int ldc, int64_t stride_c, int batch_count, blas::ComputePrecision precision, - blas::CallContext context) { - static_assert( - detail::is_any_of, std::complex>(), - "Unsupported input type"); - static_assert( - std::is_same_v || - (detail::is_any_of() && - std::is_same_v), - "Mismatched input and alpha/beta types"); - blas::BlasSupport *blas = parent()->AsBlas(); - if (!blas) { - return tsl::errors::Internal( - "Attempting to perform BLAS operation using " - "StreamExecutor without BLAS support"); - } - - void *alpha_ptr = α - void *beta_ptr = β - float alpha_storage, beta_storage; - UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, - &beta_storage); - - return blas->DoBlasGemmStridedBatched( - this, transa, transb, m, n, k, blas::ToDataType::value, - alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc, - stride_c, batch_count, precision, context); - } + blas::CallContext context); // See BlasSupport::DoBlasTrsm. Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,