-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ROCm] add Softmax Tunable Op (#14541)
### Description Add Softmax Tunable Op, only include blockwise vec implementation and composable kernel. Related PR: #14475, #14612 --------- Co-authored-by: peixuanzuo <peixuanzuo@linmif39a000004.zvflicr54joexhdgnhvmxrxygg.phxx.internal.cloudapp.net>
- Loading branch information
1 parent
12d9117
commit 326cf2f
Showing
10 changed files
with
625 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#ifdef USE_COMPOSABLE_KERNEL | ||
#include "ck/ck.hpp" | ||
#include "ck/library/tensor_operation_instance/gpu/softmax.hpp" | ||
#include "ck/tensor_operation/gpu/device/device_softmax.hpp" | ||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" | ||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" | ||
#endif // USE_COMPOSABLE_KERNEL | ||
|
||
#include "core/providers/rocm/math/softmax_common.h" | ||
|
||
namespace onnxruntime { | ||
namespace rocm { | ||
|
||
#ifdef USE_COMPOSABLE_KERNEL | ||
|
||
template <typename T> | ||
struct DataTypeAdaptor { | ||
using type = T; | ||
}; | ||
|
||
template <> | ||
struct DataTypeAdaptor<half> { | ||
using type = ck::half_t; | ||
}; | ||
|
||
template <> | ||
struct DataTypeAdaptor<BFloat16> { | ||
using type = ck::bhalf16_t; | ||
}; | ||
|
||
using Nop = ck::tensor_operation::element_wise::PassThrough; | ||
constexpr int Rank = 4; | ||
constexpr int NumReduceDim = 1; | ||
|
||
template <typename InputT, typename OutputT, typename AccT> | ||
auto GetCKSoftmaxTypeStringAndOps() { | ||
using InDataType = typename DataTypeAdaptor<InputT>::type; | ||
using OutDataType = typename DataTypeAdaptor<OutputT>::type; | ||
using AccDataType = typename DataTypeAdaptor<AccT>::type; | ||
using DeviceSoftmax = ck::tensor_operation::device:: | ||
DeviceSoftmax<InDataType, AccDataType, OutDataType, Nop, Nop, Rank>; | ||
using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceSoftmax>; | ||
|
||
std::vector<std::pair<std::string, tunable::Op<SoftmaxParams<InputT, OutputT>>>> ret; | ||
for (auto&& impl : InstanceFactory::GetInstances()) { | ||
auto type_string = onnxruntime::MakeString(impl->GetTypeString()); | ||
auto invoker = impl->MakeInvokerPointer(); | ||
|
||
auto ck_softmax_op = [impl = std::move(impl), invoker = std::move(invoker)](const SoftmaxParams<InputT, OutputT>* params) -> Status { | ||
AccDataType alpha{1.0f}; | ||
AccDataType beta{0.0f}; | ||
|
||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( | ||
params->is_log_softmax, | ||
impl->GetTypeString(), " does not support log softmax"); | ||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( | ||
impl->GetRank() != Rank || impl->GetNumReduceDim() != NumReduceDim, | ||
impl->GetTypeString(), " does not support current Rank or NumReduceDim ", params->Signature()); | ||
|
||
std::vector<ck::index_t> in_lengths{1, 1, params->batch_count, params->softmax_elements}; | ||
std::vector<ck::index_t> in_strides{params->batch_count * params->input_stride, params->batch_count * params->input_stride, params->input_stride, 1}; | ||
std::vector<ck::index_t> reduce_dims{3}; | ||
|
||
auto nop = Nop{}; | ||
auto arg = impl->MakeArgumentPointer(in_lengths, in_strides, reduce_dims, &alpha, &beta, | ||
params->input, params->output, nop, nop); | ||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), | ||
impl->GetTypeString(), " does not support ", params->Signature()); | ||
invoker->Run(arg.get(), StreamConfig{params->stream}); | ||
return Status::OK(); | ||
}; | ||
ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_softmax_op))); | ||
} | ||
return ret; | ||
} | ||
#endif // USE_COMPOSABLE_KERNEL | ||
|
||
} // namespace rocm | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/common/status.h" | ||
#include "core/providers/rocm/miopen_common.h" | ||
#include "core/providers/rocm/tunable/rocm_tunable.h" | ||
|
||
namespace onnxruntime { | ||
namespace rocm { | ||
|
||
template <typename InputT, typename OutputT> | ||
struct SoftmaxParams : tunable::OpParams { | ||
SoftmaxParams(tunable::RocmTuningContext* tuning_ctx, hipStream_t stream, OutputT* output, const InputT* input, | ||
int softmax_elements, int input_stride, int output_stride, int batch_count, bool is_log_softmax) | ||
: OpParams(tuning_ctx, stream), output(output), input(input), softmax_elements(softmax_elements), input_stride(input_stride), output_stride(output_stride), batch_count(batch_count), is_log_softmax(is_log_softmax) {} | ||
|
||
std::string Signature() const override { | ||
std::string sig = std::to_string(batch_count) + "_" + std::to_string(softmax_elements); | ||
return sig; | ||
} | ||
|
||
OutputT* output; | ||
const InputT* input; | ||
int softmax_elements; | ||
int input_stride; | ||
int output_stride; | ||
int batch_count; | ||
bool is_log_softmax; | ||
}; | ||
|
||
Status SoftmaxForward(miopenHandle_t miopen_handle, const void* alpha, const miopenTensorDescriptor_t input_tensor, | ||
const void* input_data, const void* beta, const miopenTensorDescriptor_t output_tensor, | ||
void* output_data); | ||
|
||
Status SoftmaxBackward(miopenHandle_t miopen_handle, bool is_log_softmax, const void* alpha, | ||
const miopenTensorDescriptor_t input_tensor, const void* output_data, | ||
const void* output_grad_data, const void* beta, const miopenTensorDescriptor_t output_tensor, | ||
void* input_grad_data); | ||
|
||
} // namespace rocm | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include <hip/hip_runtime.h> | ||
|
||
#include "core/providers/rocm/cu_inc/common.cuh" | ||
#include "core/providers/rocm/math/softmax_ck.cuh" | ||
#include "core/providers/rocm/math/softmax_common.h" | ||
#include "core/providers/rocm/math/softmax_warpwise_impl.cuh" | ||
#include "core/providers/rocm/math/softmax_blockwise_impl.cuh" | ||
#include "core/providers/rocm/tunable/rocm_tunable.h" | ||
|
||
namespace onnxruntime { | ||
namespace rocm { | ||
|
||
template <typename InputT, typename OutputT, typename AccT, int VecSize> | ||
Status SoftmaxBlockwiseOp(const SoftmaxParams<InputT, OutputT>* params) { | ||
dim3 grid(params->batch_count); | ||
dim3 block = SoftMax_getBlockSize(VecSize, params->softmax_elements); | ||
if (params->is_log_softmax) { | ||
softmax_block_forward<VecSize, InputT, AccT, OutputT, LogSoftMaxForwardEpilogue> | ||
<<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input), | ||
params->softmax_elements, params->input_stride, | ||
params->output_stride); | ||
} else { | ||
softmax_block_forward<VecSize, InputT, AccT, OutputT, SoftMaxForwardEpilogue> | ||
<<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input), | ||
params->softmax_elements, params->input_stride, | ||
params->output_stride); | ||
} | ||
return HIP_CALL(hipGetLastError()); | ||
} | ||
|
||
template <typename InputT, typename OutputT, typename AccT> | ||
Status SoftmaxBlockwiseStaticSelection(const SoftmaxParams<InputT, OutputT>* params) { | ||
dim3 grid(params->batch_count); | ||
constexpr int ILP = sizeof(float4) / sizeof(InputT); | ||
dim3 block = SoftMax_getBlockSize(ILP, params->softmax_elements); | ||
if (params->is_log_softmax) { | ||
softmax_block_forward<ILP, InputT, AccT, OutputT, LogSoftMaxForwardEpilogue> | ||
<<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input), | ||
params->softmax_elements, params->input_stride, | ||
params->output_stride); | ||
} else { | ||
softmax_block_forward<ILP, InputT, AccT, OutputT, SoftMaxForwardEpilogue> | ||
<<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input), | ||
params->softmax_elements, params->input_stride, | ||
params->output_stride); | ||
} | ||
return HIP_CALL(hipGetLastError()); | ||
} | ||
|
||
template <typename InputT, typename OutputT, typename AccT> | ||
class SoftmaxTunableOp : public onnxruntime::rocm::tunable::TunableOp<SoftmaxParams<InputT, OutputT>> { | ||
public: | ||
SoftmaxTunableOp() { | ||
this->RegisterOp(SoftmaxBlockwiseStaticSelection<InputT, OutputT, AccT>); | ||
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 1>); | ||
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 2>); | ||
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 4>); | ||
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 8>); | ||
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 16>); | ||
|
||
#ifdef USE_COMPOSABLE_KERNEL | ||
for (auto&& [_, op] : GetCKSoftmaxTypeStringAndOps<InputT, OutputT, AccT>()) { | ||
ORT_UNUSED_PARAMETER(_); | ||
this->RegisterOp(std::move(op)); | ||
} | ||
#endif // USE_COMPOSABLE_KERNEL | ||
} | ||
}; | ||
|
||
} // namespace rocm | ||
} // namespace onnxruntime |
Oops, something went wrong.