From 326cf2f5e96ce20dda90b3e3de70b280635644ed Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Mon, 13 Feb 2023 15:56:50 +0800 Subject: [PATCH] [ROCm] add Softmax Tunable Op (#14541) ### Description Add Softmax Tunable Op, only include blockwise vec implementation and composable kernel. Related PR: https://github.com/microsoft/onnxruntime/pull/14475, https://github.com/microsoft/onnxruntime/pull/14612 --------- Co-authored-by: peixuanzuo <peixuanzuo@linmif39a000004.zvflicr54joexhdgnhvmxrxygg.phxx.internal.cloudapp.net> --- cmake/onnxruntime_providers.cmake | 1 + cmake/onnxruntime_rocm_hipify.cmake | 1 + .../core/providers/rocm/math/softmax_ck.cuh | 88 ++++++++ .../core/providers/rocm/math/softmax_common.h | 43 ++++ .../core/providers/rocm/math/softmax_impl.cu | 132 ++++++------ .../rocm/math/softmax_tunable_op.cuh | 76 +++++++ .../tools/kernel_explorer/kernel_explorer.cc | 4 +- .../kernel_explorer/kernels/rocm/softmax.cu | 191 ++++++++++++++++++ .../kernel_explorer/kernels/rocm/softmax.h | 14 ++ .../kernel_explorer/kernels/softmax_test.py | 148 ++++++++++++++ 10 files changed, 625 insertions(+), 73 deletions(-) create mode 100644 onnxruntime/core/providers/rocm/math/softmax_ck.cuh create mode 100644 onnxruntime/core/providers/rocm/math/softmax_common.h create mode 100644 onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.h create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/softmax_test.py diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 84e429db302c1..0b9faf8849e06 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -1460,6 +1460,7 @@ if (onnxruntime_USE_ROCM) device_gemm_add_fastgelu_instance device_gemm_fastgelu_instance device_batched_gemm_instance + device_softmax_instance ) target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_COMPOSABLE_KERNEL) endif() diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index ec3726048f09e..f13d95474cd94 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -122,6 +122,7 @@ set(provider_excluded_files "math/softmax_impl.cu" "math/softmax_warpwise_impl.cuh" "math/softmax_common.cc" + "math/softmax_common.h" "math/softmax.cc" "nn/conv.cc" "nn/conv.h" diff --git a/onnxruntime/core/providers/rocm/math/softmax_ck.cuh b/onnxruntime/core/providers/rocm/math/softmax_ck.cuh new file mode 100644 index 0000000000000..060415990a7cd --- /dev/null +++ b/onnxruntime/core/providers/rocm/math/softmax_ck.cuh @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include <string> +#include <utility> +#include <vector> + +#ifdef USE_COMPOSABLE_KERNEL +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#endif // USE_COMPOSABLE_KERNEL + +#include "core/providers/rocm/math/softmax_common.h" + +namespace onnxruntime { +namespace rocm { + +#ifdef USE_COMPOSABLE_KERNEL + +template <typename T> +struct DataTypeAdaptor { + using type = T; +}; + +template <> +struct DataTypeAdaptor<half> { + using type = ck::half_t; +}; + +template <> +struct DataTypeAdaptor<BFloat16> { + using type = ck::bhalf16_t; +}; + +using Nop = ck::tensor_operation::element_wise::PassThrough; +constexpr int Rank = 4; +constexpr int NumReduceDim = 1; + +template <typename InputT, typename OutputT, typename AccT> +auto GetCKSoftmaxTypeStringAndOps() { + using InDataType = typename DataTypeAdaptor<InputT>::type; + using OutDataType = typename DataTypeAdaptor<OutputT>::type; + using AccDataType = typename DataTypeAdaptor<AccT>::type; + using DeviceSoftmax = ck::tensor_operation::device:: + DeviceSoftmax<InDataType, AccDataType, OutDataType, Nop, Nop, Rank>; + using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceSoftmax>; + + std::vector<std::pair<std::string, tunable::Op<SoftmaxParams<InputT, OutputT>>>> ret; + for (auto&& impl : InstanceFactory::GetInstances()) { + auto type_string = onnxruntime::MakeString(impl->GetTypeString()); + auto invoker = impl->MakeInvokerPointer(); + + auto ck_softmax_op = [impl = std::move(impl), invoker = std::move(invoker)](const SoftmaxParams<InputT, OutputT>* params) -> Status { + AccDataType alpha{1.0f}; + AccDataType beta{0.0f}; + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->is_log_softmax, + impl->GetTypeString(), " does not support log softmax"); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + impl->GetRank() != Rank || impl->GetNumReduceDim() != NumReduceDim, + impl->GetTypeString(), " does not support current Rank or NumReduceDim ", params->Signature()); + + std::vector<ck::index_t> in_lengths{1, 1, params->batch_count, params->softmax_elements}; + std::vector<ck::index_t> in_strides{params->batch_count * params->input_stride, params->batch_count * params->input_stride, params->input_stride, 1}; + std::vector<ck::index_t> reduce_dims{3}; + + auto nop = Nop{}; + auto arg = impl->MakeArgumentPointer(in_lengths, in_strides, reduce_dims, &alpha, &beta, + params->input, params->output, nop, nop); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support ", params->Signature()); + invoker->Run(arg.get(), StreamConfig{params->stream}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_softmax_op))); + } + return ret; +} +#endif // USE_COMPOSABLE_KERNEL + +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/softmax_common.h b/onnxruntime/core/providers/rocm/math/softmax_common.h new file mode 100644 index 0000000000000..976e7ba69ce4e --- /dev/null +++ b/onnxruntime/core/providers/rocm/math/softmax_common.h @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/status.h" +#include "core/providers/rocm/miopen_common.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace rocm { + +template <typename InputT, typename OutputT> +struct SoftmaxParams : tunable::OpParams { + SoftmaxParams(tunable::RocmTuningContext* tuning_ctx, hipStream_t stream, OutputT* output, const InputT* input, + int softmax_elements, int input_stride, int output_stride, int batch_count, bool is_log_softmax) + : OpParams(tuning_ctx, stream), output(output), input(input), softmax_elements(softmax_elements), input_stride(input_stride), output_stride(output_stride), batch_count(batch_count), is_log_softmax(is_log_softmax) {} + + std::string Signature() const override { + std::string sig = std::to_string(batch_count) + "_" + std::to_string(softmax_elements); + return sig; + } + + OutputT* output; + const InputT* input; + int softmax_elements; + int input_stride; + int output_stride; + int batch_count; + bool is_log_softmax; +}; + +Status SoftmaxForward(miopenHandle_t miopen_handle, const void* alpha, const miopenTensorDescriptor_t input_tensor, + const void* input_data, const void* beta, const miopenTensorDescriptor_t output_tensor, + void* output_data); + +Status SoftmaxBackward(miopenHandle_t miopen_handle, bool is_log_softmax, const void* alpha, + const miopenTensorDescriptor_t input_tensor, const void* output_data, + const void* output_grad_data, const void* beta, const miopenTensorDescriptor_t output_tensor, + void* input_grad_data); + +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/softmax_impl.cu b/onnxruntime/core/providers/rocm/math/softmax_impl.cu index 1948878e7bb3f..ad36240926f54 100644 --- a/onnxruntime/core/providers/rocm/math/softmax_impl.cu +++ b/onnxruntime/core/providers/rocm/math/softmax_impl.cu @@ -1,18 +1,18 @@ /** -* Copyright (c) 2016-present, Facebook, Inc. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ /* Modifications Copyright (c) Microsoft. */ @@ -29,8 +29,8 @@ namespace onnxruntime { namespace rocm { -template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax> -Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) { +template <typename InputT, typename OutputT, typename AccT, bool IsLogSoftmax> +Status dispatch_warpwise_softmax_forward(hipStream_t stream, OutputT* dst, const InputT* src, int softmax_elements, int softmax_elements_stride, int batch_count) { if (softmax_elements == 0) { return Status::OK(); } else { @@ -51,39 +51,23 @@ Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, cons dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { - case 0: // 1 - softmax_warp_forward<input_t, output_t, acc_t, 0, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - softmax_warp_forward<input_t, output_t, acc_t, 1, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - softmax_warp_forward<input_t, output_t, acc_t, 2, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - softmax_warp_forward<input_t, output_t, acc_t, 3, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - softmax_warp_forward<input_t, output_t, acc_t, 4, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - softmax_warp_forward<input_t, output_t, acc_t, 5, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - softmax_warp_forward<input_t, output_t, acc_t, 6, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - softmax_warp_forward<input_t, output_t, acc_t, 7, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - softmax_warp_forward<input_t, output_t, acc_t, 8, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - softmax_warp_forward<input_t, output_t, acc_t, 9, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - softmax_warp_forward<input_t, output_t, acc_t, 10, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; + #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) \ + case L2E: \ + softmax_warp_forward<InputT, OutputT, AccT, L2E, IsLogSoftmax> \ + <<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, \ + softmax_elements_stride, softmax_elements); \ + break; + LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1 + LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2 + LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4 + LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8 + LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16 + LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32 + LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64 + LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128 + LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256 + LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512 + LAUNCH_SOFTMAX_WARP_FORWARD(10); // 1024 default: break; } @@ -91,39 +75,43 @@ Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, cons return HIP_CALL(hipGetLastError()); } -#define SPECIALIZED_SOFTMAX_IMPL(input_t, output_t, acc_t) \ -template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, false>(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); \ -template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, true>(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); +#define SPECIALIZED_SOFTMAX_IMPL(InputT, OutputT, AccT) \ + template Status dispatch_warpwise_softmax_forward<InputT, OutputT, AccT, false>( \ + hipStream_t stream, OutputT * dst, const InputT* src, int softmax_elements, \ + int softmax_elements_stride, int batch_count); \ + template Status dispatch_warpwise_softmax_forward<InputT, OutputT, AccT, true>( \ + hipStream_t stream, OutputT * dst, const InputT* src, int softmax_elements, \ + int softmax_elements_stride, int batch_count); SPECIALIZED_SOFTMAX_IMPL(float, float, float) SPECIALIZED_SOFTMAX_IMPL(half, half, float) SPECIALIZED_SOFTMAX_IMPL(double, double, double) SPECIALIZED_SOFTMAX_IMPL(BFloat16, BFloat16, float) -template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax> -Status dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, const input_t* input, int softmax_elements, - int input_stride, int output_stride, int batch_count) { +template <typename InputT, typename OutputT, typename AccT, bool IsLogSoftmax> +Status dispatch_blockwise_softmax_forward(hipStream_t stream, OutputT* output, const InputT* input, int softmax_elements, + int input_stride, int output_stride, int batch_count) { dim3 grid(batch_count); - constexpr int ILP = sizeof(float4) / sizeof(input_t); + constexpr int ILP = sizeof(float4) / sizeof(InputT); dim3 block = SoftMax_getBlockSize(ILP, softmax_elements); - if (is_log_softmax) { - softmax_block_forward<ILP, input_t, acc_t, output_t, LogSoftMaxForwardEpilogue> - <<<grid, block, block.x * sizeof(acc_t), stream>>>(output, const_cast<input_t*>(input), - softmax_elements, input_stride, output_stride); + if (IsLogSoftmax) { + softmax_block_forward<ILP, InputT, AccT, OutputT, LogSoftMaxForwardEpilogue> + <<<grid, block, block.x * sizeof(AccT), stream>>>(output, const_cast<InputT*>(input), + softmax_elements, input_stride, output_stride); } else { - softmax_block_forward<ILP, input_t, acc_t, output_t, SoftMaxForwardEpilogue> - <<<grid, block, block.x * sizeof(acc_t), stream>>>(output, const_cast<input_t*>(input), - softmax_elements, input_stride, output_stride); + softmax_block_forward<ILP, InputT, AccT, OutputT, SoftMaxForwardEpilogue> + <<<grid, block, block.x * sizeof(AccT), stream>>>(output, const_cast<InputT*>(input), + softmax_elements, input_stride, output_stride); } - return HIP_CALL(hipGetLastError()); + return HIP_CALL(hipGetLastError()); } -#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \ - template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, false>( \ - hipStream_t stream, output_t * output, const input_t* src, int softmax_elements, \ - int input_stride, int output_stride, int batch_count); \ - template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, true>( \ - hipStream_t stream, output_t * output, const input_t* src, int softmax_elements, \ +#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(InputT, OutputT, AccT) \ + template Status dispatch_blockwise_softmax_forward<InputT, OutputT, AccT, false>( \ + hipStream_t stream, OutputT * output, const InputT* input, int softmax_elements, \ + int input_stride, int output_stride, int batch_count); \ + template Status dispatch_blockwise_softmax_forward<InputT, OutputT, AccT, true>( \ + hipStream_t stream, OutputT * output, const InputT* input, int softmax_elements, \ int input_stride, int output_stride, int batch_count); SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float) @@ -135,5 +123,5 @@ SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(BFloat16, BFloat16, float) SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(half, float, float) // used by BeamSearch op #endif -} -} +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh b/onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh new file mode 100644 index 0000000000000..7347cd2c035b9 --- /dev/null +++ b/onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include <hip/hip_runtime.h> + +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/math/softmax_ck.cuh" +#include "core/providers/rocm/math/softmax_common.h" +#include "core/providers/rocm/math/softmax_warpwise_impl.cuh" +#include "core/providers/rocm/math/softmax_blockwise_impl.cuh" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace rocm { + +template <typename InputT, typename OutputT, typename AccT, int VecSize> +Status SoftmaxBlockwiseOp(const SoftmaxParams<InputT, OutputT>* params) { + dim3 grid(params->batch_count); + dim3 block = SoftMax_getBlockSize(VecSize, params->softmax_elements); + if (params->is_log_softmax) { + softmax_block_forward<VecSize, InputT, AccT, OutputT, LogSoftMaxForwardEpilogue> + <<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input), + params->softmax_elements, params->input_stride, + params->output_stride); + } else { + softmax_block_forward<VecSize, InputT, AccT, OutputT, SoftMaxForwardEpilogue> + <<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input), + params->softmax_elements, params->input_stride, + params->output_stride); + } + return HIP_CALL(hipGetLastError()); +} + +template <typename InputT, typename OutputT, typename AccT> +Status SoftmaxBlockwiseStaticSelection(const SoftmaxParams<InputT, OutputT>* params) { + dim3 grid(params->batch_count); + constexpr int ILP = sizeof(float4) / sizeof(InputT); + dim3 block = SoftMax_getBlockSize(ILP, params->softmax_elements); + if (params->is_log_softmax) { + softmax_block_forward<ILP, InputT, AccT, OutputT, LogSoftMaxForwardEpilogue> + <<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input), + params->softmax_elements, params->input_stride, + params->output_stride); + } else { + softmax_block_forward<ILP, InputT, AccT, OutputT, SoftMaxForwardEpilogue> + <<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input), + params->softmax_elements, params->input_stride, + params->output_stride); + } + return HIP_CALL(hipGetLastError()); +} + +template <typename InputT, typename OutputT, typename AccT> +class SoftmaxTunableOp : public onnxruntime::rocm::tunable::TunableOp<SoftmaxParams<InputT, OutputT>> { + public: + SoftmaxTunableOp() { + this->RegisterOp(SoftmaxBlockwiseStaticSelection<InputT, OutputT, AccT>); + this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 1>); + this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 2>); + this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 4>); + this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 8>); + this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 16>); + +#ifdef USE_COMPOSABLE_KERNEL + for (auto&& [_, op] : GetCKSoftmaxTypeStringAndOps<InputT, OutputT, AccT>()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } +#endif // USE_COMPOSABLE_KERNEL + } +}; + +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc b/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc index 733c3ee5c523b..93508fcb629e4 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc +++ b/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc @@ -7,8 +7,9 @@ #include "python/tools/kernel_explorer/kernels/vector_add.h" #include "python/tools/kernel_explorer/kernels/rocm/fast_gelu.h" #include "python/tools/kernel_explorer/kernels/rocm/gemm.h" -#include "python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h" #include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h" +#include "python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h" +#include "python/tools/kernel_explorer/kernels/rocm/softmax.h" namespace py = pybind11; @@ -24,6 +25,7 @@ PYBIND11_MODULE(_kernel_explorer, m) { InitGemm(m); InitSkipLayerNorm(m); InitGemmFastGelu(m); + InitSoftmax(m); #endif m.def("is_composable_kernel_available", []() { diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu new file mode 100644 index 0000000000000..8128f73243804 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu @@ -0,0 +1,191 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "python/tools/kernel_explorer/kernels/rocm/softmax.h" + +#include <hip/hip_fp16.h> +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include <string> +#include <utility> +#include <vector> + +#include "core/providers/rocm/math/softmax_ck.cuh" +#include "core/providers/rocm/math/softmax_tunable_op.cuh" +#include "core/providers/rocm/shared_inc/accumulation_type.h" +#include "python/tools/kernel_explorer/device_array.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" + +namespace py = pybind11; + +namespace onnxruntime { + +template <typename T, int VecSize> +class SoftmaxBlockwise : public IKernelExplorer { + public: + SoftmaxBlockwise(DeviceArray& output, DeviceArray& input, int softmax_elements, + int input_stride, int output_stride, int batch_count, bool is_log_softmax) + : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()), + softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) { + type_string_ = "SoftmaxBlockwise_" + std::to_string(VecSize); + } + + void Run() override { + ORT_THROW_IF_ERROR((rocm::SoftmaxBlockwiseOp<T, T, rocm::AccumulationType_t<T>, VecSize>(¶ms_))); + } + + std::vector<std::string> ListOps() const { + return {type_string_}; + } + + bool SelectOp(const std::string& name) { + Status status = rocm::SoftmaxBlockwiseOp<T, T, rocm::AccumulationType_t<T>, VecSize>(¶ms_); + return status.IsOK() && name == type_string_; + } + + private: + using ParamsT = rocm::SoftmaxParams<T, T>; + ParamsT params_{}; + std::string type_string_{}; +}; + +template <typename T> +class SoftmaxBlockwiseStaticSelection : public IKernelExplorer { + public: + SoftmaxBlockwiseStaticSelection(DeviceArray& output, DeviceArray& input, int softmax_elements, + int input_stride, int output_stride, int batch_count, bool is_log_softmax) + : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()), + softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) {} + + void Run() override { + ORT_THROW_IF_ERROR((rocm::SoftmaxBlockwiseStaticSelection<T, T, rocm::AccumulationType_t<T>>(¶ms_))); + } + + std::vector<std::string> ListOps() const { + return {"SoftmaxBlockwiseStaticSelection"}; + } + + bool SelectOp(const std::string& name) { + return name == "SoftmaxBlockwiseStaticSelection"; + } + + private: + using ParamsT = rocm::SoftmaxParams<T, T>; + ParamsT params_{}; +}; + +template <typename T> +class SoftmaxTunable : public IKernelExplorer { + public: + SoftmaxTunable(DeviceArray& output, DeviceArray& input, int softmax_elements, + int input_stride, int output_stride, int batch_count, bool is_log_softmax) + : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()), + softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) { + params_.TuningContext()->EnableTunableOp(); + } + + void Run() override { + ORT_THROW_IF_ERROR(op_(¶ms_)); + } + + std::vector<std::string> ListOps() const { + return {"SoftmaxTunable"}; + } + + bool SelectOp(const std::string& name) { + return name == "SoftmaxTunable"; + } + + private: + using ParamsT = rocm::SoftmaxParams<T, T>; + ParamsT params_{}; + rocm::SoftmaxTunableOp<T, T, rocm::AccumulationType_t<T>> op_{}; +}; + +#ifdef USE_COMPOSABLE_KERNEL +template <typename T> +class CKSoftmax : public IKernelExplorer { + public: + CKSoftmax(DeviceArray& output, DeviceArray& input, int softmax_elements, + int input_stride, int output_stride, int batch_count, bool is_log_softmax) + : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()), + softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) { + for (auto&& [type_string, op] : rocm::GetCKSoftmaxTypeStringAndOps<T, T, rocm::AccumulationType_t<T>>()) { + type_strings_.emplace_back(std::move(type_string)); + ops_.emplace_back(std::move(op)); + } + } + + void Run() override { + ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); + } + + std::vector<std::string> ListOps() const { + return type_strings_; + } + + bool SelectOp(const std::string& name) { + for (size_t i = 0; i < ops_.size(); i++) { + if (type_strings_[i] == name) { + selected_op_ = i; + Status status = ops_[i](¶ms_); + return status.IsOK(); + } + } + + ORT_THROW("Cannot find implementation ", name); + } + + private: + using ParamsT = rocm::SoftmaxParams<T, T>; + using OpT = rocm::tunable::Op<ParamsT>; + ParamsT params_{}; + std::vector<OpT> ops_; + std::vector<std::string> type_strings_; + size_t selected_op_{}; +}; +#endif // USE_COMPOSABLE_KERNEL + +#define REGISTER_OP(name, type, vec_size) \ + py::class_<name<type, vec_size>>(m, #name "_" #type "_" #vec_size) \ + .def(py::init<DeviceArray&, DeviceArray&, int, int, int, int, bool>()) \ + .def("SetRepeats", &name<type, vec_size>::SetRepeats) \ + .def("Profile", &name<type, vec_size>::Profile) \ + .def("Run", &name<type, vec_size>::Run) \ + .def("ListOps", &name<type, vec_size>::ListOps) \ + .def("SelectOp", &name<type, vec_size>::SelectOp); + +#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type) \ + REGISTER_OP(name, type, 1) \ + REGISTER_OP(name, type, 2) \ + REGISTER_OP(name, type, 4) \ + REGISTER_OP(name, type, 8) \ + REGISTER_OP(name, type, 16) + +#define REGISTER_OP_TYPED(name, type) \ + py::class_<name<type>>(m, #name "_" #type) \ + .def(py::init<DeviceArray&, DeviceArray&, int, int, int, int, bool>()) \ + .def("SetRepeats", &name<type>::SetRepeats) \ + .def("Profile", &name<type>::Profile) \ + .def("Run", &name<type>::Run) \ + .def("ListOps", &name<type>::ListOps) \ + .def("SelectOp", &name<type>::SelectOp); + +void InitSoftmax(py::module m) { + REGISTER_OP_FOR_ALL_VEC_SIZE(SoftmaxBlockwise, half); + REGISTER_OP_FOR_ALL_VEC_SIZE(SoftmaxBlockwise, float); + + REGISTER_OP_TYPED(SoftmaxBlockwiseStaticSelection, half); + REGISTER_OP_TYPED(SoftmaxBlockwiseStaticSelection, float); + + REGISTER_OP_TYPED(SoftmaxTunable, half); + REGISTER_OP_TYPED(SoftmaxTunable, float); + +#ifdef USE_COMPOSABLE_KERNEL + REGISTER_OP_TYPED(CKSoftmax, half); + REGISTER_OP_TYPED(CKSoftmax, float); +#endif // USE_COMPOSABLE_KERNEL +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.h new file mode 100644 index 0000000000000..5ae71614e2c7f --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include <pybind11/pybind11.h> + +namespace py = pybind11; + +namespace onnxruntime { + +void InitSoftmax(py::module m); + +} diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/softmax_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/softmax_test.py new file mode 100644 index 0000000000000..cd3998a2826cf --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/softmax_test.py @@ -0,0 +1,148 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import re +import sys +from dataclasses import dataclass +from itertools import product + +import kernel_explorer as ke +import numpy as np +import pytest +from utils import dtype_to_bytes, dtype_to_suffix + + +def get_test_sizes(): + batch_count = [1, 8] + softmax_elements = [1, 2, 3, 4, 5, 7, 8, 9, 11, 16, 31, 32, 33, 64, 65, 127, 128, 1024, 1025, 2048, 4096] + is_log_softmax = [True, False] + return product(batch_count, softmax_elements, is_log_softmax) + + +def dtype_to_funcs(dtype): + type_map = { + "float16": list(filter(lambda x: re.match("Softmax.*_half.*", x), dir(ke))), + "float32": list(filter(lambda x: re.match("Softmax.*_float.*", x), dir(ke))), + } + return type_map[dtype] + + +def softmax(x, is_log_softmax): + x = x - np.max(x, axis=-1, keepdims=1) + if is_log_softmax: + return x - np.log(np.sum(np.exp(x), axis=-1, keepdims=1)) + return (np.exp(x)) / np.sum(np.exp(x), axis=-1, keepdims=1) + + +def _test_softmax(batch_count, softmax_elements, is_log_softmax, dtype, func): + np.random.seed(0) + x = np.random.rand(batch_count, softmax_elements).astype(dtype) + y = np.random.rand(batch_count, softmax_elements).astype(dtype) + + x_d = ke.DeviceArray(x) + y_d = ke.DeviceArray(y) + y_ref = softmax(x, is_log_softmax) + + softmax_func = getattr(ke, func) + softmax_op = softmax_func( + y_d, x_d, softmax_elements, softmax_elements, softmax_elements, batch_count, is_log_softmax + ) + for impl in softmax_op.ListOps(): + if not softmax_op.SelectOp(impl): + continue + + softmax_op.Run() + y_d.UpdateHostNumpyArray() + + np.testing.assert_allclose(y_ref, y, rtol=1e-02) + + +dtypes = ["float16", "float32"] + + +@pytest.mark.parametrize("batch_count, softmax_elements, is_log_softmax", get_test_sizes()) +@pytest.mark.parametrize("dtype", dtypes) +def test_softmax(batch_count, softmax_elements, is_log_softmax, dtype): + for f in dtype_to_funcs(dtype): + _test_softmax(batch_count, softmax_elements, is_log_softmax, dtype, f) + + +@pytest.mark.parametrize("batch_count, softmax_elements, is_log_softmax", get_test_sizes()) +@pytest.mark.parametrize("dtype", dtypes) +def test_ck_softmax(batch_count, softmax_elements, is_log_softmax, dtype): + ck_f_name = "CKSoftmax" + "_" + dtype_to_suffix(dtype) + _test_softmax(batch_count, softmax_elements, is_log_softmax, dtype, ck_f_name) + + +@dataclass +class SoftmaxMetric(ke.BandwidthMetric): + batch_count: int + softmax_elements: int + is_log_softmax: bool + + def report(self): + prefix = f"{self.name:<110} {self.dtype} batch_count={self.batch_count:<4} softmax_elements={self.softmax_elements:<4} is_log_softmax={self.is_log_softmax:<4}" + if self.duration > 0: + return prefix + f"{self.duration:.2f} us, {self.gbps:.2f} GB/s" + return prefix + "not supported" + + +def profile_softmax_func(batch_count, softmax_elements, is_log_softmax, dtype, func): + np.random.seed(0) + x = np.random.rand(batch_count, softmax_elements).astype(dtype) + y = np.random.rand(batch_count, softmax_elements).astype(dtype) + + x_d = ke.DeviceArray(x) + y_d = ke.DeviceArray(y) + + softmax_func = getattr(ke, func) + softmax_op = softmax_func( + y_d, x_d, softmax_elements, softmax_elements, softmax_elements, batch_count, is_log_softmax + ) + + for impl in softmax_op.ListOps(): + duration_ms = -1 + if softmax_op.SelectOp(impl): + duration_ms = softmax_op.Profile() + total_bytes = 2 * batch_count * softmax_elements * dtype_to_bytes(dtype) + + ke.report(SoftmaxMetric(impl, dtype, duration_ms, total_bytes, batch_count, softmax_elements, is_log_softmax)) + + +def profile_with_args(batch_count, softmax_elements, is_log_softmax, dtype, sort): + with ke.benchmark(sort): + for func in dtype_to_funcs(dtype): + profile_softmax_func(batch_count, softmax_elements, is_log_softmax, dtype, func) + # ck function + ck_f_name = "CKSoftmax" + "_" + dtype_to_suffix(dtype) + profile_softmax_func(batch_count, softmax_elements, is_log_softmax, dtype, ck_f_name) + + +profile_size = [(1, 2048), (8, 2048), (65536, 4096)] + + +def profile(): + for dtype in dtypes: + for batch_count, softmax_elements in profile_size: + profile_with_args(batch_count, softmax_elements, False, dtype, True) + print() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("batch_count", type=int) + group.add_argument("softmax_elements", type=int) + group.add_argument("is_log_softmax", type=int) + group.add_argument("dtype", choices=dtypes) + group.add_argument("--sort", action="store_true") + + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.batch_count, args.softmax_elements, args.is_log_softmax, args.dtype, args.sort)