Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Add hipBLASLt workspace support #17096

Merged
merged 3 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions onnxruntime/core/providers/rocm/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,11 @@ Status MatMul<T>::ComputeInternal(OpKernelContext* ctx) const {
// Bail out early if the output is going to be empty
if (Y->Shape().Size() == 0) return Status::OK();

if (MatMulImpl<T>(this, helper, reinterpret_cast<const T*>(left_X->Data<T>()),
reinterpret_cast<const T*>(right_X->Data<T>()),
reinterpret_cast<T*>(Y->MutableData<T>()),
left_X->Shape(), right_X->Shape(),
transa, transb, trans_batch_a_, trans_batch_b_, alpha_, ctx->GetComputeStream()) != Status::OK()) {
return Status(common::ONNXRUNTIME, common::FAIL, "MatMulImpl failed");
}
return Status::OK();
return MatMulImpl<T>(this, helper, reinterpret_cast<const T*>(left_X->Data<T>()),
reinterpret_cast<const T*>(right_X->Data<T>()),
reinterpret_cast<T*>(Y->MutableData<T>()),
left_X->Shape(), right_X->Shape(),
transa, transb, trans_batch_a_, trans_batch_b_, alpha_, ctx->GetComputeStream());
}

} // namespace rocm
Expand Down
36 changes: 19 additions & 17 deletions onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,26 @@ enum ActivationType {
};

template <typename T>
constexpr hipblasDatatype_t HipBlasDataTypeFor();
constexpr hipblasltDatatype_t HipBlasDataTypeFor();

template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor<float>() {
return HIPBLAS_R_32F;
constexpr hipblasltDatatype_t HipBlasDataTypeFor<float>() {
return HIPBLASLT_R_32F;
}

template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor<half>() {
return HIPBLAS_R_16F;
constexpr hipblasltDatatype_t HipBlasDataTypeFor<half>() {
return HIPBLASLT_R_16F;
}

template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor<BFloat16>() {
return HIPBLAS_R_16B;
constexpr hipblasltDatatype_t HipBlasDataTypeFor<BFloat16>() {
return HIPBLASLT_R_16B;
}

template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor<double>() {
return HIPBLAS_R_64F;
constexpr hipblasltDatatype_t HipBlasDataTypeFor<double>() {
return HIPBLASLT_R_64F;
}

template <typename Layout>
Expand Down Expand Up @@ -104,7 +104,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp

hipblasOperation_t trans_a = MapCKLayoutToHipBlasLt<BLayout>();
hipblasOperation_t trans_b = MapCKLayoutToHipBlasLt<ALayout>();
hipblasDatatype_t in_out_datatype = HipBlasDataTypeFor<T>();
hipblasltDatatype_t in_out_datatype = HipBlasDataTypeFor<T>();
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;

HIPBLASLT_CALL_THROW(hipblaslt_ext::getAllAlgos(handle,
Expand Down Expand Up @@ -149,7 +149,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLASLT_R_32F));

int batch = GetBatchCountFromParams<T>(params);
if (batch > 1) {
Expand Down Expand Up @@ -213,9 +213,11 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status != HIPBLAS_STATUS_SUCCESS, "hipBLASLt find_all: algo not supported, index ", std::to_string(i));
// TODO: support workspace in next PR
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
workspace_size > 0, "hipBLASLt find_all: extra workspace not supported for now.");

IAllocatorUniquePtr<void> workspace_buffer;
if (workspace_size > 0) {
workspace_buffer = params->tuning_ctx->GetScratchBuffer(workspace_size, params->stream);
}

HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmul(op_handle,
matmul,
Expand All @@ -230,9 +232,9 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp
params->c,
mat_c,
&algo_i,
nullptr,
0,
params->stream));
workspace_buffer.get(),
workspace_size,
params->StreamHandle()));

HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescDestroy(matmul));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_a));
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ auto GetRocBlasGemmTypeStringAndOps() {
status == rocblas_status_invalid_size, "Solution ", solution, " not supported: INVALID VALUE.");

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status != rocblas_status_success, "Solution ", solution, " failed.");
status != rocblas_status_success, "Solution ", solution, " failed: ", rocblas_status_to_string(status));

return Status::OK();
};
Expand Down Expand Up @@ -232,7 +232,7 @@ auto GetRocBlasBatchedGemmTypeStringAndOps() {
status == rocblas_status_invalid_size, "Solution ", solution, " not supported: INVALID VALUE.");

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status != rocblas_status_success, "Solution ", solution, " failed.");
status != rocblas_status_success, "Solution ", solution, " failed: ", rocblas_status_to_string(status));

return Status::OK();
};
Expand Down Expand Up @@ -299,7 +299,7 @@ auto GetRocBlasStridedBatchedGemmTypeStringAndOps() {
status == rocblas_status_invalid_size, "Solution ", solution, " not supported: INVALID VALUE.");

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status != rocblas_status_success, "Solution ", solution, " failed.");
status != rocblas_status_success, "Solution ", solution, " failed: ", rocblas_status_to_string(status));

return Status::OK();
};
Expand Down