Skip to content

Commit

Permalink
[ROCm] add Softmax Tunable Op (#14541)
Browse files Browse the repository at this point in the history
### Description
Add Softmax Tunable Op, only include blockwise vec implementation and
composable kernel.
Related PR: #14475,
#14612

---------

Co-authored-by: peixuanzuo <peixuanzuo@linmif39a000004.zvflicr54joexhdgnhvmxrxygg.phxx.internal.cloudapp.net>
  • Loading branch information
PeixuanZuo and peixuanzuo authored Feb 13, 2023
1 parent 12d9117 commit 326cf2f
Show file tree
Hide file tree
Showing 10 changed files with 625 additions and 73 deletions.
1 change: 1 addition & 0 deletions cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1460,6 +1460,7 @@ if (onnxruntime_USE_ROCM)
device_gemm_add_fastgelu_instance
device_gemm_fastgelu_instance
device_batched_gemm_instance
device_softmax_instance
)
target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_COMPOSABLE_KERNEL)
endif()
Expand Down
1 change: 1 addition & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ set(provider_excluded_files
"math/softmax_impl.cu"
"math/softmax_warpwise_impl.cuh"
"math/softmax_common.cc"
"math/softmax_common.h"
"math/softmax.cc"
"nn/conv.cc"
"nn/conv.h"
Expand Down
88 changes: 88 additions & 0 deletions onnxruntime/core/providers/rocm/math/softmax_ck.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <string>
#include <utility>
#include <vector>

#ifdef USE_COMPOSABLE_KERNEL
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/softmax.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#endif // USE_COMPOSABLE_KERNEL

#include "core/providers/rocm/math/softmax_common.h"

namespace onnxruntime {
namespace rocm {

#ifdef USE_COMPOSABLE_KERNEL

template <typename T>
struct DataTypeAdaptor {
using type = T;
};

template <>
struct DataTypeAdaptor<half> {
using type = ck::half_t;
};

template <>
struct DataTypeAdaptor<BFloat16> {
using type = ck::bhalf16_t;
};

using Nop = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 4;
constexpr int NumReduceDim = 1;

template <typename InputT, typename OutputT, typename AccT>
auto GetCKSoftmaxTypeStringAndOps() {
using InDataType = typename DataTypeAdaptor<InputT>::type;
using OutDataType = typename DataTypeAdaptor<OutputT>::type;
using AccDataType = typename DataTypeAdaptor<AccT>::type;
using DeviceSoftmax = ck::tensor_operation::device::
DeviceSoftmax<InDataType, AccDataType, OutDataType, Nop, Nop, Rank>;
using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceSoftmax>;

std::vector<std::pair<std::string, tunable::Op<SoftmaxParams<InputT, OutputT>>>> ret;
for (auto&& impl : InstanceFactory::GetInstances()) {
auto type_string = onnxruntime::MakeString(impl->GetTypeString());
auto invoker = impl->MakeInvokerPointer();

auto ck_softmax_op = [impl = std::move(impl), invoker = std::move(invoker)](const SoftmaxParams<InputT, OutputT>* params) -> Status {
AccDataType alpha{1.0f};
AccDataType beta{0.0f};

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->is_log_softmax,
impl->GetTypeString(), " does not support log softmax");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
impl->GetRank() != Rank || impl->GetNumReduceDim() != NumReduceDim,
impl->GetTypeString(), " does not support current Rank or NumReduceDim ", params->Signature());

std::vector<ck::index_t> in_lengths{1, 1, params->batch_count, params->softmax_elements};
std::vector<ck::index_t> in_strides{params->batch_count * params->input_stride, params->batch_count * params->input_stride, params->input_stride, 1};
std::vector<ck::index_t> reduce_dims{3};

auto nop = Nop{};
auto arg = impl->MakeArgumentPointer(in_lengths, in_strides, reduce_dims, &alpha, &beta,
params->input, params->output, nop, nop);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support ", params->Signature());
invoker->Run(arg.get(), StreamConfig{params->stream});
return Status::OK();
};
ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_softmax_op)));
}
return ret;
}
#endif // USE_COMPOSABLE_KERNEL

} // namespace rocm
} // namespace onnxruntime
43 changes: 43 additions & 0 deletions onnxruntime/core/providers/rocm/math/softmax_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/status.h"
#include "core/providers/rocm/miopen_common.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"

namespace onnxruntime {
namespace rocm {

template <typename InputT, typename OutputT>
struct SoftmaxParams : tunable::OpParams {
SoftmaxParams(tunable::RocmTuningContext* tuning_ctx, hipStream_t stream, OutputT* output, const InputT* input,
int softmax_elements, int input_stride, int output_stride, int batch_count, bool is_log_softmax)
: OpParams(tuning_ctx, stream), output(output), input(input), softmax_elements(softmax_elements), input_stride(input_stride), output_stride(output_stride), batch_count(batch_count), is_log_softmax(is_log_softmax) {}

std::string Signature() const override {
std::string sig = std::to_string(batch_count) + "_" + std::to_string(softmax_elements);
return sig;
}

OutputT* output;
const InputT* input;
int softmax_elements;
int input_stride;
int output_stride;
int batch_count;
bool is_log_softmax;
};

Status SoftmaxForward(miopenHandle_t miopen_handle, const void* alpha, const miopenTensorDescriptor_t input_tensor,
const void* input_data, const void* beta, const miopenTensorDescriptor_t output_tensor,
void* output_data);

Status SoftmaxBackward(miopenHandle_t miopen_handle, bool is_log_softmax, const void* alpha,
const miopenTensorDescriptor_t input_tensor, const void* output_data,
const void* output_grad_data, const void* beta, const miopenTensorDescriptor_t output_tensor,
void* input_grad_data);

} // namespace rocm
} // namespace onnxruntime
132 changes: 60 additions & 72 deletions onnxruntime/core/providers/rocm/math/softmax_impl.cu
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/* Modifications Copyright (c) Microsoft. */

Expand All @@ -29,8 +29,8 @@
namespace onnxruntime {
namespace rocm {

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
template <typename InputT, typename OutputT, typename AccT, bool IsLogSoftmax>
Status dispatch_warpwise_softmax_forward(hipStream_t stream, OutputT* dst, const InputT* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
if (softmax_elements == 0) {
return Status::OK();
} else {
Expand All @@ -51,79 +51,67 @@ Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, cons
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
softmax_warp_forward<input_t, output_t, acc_t, 0, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
softmax_warp_forward<input_t, output_t, acc_t, 1, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
softmax_warp_forward<input_t, output_t, acc_t, 2, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
softmax_warp_forward<input_t, output_t, acc_t, 3, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
softmax_warp_forward<input_t, output_t, acc_t, 4, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
softmax_warp_forward<input_t, output_t, acc_t, 5, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
softmax_warp_forward<input_t, output_t, acc_t, 6, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
softmax_warp_forward<input_t, output_t, acc_t, 7, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
softmax_warp_forward<input_t, output_t, acc_t, 8, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
softmax_warp_forward<input_t, output_t, acc_t, 9, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
softmax_warp_forward<input_t, output_t, acc_t, 10, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
#define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) \
case L2E: \
softmax_warp_forward<InputT, OutputT, AccT, L2E, IsLogSoftmax> \
<<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, \
softmax_elements_stride, softmax_elements); \
break;
LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1
LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2
LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4
LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8
LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16
LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32
LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64
LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128
LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256
LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512
LAUNCH_SOFTMAX_WARP_FORWARD(10); // 1024
default:
break;
}
}
return HIP_CALL(hipGetLastError());
}

#define SPECIALIZED_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, false>(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); \
template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, true>(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count);
#define SPECIALIZED_SOFTMAX_IMPL(InputT, OutputT, AccT) \
template Status dispatch_warpwise_softmax_forward<InputT, OutputT, AccT, false>( \
hipStream_t stream, OutputT * dst, const InputT* src, int softmax_elements, \
int softmax_elements_stride, int batch_count); \
template Status dispatch_warpwise_softmax_forward<InputT, OutputT, AccT, true>( \
hipStream_t stream, OutputT * dst, const InputT* src, int softmax_elements, \
int softmax_elements_stride, int batch_count);

SPECIALIZED_SOFTMAX_IMPL(float, float, float)
SPECIALIZED_SOFTMAX_IMPL(half, half, float)
SPECIALIZED_SOFTMAX_IMPL(double, double, double)
SPECIALIZED_SOFTMAX_IMPL(BFloat16, BFloat16, float)

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
Status dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, const input_t* input, int softmax_elements,
int input_stride, int output_stride, int batch_count) {
template <typename InputT, typename OutputT, typename AccT, bool IsLogSoftmax>
Status dispatch_blockwise_softmax_forward(hipStream_t stream, OutputT* output, const InputT* input, int softmax_elements,
int input_stride, int output_stride, int batch_count) {
dim3 grid(batch_count);
constexpr int ILP = sizeof(float4) / sizeof(input_t);
constexpr int ILP = sizeof(float4) / sizeof(InputT);
dim3 block = SoftMax_getBlockSize(ILP, softmax_elements);
if (is_log_softmax) {
softmax_block_forward<ILP, input_t, acc_t, output_t, LogSoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(acc_t), stream>>>(output, const_cast<input_t*>(input),
softmax_elements, input_stride, output_stride);
if (IsLogSoftmax) {
softmax_block_forward<ILP, InputT, AccT, OutputT, LogSoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(AccT), stream>>>(output, const_cast<InputT*>(input),
softmax_elements, input_stride, output_stride);
} else {
softmax_block_forward<ILP, input_t, acc_t, output_t, SoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(acc_t), stream>>>(output, const_cast<input_t*>(input),
softmax_elements, input_stride, output_stride);
softmax_block_forward<ILP, InputT, AccT, OutputT, SoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(AccT), stream>>>(output, const_cast<InputT*>(input),
softmax_elements, input_stride, output_stride);
}
return HIP_CALL(hipGetLastError());
return HIP_CALL(hipGetLastError());
}

#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, false>( \
hipStream_t stream, output_t * output, const input_t* src, int softmax_elements, \
int input_stride, int output_stride, int batch_count); \
template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, true>( \
hipStream_t stream, output_t * output, const input_t* src, int softmax_elements, \
#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(InputT, OutputT, AccT) \
template Status dispatch_blockwise_softmax_forward<InputT, OutputT, AccT, false>( \
hipStream_t stream, OutputT * output, const InputT* input, int softmax_elements, \
int input_stride, int output_stride, int batch_count); \
template Status dispatch_blockwise_softmax_forward<InputT, OutputT, AccT, true>( \
hipStream_t stream, OutputT * output, const InputT* input, int softmax_elements, \
int input_stride, int output_stride, int batch_count);

SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float)
Expand All @@ -135,5 +123,5 @@ SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(BFloat16, BFloat16, float)
SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(half, float, float) // used by BeamSearch op
#endif

}
}
} // namespace rocm
} // namespace onnxruntime
76 changes: 76 additions & 0 deletions onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <hip/hip_runtime.h>

#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/math/softmax_ck.cuh"
#include "core/providers/rocm/math/softmax_common.h"
#include "core/providers/rocm/math/softmax_warpwise_impl.cuh"
#include "core/providers/rocm/math/softmax_blockwise_impl.cuh"
#include "core/providers/rocm/tunable/rocm_tunable.h"

namespace onnxruntime {
namespace rocm {

template <typename InputT, typename OutputT, typename AccT, int VecSize>
Status SoftmaxBlockwiseOp(const SoftmaxParams<InputT, OutputT>* params) {
dim3 grid(params->batch_count);
dim3 block = SoftMax_getBlockSize(VecSize, params->softmax_elements);
if (params->is_log_softmax) {
softmax_block_forward<VecSize, InputT, AccT, OutputT, LogSoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input),
params->softmax_elements, params->input_stride,
params->output_stride);
} else {
softmax_block_forward<VecSize, InputT, AccT, OutputT, SoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input),
params->softmax_elements, params->input_stride,
params->output_stride);
}
return HIP_CALL(hipGetLastError());
}

template <typename InputT, typename OutputT, typename AccT>
Status SoftmaxBlockwiseStaticSelection(const SoftmaxParams<InputT, OutputT>* params) {
dim3 grid(params->batch_count);
constexpr int ILP = sizeof(float4) / sizeof(InputT);
dim3 block = SoftMax_getBlockSize(ILP, params->softmax_elements);
if (params->is_log_softmax) {
softmax_block_forward<ILP, InputT, AccT, OutputT, LogSoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input),
params->softmax_elements, params->input_stride,
params->output_stride);
} else {
softmax_block_forward<ILP, InputT, AccT, OutputT, SoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input),
params->softmax_elements, params->input_stride,
params->output_stride);
}
return HIP_CALL(hipGetLastError());
}

template <typename InputT, typename OutputT, typename AccT>
class SoftmaxTunableOp : public onnxruntime::rocm::tunable::TunableOp<SoftmaxParams<InputT, OutputT>> {
public:
SoftmaxTunableOp() {
this->RegisterOp(SoftmaxBlockwiseStaticSelection<InputT, OutputT, AccT>);
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 1>);
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 2>);
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 4>);
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 8>);
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 16>);

#ifdef USE_COMPOSABLE_KERNEL
for (auto&& [_, op] : GetCKSoftmaxTypeStringAndOps<InputT, OutputT, AccT>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
#endif // USE_COMPOSABLE_KERNEL
}
};

} // namespace rocm
} // namespace onnxruntime
Loading

0 comments on commit 326cf2f

Please sign in to comment.