Skip to content

Commit

Permalink
Add TuningContext for TunableOp (#14557)
Browse files Browse the repository at this point in the history
This makes the the TunableOp tuning results state free and will allow us to
dump and load offline tuning results.
  • Loading branch information
cloudhan authored Feb 10, 2023
1 parent 9a9d45f commit 9bd022b
Show file tree
Hide file tree
Showing 57 changed files with 760 additions and 294 deletions.
1 change: 1 addition & 0 deletions cmake/onnxruntime_kernel_explorer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ if (onnxruntime_USE_CUDA)
"${KERNEL_EXPLORER_ROOT}/kernels/cuda/*.cuh"
)
target_sources(kernel_explorer PRIVATE ${kernel_explorer_cuda_kernel_srcs})
target_include_directories(kernel_explorer PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
elseif (onnxruntime_USE_ROCM)
file(GLOB kernel_explorer_rocm_kernel_srcs CONFIGURE_DEPENDS
"${KERNEL_EXPLORER_ROOT}/kernels/rocm/*.cc"
Expand Down
8 changes: 8 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Node;
#include "core/framework/func_api.h"
#include "core/framework/provider_options.h"
#include "core/framework/stream_handles.h"
#include "core/framework/tuning_context.h"

namespace onnxruntime {

Expand Down Expand Up @@ -300,6 +301,13 @@ class IExecutionProvider {
*/
virtual bool ConcurrentRunSupported() const { return true; }

/**
* Return the tuning context which holds all TunableOp state.
*/
virtual ITuningContext* GetTuningContext() const {
return nullptr;
}

private:
const std::string type_;

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
return LaunchDecoderAttentionKernel(
device_prop,
#ifdef USE_ROCM
IsTunableOpEnabled(),
GetTuningContext(),
#endif
stream,
cublas,
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/rocm/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
// TODO: use custom kernel of expand to improve the performance.
ORT_RETURN_IF_ERROR(blas::column_major::Gemm(
IsTunableOpEnabled(), Stream(context), rocblas,
GetTuningContext(), Stream(context), rocblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
n, m, 1,
/*alpha=*/1.0f,
Expand All @@ -99,7 +99,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {

// result(N, M) = 1 * weights x input + 1 x B.
ORT_RETURN_IF_ERROR(blas::column_major::Gemm(
IsTunableOpEnabled(), Stream(context), rocblas,
GetTuningContext(), Stream(context), rocblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
n, m, k,
/*alpha=*/1.0f,
Expand All @@ -114,7 +114,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
return LaunchAttentionKernel(
device_prop,
IsTunableOpEnabled(),
GetTuningContext(),
Stream(context),
rocblas,
element_size,
Expand Down
28 changes: 14 additions & 14 deletions onnxruntime/contrib_ops/rocm/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ size_t GetAttentionWorkspaceSize(
template <typename T>
Status QkvToContext(
const hipDeviceProp_t& prop,
bool tuning,
RocmTuningContext* tuning_ctx,
rocblas_handle& rocblas,
hipStream_t stream,
const int batch_size,
Expand Down Expand Up @@ -139,7 +139,7 @@ Status QkvToContext(
const int temp_matrix_size = sequence_length * all_sequence_length;

ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::Trans, blas::BlasOp::NonTrans,
all_sequence_length, sequence_length, head_size,
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
Expand Down Expand Up @@ -174,7 +174,7 @@ Status QkvToContext(

// compute P*V (as V*P), and store in scratch3: BxNxSxH
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
head_size, sequence_length, all_sequence_length,
/*alpha=*/1.0f,
Expand All @@ -191,7 +191,7 @@ Status QkvToContext(

Status LaunchAttentionKernel(
const hipDeviceProp_t& prop,
bool tuning,
RocmTuningContext* tuning_ctx,
hipStream_t stream,
rocblas_handle& rocblas,
const size_t element_size,
Expand All @@ -215,7 +215,7 @@ Status LaunchAttentionKernel(
bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax();
if (element_size == 2) {
return QkvToContext(
prop, tuning, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
prop, tuning_ctx, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
reinterpret_cast<const __half*>(input),
reinterpret_cast<__half*>(output),
reinterpret_cast<__half*>(workspace),
Expand All @@ -230,7 +230,7 @@ Status LaunchAttentionKernel(
use_persistent_softmax);
} else {
return QkvToContext(
prop, tuning, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
prop, tuning_ctx, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
reinterpret_cast<const float*>(input),
reinterpret_cast<float*>(output),
reinterpret_cast<float*>(workspace),
Expand All @@ -249,7 +249,7 @@ Status LaunchAttentionKernel(
template <typename T>
Status DecoderQkvToContext(
const hipDeviceProp_t& prop,
bool tuning,
RocmTuningContext* tuning_ctx,
hipStream_t stream,
rocblas_handle& rocblas,
const size_t element_size,
Expand Down Expand Up @@ -352,7 +352,7 @@ Status DecoderQkvToContext(
const int strideB = sequence_length * head_size;
if (use_past && static_kv) {
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::Trans, blas::BlasOp::NonTrans,
kv_sequence_length, sequence_length, head_size,
/*alpha=*/rsqrt_head_size,
Expand All @@ -363,7 +363,7 @@ Status DecoderQkvToContext(
BN));
} else {
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::Trans, blas::BlasOp::NonTrans,
kv_sequence_length, sequence_length, head_size,
/*alpha=*/rsqrt_head_size,
Expand All @@ -386,7 +386,7 @@ Status DecoderQkvToContext(
// compute P*V (as V*P), and store in scratch3: BxNxSxH
if (use_past && static_kv) {
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
head_size, sequence_length, kv_sequence_length,
/*alpha=*/1.0f,
Expand All @@ -397,7 +397,7 @@ Status DecoderQkvToContext(
BN));
} else {
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
head_size, sequence_length, kv_sequence_length,
/*alpha=*/1.0f,
Expand All @@ -415,7 +415,7 @@ Status DecoderQkvToContext(

Status LaunchDecoderAttentionKernel(
const hipDeviceProp_t& prop,
bool tuning,
RocmTuningContext* tuning_ctx,
hipStream_t stream,
rocblas_handle& rocblas,
const size_t element_size,
Expand All @@ -442,7 +442,7 @@ Status LaunchDecoderAttentionKernel(
if (element_size == 2) {
return DecoderQkvToContext(
prop,
tuning,
tuning_ctx,
stream,
rocblas,
element_size,
Expand All @@ -469,7 +469,7 @@ Status LaunchDecoderAttentionKernel(
} else {
return DecoderQkvToContext(
prop,
tuning,
tuning_ctx,
stream,
rocblas,
element_size,
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/rocm/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <hip/hip_fp16.h>
#include <rocblas/rocblas.h>
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -27,7 +28,7 @@ size_t GetAttentionWorkspaceSize(

Status LaunchAttentionKernel(
const hipDeviceProp_t& prop, // Device Properties
bool tuning, // Whether to enable tuning
RocmTuningContext* tuning_ctx, // context for tuning
hipStream_t stream, // Hip stream
rocblas_handle& rocblas, // Rocblas handle
const size_t element_size, // Element size of input tensor
Expand All @@ -50,7 +51,7 @@ Status LaunchAttentionKernel(

Status LaunchDecoderAttentionKernel(
const hipDeviceProp_t& prop, // Device Properties
bool tuning, // Whether to enable tuning
RocmTuningContext* tuning_ctx, // context for tuning
hipStream_t stream, // Hip stream
rocblas_handle& rocblas, // Rocblas handle
const size_t element_size, // Element size of input tensor
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
typedef typename ToHipType<T>::MappedType HipT;

return LaunchFastGeluKernel<HipT>(IsTunableOpEnabled(),
return LaunchFastGeluKernel<HipT>(GetTuningContext(),
Stream(context),
static_cast<int>(input_length),
static_cast<int>(bias_length),
Expand Down
16 changes: 9 additions & 7 deletions onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,27 @@ namespace contrib {
namespace rocm {

template <typename T>
Status LaunchFastGeluKernel(bool tuning, hipStream_t stream, int input_length, int bias_length,
Status LaunchFastGeluKernel(RocmTuningContext* tuning_ctx, hipStream_t stream, int input_length, int bias_length,
const T* input, const T* bias, T* output) {
FastGeluParams<T> params(stream, input, bias, output, input_length, bias_length);
if (tuning) {
FastGeluParams<T> params(tuning_ctx, stream, input, bias, output, input_length, bias_length);
if (tuning_ctx->IsTunableOpEnabled()) {
static FastGeluTunableOp<T> op;
op.EnableTuning();
return op(&params);
}

return FastGeluStaticSelection<T>(&params);
}

template Status LaunchFastGeluKernel<float>(bool tuning, hipStream_t stream, int input_length, int bias_length,
template Status LaunchFastGeluKernel<float>(RocmTuningContext* tuning_ctx, hipStream_t stream,
int input_length, int bias_length,
const float* input, const float* bias, float* output);

template Status LaunchFastGeluKernel<BFloat16>(bool tuning, hipStream_t stream, int input_length, int bias_length,
template Status LaunchFastGeluKernel<BFloat16>(RocmTuningContext* tuning_ctx, hipStream_t stream,
int input_length, int bias_length,
const BFloat16* input, const BFloat16* bias, BFloat16* output);

template Status LaunchFastGeluKernel<half>(bool tuning, hipStream_t stream, int input_length, int bias_length,
template Status LaunchFastGeluKernel<half>(RocmTuningContext* tuning_ctx, hipStream_t stream,
int input_length, int bias_length,
const half* input, const half* bias, half* output);

} // namespace rocm
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"

namespace onnxruntime {
namespace contrib {
namespace rocm {

template <typename T>
Status LaunchFastGeluKernel(bool tuning, hipStream_t stream, int input_length, int bias_length,
Status LaunchFastGeluKernel(RocmTuningContext* tuning_ctx, hipStream_t stream, int input_length, int bias_length,
const T* input, const T* bias, T* output);

} // namespace rocm
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/contrib_ops/rocm/bert/fast_gelu_tunable_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ namespace contrib {
namespace rocm {

template <typename T>
struct FastGeluParams : onnxruntime::rocm::tunable::OpParams {
FastGeluParams(hipStream_t stream, const T* input, const T* bias, T* output, int input_length, int bias_length) :
OpParams(stream), input(input), bias(bias), output(output), input_length(input_length), bias_length(bias_length) {}
struct FastGeluParams : OpParams {
FastGeluParams(RocmTuningContext* tuning_ctx, hipStream_t stream, const T* input, const T* bias, T* output, int input_length, int bias_length) :
OpParams(tuning_ctx, stream), input(input), bias(bias), output(output), input_length(input_length), bias_length(bias_length) {}

std::string Signature() const override {
std::string sig = std::to_string(input_length) + "_" + std::to_string(bias_length);
Expand Down Expand Up @@ -119,7 +119,7 @@ Status FastGeluStaticSelection(const FastGeluParams<half>* params) {
this->RegisterOp(FastGeluOp<T, threads_per_block, 16>{});

template <typename T>
class FastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp<FastGeluParams<T>> {
class FastGeluTunableOp : public TunableOp<FastGeluParams<T>> {
public:
FastGeluTunableOp() {
this->RegisterOp(FastGeluStaticSelection<T>);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Status GemmFastGelu<T>::ComputeInternal(OpKernelContext* ctx) const {
using onnxruntime::rocm::tunable::blas::BlasOp;

return blas::row_major::GemmFastGelu(
IsTunableOpEnabled(),
GetTuningContext(),
Stream(ctx), GetRocblasHandle(ctx),
transa ? BlasOp::Trans : BlasOp::NonTrans,
transb ? BlasOp::Trans : BlasOp::NonTrans,
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h"

using onnxruntime::rocm::ToHipType;
using onnxruntime::rocm::tunable::Op;

namespace onnxruntime {
namespace contrib {
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace rocm {
namespace blas {

template <typename T>
struct GemmFastGeluParams : onnxruntime::rocm::tunable::OpParams {
struct GemmFastGeluParams : OpParams {
std::string Signature() const override {
bool has_bias = (nullptr != bias) ? 0 : 1;
return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k, '_', has_bias);
Expand All @@ -39,7 +39,6 @@ struct GemmFastGeluParams : onnxruntime::rocm::tunable::OpParams {
T beta;
T* c;
int64_t ldc;
bool tuning{false};
};

} // namespace blas
Expand Down
10 changes: 3 additions & 7 deletions onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace row_major {
template <typename T, typename ScalarT>
inline GEMMFASTGELU(T, ScalarT) {
GemmFastGeluParams<T> params;
params.tuning_ctx = tuning_ctx;
params.stream = stream;
params.handle = handle;

Expand All @@ -46,23 +47,18 @@ inline GEMMFASTGELU(T, ScalarT) {
params.c = c;
params.ldc = ldc;

if (tunable) {
params.tuning = true;
if (tuning_ctx->IsTunableOpEnabled()) {
if (opa == BlasOp::N && opb == BlasOp::N) {
static internal::GemmFastGeluTunableOp<T, internal::Row, internal::Row> gemm_fast_gelu{};
gemm_fast_gelu.EnableTuning();
return gemm_fast_gelu(&params);
} else if (opa == BlasOp::T && opb == BlasOp::N) {
static internal::GemmFastGeluTunableOp<T, internal::Col, internal::Row> gemm_fast_gelu{};
gemm_fast_gelu.EnableTuning();
return gemm_fast_gelu(&params);
} else if (opa == BlasOp::N && opb == BlasOp::T) {
static internal::GemmFastGeluTunableOp<T, internal::Row, internal::Col> gemm_fast_gelu{};
gemm_fast_gelu.EnableTuning();
return gemm_fast_gelu(&params);
} else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ {
static internal::GemmFastGeluTunableOp<T, internal::Col, internal::Col> gemm_fast_gelu{};
gemm_fast_gelu.EnableTuning();
return gemm_fast_gelu(&params);
}
}
Expand All @@ -71,7 +67,7 @@ inline GEMMFASTGELU(T, ScalarT) {
}

#define CALL_GEMMFASTGELU(T, ScalarT) \
GemmFastGelu<T, ScalarT>(tunable, stream, handle, \
GemmFastGelu<T, ScalarT>(tuning_ctx, stream, handle, \
opa, opb, \
m, n, k, \
alpha, a, lda, b, ldb, bias, \
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace blas {

#define GEMMFASTGELU(T, ScalarT) \
common::Status GemmFastGelu( \
bool tunable, hipStream_t stream, rocblas_handle handle, \
RocmTuningContext* tuning_ctx, hipStream_t stream, rocblas_handle handle, \
BlasOp opa, BlasOp opb, \
std::int64_t m, std::int64_t n, std::int64_t k, \
ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \
Expand Down
Loading

0 comments on commit 9bd022b

Please sign in to comment.