Skip to content

Commit

Permalink
Merge branch 'main' into tlwu/optimize_sd_3
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Feb 12, 2023
2 parents e1540c3 + 12d9117 commit 89d0dd0
Show file tree
Hide file tree
Showing 66 changed files with 882 additions and 353 deletions.
2 changes: 2 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ The html docs are generated from markdown using Jekyll and published using GitHu

To update the docs, create a Pull Request against the [gh-pages](https://github.com/microsoft/onnxruntime/tree/gh-pages) branch of the [ONNX Runtime repo](https://github.com/microsoft/onnxruntime).

To preview your changes, you can push to the gh-pages branch in your fork and this will publish a staged version of your changes to <github user name>.github.io/onnxruntime/docs.

Once your PR is approved and merged, your changes will be automatically published to https://onnxruntime.ai/docs.

Note: technical reference docs for developers of ONNX Runtime source code can be found [here](https://github.com/microsoft/onnxruntime/docs)
Expand Down
1 change: 1 addition & 0 deletions cmake/onnxruntime_kernel_explorer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ if (onnxruntime_USE_CUDA)
"${KERNEL_EXPLORER_ROOT}/kernels/cuda/*.cuh"
)
target_sources(kernel_explorer PRIVATE ${kernel_explorer_cuda_kernel_srcs})
target_include_directories(kernel_explorer PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
elseif (onnxruntime_USE_ROCM)
file(GLOB kernel_explorer_rocm_kernel_srcs CONFIGURE_DEPENDS
"${KERNEL_EXPLORER_ROOT}/kernels/rocm/*.cc"
Expand Down
8 changes: 8 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Node;
#include "core/framework/func_api.h"
#include "core/framework/provider_options.h"
#include "core/framework/stream_handles.h"
#include "core/framework/tuning_context.h"

namespace onnxruntime {

Expand Down Expand Up @@ -300,6 +301,13 @@ class IExecutionProvider {
*/
virtual bool ConcurrentRunSupported() const { return true; }

/**
* Return the tuning context which holds all TunableOp state.
*/
virtual ITuningContext* GetTuningContext() const {
return nullptr;
}

private:
const std::string type_;

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.

#include <cassert>
#include <cuda_fp16.h>
#include <cub/cub.cuh>
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
return LaunchDecoderAttentionKernel(
device_prop,
#ifdef USE_ROCM
IsTunableOpEnabled(),
GetTuningContext(),
#endif
stream,
cublas,
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/rocm/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
// TODO: use custom kernel of expand to improve the performance.
ORT_RETURN_IF_ERROR(blas::column_major::Gemm(
IsTunableOpEnabled(), Stream(context), rocblas,
GetTuningContext(), Stream(context), rocblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
n, m, 1,
/*alpha=*/1.0f,
Expand All @@ -99,7 +99,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {

// result(N, M) = 1 * weights x input + 1 x B.
ORT_RETURN_IF_ERROR(blas::column_major::Gemm(
IsTunableOpEnabled(), Stream(context), rocblas,
GetTuningContext(), Stream(context), rocblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
n, m, k,
/*alpha=*/1.0f,
Expand All @@ -114,7 +114,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
return LaunchAttentionKernel(
device_prop,
IsTunableOpEnabled(),
GetTuningContext(),
Stream(context),
rocblas,
element_size,
Expand Down
28 changes: 14 additions & 14 deletions onnxruntime/contrib_ops/rocm/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ size_t GetAttentionWorkspaceSize(
template <typename T>
Status QkvToContext(
const hipDeviceProp_t& prop,
bool tuning,
RocmTuningContext* tuning_ctx,
rocblas_handle& rocblas,
hipStream_t stream,
const int batch_size,
Expand Down Expand Up @@ -139,7 +139,7 @@ Status QkvToContext(
const int temp_matrix_size = sequence_length * all_sequence_length;

ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::Trans, blas::BlasOp::NonTrans,
all_sequence_length, sequence_length, head_size,
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
Expand Down Expand Up @@ -174,7 +174,7 @@ Status QkvToContext(

// compute P*V (as V*P), and store in scratch3: BxNxSxH
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
head_size, sequence_length, all_sequence_length,
/*alpha=*/1.0f,
Expand All @@ -191,7 +191,7 @@ Status QkvToContext(

Status LaunchAttentionKernel(
const hipDeviceProp_t& prop,
bool tuning,
RocmTuningContext* tuning_ctx,
hipStream_t stream,
rocblas_handle& rocblas,
const size_t element_size,
Expand All @@ -215,7 +215,7 @@ Status LaunchAttentionKernel(
bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax();
if (element_size == 2) {
return QkvToContext(
prop, tuning, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
prop, tuning_ctx, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
reinterpret_cast<const __half*>(input),
reinterpret_cast<__half*>(output),
reinterpret_cast<__half*>(workspace),
Expand All @@ -230,7 +230,7 @@ Status LaunchAttentionKernel(
use_persistent_softmax);
} else {
return QkvToContext(
prop, tuning, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
prop, tuning_ctx, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
reinterpret_cast<const float*>(input),
reinterpret_cast<float*>(output),
reinterpret_cast<float*>(workspace),
Expand All @@ -249,7 +249,7 @@ Status LaunchAttentionKernel(
template <typename T>
Status DecoderQkvToContext(
const hipDeviceProp_t& prop,
bool tuning,
RocmTuningContext* tuning_ctx,
hipStream_t stream,
rocblas_handle& rocblas,
const size_t element_size,
Expand Down Expand Up @@ -352,7 +352,7 @@ Status DecoderQkvToContext(
const int strideB = sequence_length * head_size;
if (use_past && static_kv) {
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::Trans, blas::BlasOp::NonTrans,
kv_sequence_length, sequence_length, head_size,
/*alpha=*/rsqrt_head_size,
Expand All @@ -363,7 +363,7 @@ Status DecoderQkvToContext(
BN));
} else {
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::Trans, blas::BlasOp::NonTrans,
kv_sequence_length, sequence_length, head_size,
/*alpha=*/rsqrt_head_size,
Expand All @@ -386,7 +386,7 @@ Status DecoderQkvToContext(
// compute P*V (as V*P), and store in scratch3: BxNxSxH
if (use_past && static_kv) {
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
head_size, sequence_length, kv_sequence_length,
/*alpha=*/1.0f,
Expand All @@ -397,7 +397,7 @@ Status DecoderQkvToContext(
BN));
} else {
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
head_size, sequence_length, kv_sequence_length,
/*alpha=*/1.0f,
Expand All @@ -415,7 +415,7 @@ Status DecoderQkvToContext(

Status LaunchDecoderAttentionKernel(
const hipDeviceProp_t& prop,
bool tuning,
RocmTuningContext* tuning_ctx,
hipStream_t stream,
rocblas_handle& rocblas,
const size_t element_size,
Expand All @@ -442,7 +442,7 @@ Status LaunchDecoderAttentionKernel(
if (element_size == 2) {
return DecoderQkvToContext(
prop,
tuning,
tuning_ctx,
stream,
rocblas,
element_size,
Expand All @@ -469,7 +469,7 @@ Status LaunchDecoderAttentionKernel(
} else {
return DecoderQkvToContext(
prop,
tuning,
tuning_ctx,
stream,
rocblas,
element_size,
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/rocm/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <hip/hip_fp16.h>
#include <rocblas/rocblas.h>
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -27,7 +28,7 @@ size_t GetAttentionWorkspaceSize(

Status LaunchAttentionKernel(
const hipDeviceProp_t& prop, // Device Properties
bool tuning, // Whether to enable tuning
RocmTuningContext* tuning_ctx, // context for tuning
hipStream_t stream, // Hip stream
rocblas_handle& rocblas, // Rocblas handle
const size_t element_size, // Element size of input tensor
Expand All @@ -50,7 +51,7 @@ Status LaunchAttentionKernel(

Status LaunchDecoderAttentionKernel(
const hipDeviceProp_t& prop, // Device Properties
bool tuning, // Whether to enable tuning
RocmTuningContext* tuning_ctx, // context for tuning
hipStream_t stream, // Hip stream
rocblas_handle& rocblas, // Rocblas handle
const size_t element_size, // Element size of input tensor
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
typedef typename ToHipType<T>::MappedType HipT;

return LaunchFastGeluKernel<HipT>(IsTunableOpEnabled(),
return LaunchFastGeluKernel<HipT>(GetTuningContext(),
Stream(context),
static_cast<int>(input_length),
static_cast<int>(bias_length),
Expand Down
16 changes: 9 additions & 7 deletions onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,27 @@ namespace contrib {
namespace rocm {

template <typename T>
Status LaunchFastGeluKernel(bool tuning, hipStream_t stream, int input_length, int bias_length,
Status LaunchFastGeluKernel(RocmTuningContext* tuning_ctx, hipStream_t stream, int input_length, int bias_length,
const T* input, const T* bias, T* output) {
FastGeluParams<T> params(stream, input, bias, output, input_length, bias_length);
if (tuning) {
FastGeluParams<T> params(tuning_ctx, stream, input, bias, output, input_length, bias_length);
if (tuning_ctx->IsTunableOpEnabled()) {
static FastGeluTunableOp<T> op;
op.EnableTuning();
return op(&params);
}

return FastGeluStaticSelection<T>(&params);
}

template Status LaunchFastGeluKernel<float>(bool tuning, hipStream_t stream, int input_length, int bias_length,
template Status LaunchFastGeluKernel<float>(RocmTuningContext* tuning_ctx, hipStream_t stream,
int input_length, int bias_length,
const float* input, const float* bias, float* output);

template Status LaunchFastGeluKernel<BFloat16>(bool tuning, hipStream_t stream, int input_length, int bias_length,
template Status LaunchFastGeluKernel<BFloat16>(RocmTuningContext* tuning_ctx, hipStream_t stream,
int input_length, int bias_length,
const BFloat16* input, const BFloat16* bias, BFloat16* output);

template Status LaunchFastGeluKernel<half>(bool tuning, hipStream_t stream, int input_length, int bias_length,
template Status LaunchFastGeluKernel<half>(RocmTuningContext* tuning_ctx, hipStream_t stream,
int input_length, int bias_length,
const half* input, const half* bias, half* output);

} // namespace rocm
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"

namespace onnxruntime {
namespace contrib {
namespace rocm {

template <typename T>
Status LaunchFastGeluKernel(bool tuning, hipStream_t stream, int input_length, int bias_length,
Status LaunchFastGeluKernel(RocmTuningContext* tuning_ctx, hipStream_t stream, int input_length, int bias_length,
const T* input, const T* bias, T* output);

} // namespace rocm
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/contrib_ops/rocm/bert/fast_gelu_tunable_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ namespace contrib {
namespace rocm {

template <typename T>
struct FastGeluParams : onnxruntime::rocm::tunable::OpParams {
FastGeluParams(hipStream_t stream, const T* input, const T* bias, T* output, int input_length, int bias_length) :
OpParams(stream), input(input), bias(bias), output(output), input_length(input_length), bias_length(bias_length) {}
struct FastGeluParams : OpParams {
FastGeluParams(RocmTuningContext* tuning_ctx, hipStream_t stream, const T* input, const T* bias, T* output, int input_length, int bias_length) :
OpParams(tuning_ctx, stream), input(input), bias(bias), output(output), input_length(input_length), bias_length(bias_length) {}

std::string Signature() const override {
std::string sig = std::to_string(input_length) + "_" + std::to_string(bias_length);
Expand Down Expand Up @@ -119,7 +119,7 @@ Status FastGeluStaticSelection(const FastGeluParams<half>* params) {
this->RegisterOp(FastGeluOp<T, threads_per_block, 16>{});

template <typename T>
class FastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp<FastGeluParams<T>> {
class FastGeluTunableOp : public TunableOp<FastGeluParams<T>> {
public:
FastGeluTunableOp() {
this->RegisterOp(FastGeluStaticSelection<T>);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Status GemmFastGelu<T>::ComputeInternal(OpKernelContext* ctx) const {
using onnxruntime::rocm::tunable::blas::BlasOp;

return blas::row_major::GemmFastGelu(
IsTunableOpEnabled(),
GetTuningContext(),
Stream(ctx), GetRocblasHandle(ctx),
transa ? BlasOp::Trans : BlasOp::NonTrans,
transb ? BlasOp::Trans : BlasOp::NonTrans,
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h"

using onnxruntime::rocm::ToHipType;
using onnxruntime::rocm::tunable::Op;

namespace onnxruntime {
namespace contrib {
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace rocm {
namespace blas {

template <typename T>
struct GemmFastGeluParams : onnxruntime::rocm::tunable::OpParams {
struct GemmFastGeluParams : OpParams {
std::string Signature() const override {
bool has_bias = (nullptr != bias) ? 0 : 1;
return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k, '_', has_bias);
Expand All @@ -39,7 +39,6 @@ struct GemmFastGeluParams : onnxruntime::rocm::tunable::OpParams {
T beta;
T* c;
int64_t ldc;
bool tuning{false};
};

} // namespace blas
Expand Down
10 changes: 3 additions & 7 deletions onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace row_major {
template <typename T, typename ScalarT>
inline GEMMFASTGELU(T, ScalarT) {
GemmFastGeluParams<T> params;
params.tuning_ctx = tuning_ctx;
params.stream = stream;
params.handle = handle;

Expand All @@ -46,23 +47,18 @@ inline GEMMFASTGELU(T, ScalarT) {
params.c = c;
params.ldc = ldc;

if (tunable) {
params.tuning = true;
if (tuning_ctx->IsTunableOpEnabled()) {
if (opa == BlasOp::N && opb == BlasOp::N) {
static internal::GemmFastGeluTunableOp<T, internal::Row, internal::Row> gemm_fast_gelu{};
gemm_fast_gelu.EnableTuning();
return gemm_fast_gelu(&params);
} else if (opa == BlasOp::T && opb == BlasOp::N) {
static internal::GemmFastGeluTunableOp<T, internal::Col, internal::Row> gemm_fast_gelu{};
gemm_fast_gelu.EnableTuning();
return gemm_fast_gelu(&params);
} else if (opa == BlasOp::N && opb == BlasOp::T) {
static internal::GemmFastGeluTunableOp<T, internal::Row, internal::Col> gemm_fast_gelu{};
gemm_fast_gelu.EnableTuning();
return gemm_fast_gelu(&params);
} else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ {
static internal::GemmFastGeluTunableOp<T, internal::Col, internal::Col> gemm_fast_gelu{};
gemm_fast_gelu.EnableTuning();
return gemm_fast_gelu(&params);
}
}
Expand All @@ -71,7 +67,7 @@ inline GEMMFASTGELU(T, ScalarT) {
}

#define CALL_GEMMFASTGELU(T, ScalarT) \
GemmFastGelu<T, ScalarT>(tunable, stream, handle, \
GemmFastGelu<T, ScalarT>(tuning_ctx, stream, handle, \
opa, opb, \
m, n, k, \
alpha, a, lda, b, ldb, bias, \
Expand Down
Loading

0 comments on commit 89d0dd0

Please sign in to comment.