From 326cf2f5e96ce20dda90b3e3de70b280635644ed Mon Sep 17 00:00:00 2001
From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com>
Date: Mon, 13 Feb 2023 15:56:50 +0800
Subject: [PATCH] [ROCm] add Softmax Tunable Op (#14541)

### Description
Add Softmax Tunable Op, only include blockwise vec implementation and
composable kernel.
Related PR: https://github.com/microsoft/onnxruntime/pull/14475,
https://github.com/microsoft/onnxruntime/pull/14612

---------

Co-authored-by: peixuanzuo <peixuanzuo@linmif39a000004.zvflicr54joexhdgnhvmxrxygg.phxx.internal.cloudapp.net>
---
 cmake/onnxruntime_providers.cmake             |   1 +
 cmake/onnxruntime_rocm_hipify.cmake           |   1 +
 .../core/providers/rocm/math/softmax_ck.cuh   |  88 ++++++++
 .../core/providers/rocm/math/softmax_common.h |  43 ++++
 .../core/providers/rocm/math/softmax_impl.cu  | 132 ++++++------
 .../rocm/math/softmax_tunable_op.cuh          |  76 +++++++
 .../tools/kernel_explorer/kernel_explorer.cc  |   4 +-
 .../kernel_explorer/kernels/rocm/softmax.cu   | 191 ++++++++++++++++++
 .../kernel_explorer/kernels/rocm/softmax.h    |  14 ++
 .../kernel_explorer/kernels/softmax_test.py   | 148 ++++++++++++++
 10 files changed, 625 insertions(+), 73 deletions(-)
 create mode 100644 onnxruntime/core/providers/rocm/math/softmax_ck.cuh
 create mode 100644 onnxruntime/core/providers/rocm/math/softmax_common.h
 create mode 100644 onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh
 create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu
 create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.h
 create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/softmax_test.py

diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake
index 84e429db302c1..0b9faf8849e06 100644
--- a/cmake/onnxruntime_providers.cmake
+++ b/cmake/onnxruntime_providers.cmake
@@ -1460,6 +1460,7 @@ if (onnxruntime_USE_ROCM)
       device_gemm_add_fastgelu_instance
       device_gemm_fastgelu_instance
       device_batched_gemm_instance
+      device_softmax_instance
     )
     target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_COMPOSABLE_KERNEL)
   endif()
diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index ec3726048f09e..f13d95474cd94 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -122,6 +122,7 @@ set(provider_excluded_files
   "math/softmax_impl.cu"
   "math/softmax_warpwise_impl.cuh"
   "math/softmax_common.cc"
+  "math/softmax_common.h"
   "math/softmax.cc"
   "nn/conv.cc"
   "nn/conv.h"
diff --git a/onnxruntime/core/providers/rocm/math/softmax_ck.cuh b/onnxruntime/core/providers/rocm/math/softmax_ck.cuh
new file mode 100644
index 0000000000000..060415990a7cd
--- /dev/null
+++ b/onnxruntime/core/providers/rocm/math/softmax_ck.cuh
@@ -0,0 +1,88 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#ifdef USE_COMPOSABLE_KERNEL
+#include "ck/ck.hpp"
+#include "ck/library/tensor_operation_instance/gpu/softmax.hpp"
+#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
+#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+#endif  // USE_COMPOSABLE_KERNEL
+
+#include "core/providers/rocm/math/softmax_common.h"
+
+namespace onnxruntime {
+namespace rocm {
+
+#ifdef USE_COMPOSABLE_KERNEL
+
+template <typename T>
+struct DataTypeAdaptor {
+  using type = T;
+};
+
+template <>
+struct DataTypeAdaptor<half> {
+  using type = ck::half_t;
+};
+
+template <>
+struct DataTypeAdaptor<BFloat16> {
+  using type = ck::bhalf16_t;
+};
+
+using Nop = ck::tensor_operation::element_wise::PassThrough;
+constexpr int Rank = 4;
+constexpr int NumReduceDim = 1;
+
+template <typename InputT, typename OutputT, typename AccT>
+auto GetCKSoftmaxTypeStringAndOps() {
+  using InDataType = typename DataTypeAdaptor<InputT>::type;
+  using OutDataType = typename DataTypeAdaptor<OutputT>::type;
+  using AccDataType = typename DataTypeAdaptor<AccT>::type;
+  using DeviceSoftmax = ck::tensor_operation::device::
+      DeviceSoftmax<InDataType, AccDataType, OutDataType, Nop, Nop, Rank>;
+  using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceSoftmax>;
+
+  std::vector<std::pair<std::string, tunable::Op<SoftmaxParams<InputT, OutputT>>>> ret;
+  for (auto&& impl : InstanceFactory::GetInstances()) {
+    auto type_string = onnxruntime::MakeString(impl->GetTypeString());
+    auto invoker = impl->MakeInvokerPointer();
+
+    auto ck_softmax_op = [impl = std::move(impl), invoker = std::move(invoker)](const SoftmaxParams<InputT, OutputT>* params) -> Status {
+      AccDataType alpha{1.0f};
+      AccDataType beta{0.0f};
+
+      TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
+          params->is_log_softmax,
+          impl->GetTypeString(), " does not support log softmax");
+      TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
+          impl->GetRank() != Rank || impl->GetNumReduceDim() != NumReduceDim,
+          impl->GetTypeString(), " does not support current Rank or NumReduceDim ", params->Signature());
+
+      std::vector<ck::index_t> in_lengths{1, 1, params->batch_count, params->softmax_elements};
+      std::vector<ck::index_t> in_strides{params->batch_count * params->input_stride, params->batch_count * params->input_stride, params->input_stride, 1};
+      std::vector<ck::index_t> reduce_dims{3};
+
+      auto nop = Nop{};
+      auto arg = impl->MakeArgumentPointer(in_lengths, in_strides, reduce_dims, &alpha, &beta,
+                                           params->input, params->output, nop, nop);
+      TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
+                                                impl->GetTypeString(), " does not support ", params->Signature());
+      invoker->Run(arg.get(), StreamConfig{params->stream});
+      return Status::OK();
+    };
+    ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_softmax_op)));
+  }
+  return ret;
+}
+#endif  // USE_COMPOSABLE_KERNEL
+
+}  // namespace rocm
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/rocm/math/softmax_common.h b/onnxruntime/core/providers/rocm/math/softmax_common.h
new file mode 100644
index 0000000000000..976e7ba69ce4e
--- /dev/null
+++ b/onnxruntime/core/providers/rocm/math/softmax_common.h
@@ -0,0 +1,43 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/status.h"
+#include "core/providers/rocm/miopen_common.h"
+#include "core/providers/rocm/tunable/rocm_tunable.h"
+
+namespace onnxruntime {
+namespace rocm {
+
+template <typename InputT, typename OutputT>
+struct SoftmaxParams : tunable::OpParams {
+  SoftmaxParams(tunable::RocmTuningContext* tuning_ctx, hipStream_t stream, OutputT* output, const InputT* input,
+                int softmax_elements, int input_stride, int output_stride, int batch_count, bool is_log_softmax)
+      : OpParams(tuning_ctx, stream), output(output), input(input), softmax_elements(softmax_elements), input_stride(input_stride), output_stride(output_stride), batch_count(batch_count), is_log_softmax(is_log_softmax) {}
+
+  std::string Signature() const override {
+    std::string sig = std::to_string(batch_count) + "_" + std::to_string(softmax_elements);
+    return sig;
+  }
+
+  OutputT* output;
+  const InputT* input;
+  int softmax_elements;
+  int input_stride;
+  int output_stride;
+  int batch_count;
+  bool is_log_softmax;
+};
+
+Status SoftmaxForward(miopenHandle_t miopen_handle, const void* alpha, const miopenTensorDescriptor_t input_tensor,
+                      const void* input_data, const void* beta, const miopenTensorDescriptor_t output_tensor,
+                      void* output_data);
+
+Status SoftmaxBackward(miopenHandle_t miopen_handle, bool is_log_softmax, const void* alpha,
+                       const miopenTensorDescriptor_t input_tensor, const void* output_data,
+                       const void* output_grad_data, const void* beta, const miopenTensorDescriptor_t output_tensor,
+                       void* input_grad_data);
+
+}  // namespace rocm
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/rocm/math/softmax_impl.cu b/onnxruntime/core/providers/rocm/math/softmax_impl.cu
index 1948878e7bb3f..ad36240926f54 100644
--- a/onnxruntime/core/providers/rocm/math/softmax_impl.cu
+++ b/onnxruntime/core/providers/rocm/math/softmax_impl.cu
@@ -1,18 +1,18 @@
 /**
-* Copyright (c) 2016-present, Facebook, Inc.
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-*     http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
 
 /* Modifications Copyright (c) Microsoft. */
 
@@ -29,8 +29,8 @@
 namespace onnxruntime {
 namespace rocm {
 
-template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
-Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
+template <typename InputT, typename OutputT, typename AccT, bool IsLogSoftmax>
+Status dispatch_warpwise_softmax_forward(hipStream_t stream, OutputT* dst, const InputT* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
   if (softmax_elements == 0) {
     return Status::OK();
   } else {
@@ -51,39 +51,23 @@ Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, cons
     dim3 threads(warp_size, warps_per_block, 1);
     // Launch code would be more elegant if C++ supported FOR CONSTEXPR
     switch (log2_elements) {
-      case 0:  // 1
-        softmax_warp_forward<input_t, output_t, acc_t, 0, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
-        break;
-      case 1:  // 2
-        softmax_warp_forward<input_t, output_t, acc_t, 1, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
-        break;
-      case 2:  // 4
-        softmax_warp_forward<input_t, output_t, acc_t, 2, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
-        break;
-      case 3:  // 8
-        softmax_warp_forward<input_t, output_t, acc_t, 3, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
-        break;
-      case 4:  // 16
-        softmax_warp_forward<input_t, output_t, acc_t, 4, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
-        break;
-      case 5:  // 32
-        softmax_warp_forward<input_t, output_t, acc_t, 5, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
-        break;
-      case 6:  // 64
-        softmax_warp_forward<input_t, output_t, acc_t, 6, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
-        break;
-      case 7:  // 128
-        softmax_warp_forward<input_t, output_t, acc_t, 7, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
-        break;
-      case 8:  // 256
-        softmax_warp_forward<input_t, output_t, acc_t, 8, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
-        break;
-      case 9:  // 512
-        softmax_warp_forward<input_t, output_t, acc_t, 9, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
-        break;
-      case 10:  // 1024
-        softmax_warp_forward<input_t, output_t, acc_t, 10, is_log_softmax><<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
-        break;
+      #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E)                                                         \
+        case L2E:                                                                                      \
+          softmax_warp_forward<InputT, OutputT, AccT, L2E, IsLogSoftmax>                             \
+              <<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count,                      \
+                                                          softmax_elements_stride, softmax_elements);  \
+          break;
+      LAUNCH_SOFTMAX_WARP_FORWARD(0);   // 1
+      LAUNCH_SOFTMAX_WARP_FORWARD(1);   // 2
+      LAUNCH_SOFTMAX_WARP_FORWARD(2);   // 4
+      LAUNCH_SOFTMAX_WARP_FORWARD(3);   // 8
+      LAUNCH_SOFTMAX_WARP_FORWARD(4);   // 16
+      LAUNCH_SOFTMAX_WARP_FORWARD(5);   // 32
+      LAUNCH_SOFTMAX_WARP_FORWARD(6);   // 64
+      LAUNCH_SOFTMAX_WARP_FORWARD(7);   // 128
+      LAUNCH_SOFTMAX_WARP_FORWARD(8);   // 256
+      LAUNCH_SOFTMAX_WARP_FORWARD(9);   // 512
+      LAUNCH_SOFTMAX_WARP_FORWARD(10);  // 1024
       default:
         break;
     }
@@ -91,39 +75,43 @@ Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, cons
   return HIP_CALL(hipGetLastError());
 }
 
-#define SPECIALIZED_SOFTMAX_IMPL(input_t, output_t, acc_t) \
-template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, false>(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); \
-template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, true>(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count);
+#define SPECIALIZED_SOFTMAX_IMPL(InputT, OutputT, AccT)                            \
+  template Status dispatch_warpwise_softmax_forward<InputT, OutputT, AccT, false>( \
+      hipStream_t stream, OutputT * dst, const InputT* src, int softmax_elements,  \
+      int softmax_elements_stride, int batch_count);                               \
+  template Status dispatch_warpwise_softmax_forward<InputT, OutputT, AccT, true>(  \
+      hipStream_t stream, OutputT * dst, const InputT* src, int softmax_elements,  \
+      int softmax_elements_stride, int batch_count);
 
 SPECIALIZED_SOFTMAX_IMPL(float, float, float)
 SPECIALIZED_SOFTMAX_IMPL(half, half, float)
 SPECIALIZED_SOFTMAX_IMPL(double, double, double)
 SPECIALIZED_SOFTMAX_IMPL(BFloat16, BFloat16, float)
 
-template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
-Status dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, const input_t* input, int softmax_elements,
-                                        int input_stride, int output_stride, int batch_count) {
+template <typename InputT, typename OutputT, typename AccT, bool IsLogSoftmax>
+Status dispatch_blockwise_softmax_forward(hipStream_t stream, OutputT* output, const InputT* input, int softmax_elements,
+                                          int input_stride, int output_stride, int batch_count) {
   dim3 grid(batch_count);
-  constexpr int ILP = sizeof(float4) / sizeof(input_t);
+  constexpr int ILP = sizeof(float4) / sizeof(InputT);
   dim3 block = SoftMax_getBlockSize(ILP, softmax_elements);
-  if (is_log_softmax) {
-    softmax_block_forward<ILP, input_t, acc_t, output_t, LogSoftMaxForwardEpilogue>
-    <<<grid, block, block.x * sizeof(acc_t), stream>>>(output, const_cast<input_t*>(input),
-                                                       softmax_elements, input_stride, output_stride);
+  if (IsLogSoftmax) {
+    softmax_block_forward<ILP, InputT, AccT, OutputT, LogSoftMaxForwardEpilogue>
+        <<<grid, block, block.x * sizeof(AccT), stream>>>(output, const_cast<InputT*>(input),
+                                                          softmax_elements, input_stride, output_stride);
   } else {
-    softmax_block_forward<ILP, input_t, acc_t, output_t, SoftMaxForwardEpilogue>
-    <<<grid, block, block.x * sizeof(acc_t), stream>>>(output, const_cast<input_t*>(input),
-                                                       softmax_elements, input_stride, output_stride);
+    softmax_block_forward<ILP, InputT, AccT, OutputT, SoftMaxForwardEpilogue>
+        <<<grid, block, block.x * sizeof(AccT), stream>>>(output, const_cast<InputT*>(input),
+                                                          softmax_elements, input_stride, output_stride);
   }
-  return HIP_CALL(hipGetLastError()); 
+  return HIP_CALL(hipGetLastError());
 }
 
-#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t)                      \
-  template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, false>(    \
-      hipStream_t stream, output_t * output, const input_t* src, int softmax_elements,    \
-      int input_stride, int output_stride, int batch_count);                              \
-  template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, true>(     \
-      hipStream_t stream, output_t * output, const input_t* src, int softmax_elements,    \
+#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(InputT, OutputT, AccT)                      \
+  template Status dispatch_blockwise_softmax_forward<InputT, OutputT, AccT, false>(    \
+      hipStream_t stream, OutputT * output, const InputT* input, int softmax_elements, \
+      int input_stride, int output_stride, int batch_count);                           \
+  template Status dispatch_blockwise_softmax_forward<InputT, OutputT, AccT, true>(     \
+      hipStream_t stream, OutputT * output, const InputT* input, int softmax_elements, \
       int input_stride, int output_stride, int batch_count);
 
 SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float)
@@ -135,5 +123,5 @@ SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(BFloat16, BFloat16, float)
 SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(half, float, float)  // used by BeamSearch op
 #endif
 
-}
-}
+}  // namespace rocm
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh b/onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh
new file mode 100644
index 0000000000000..7347cd2c035b9
--- /dev/null
+++ b/onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh
@@ -0,0 +1,76 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include <hip/hip_runtime.h>
+
+#include "core/providers/rocm/cu_inc/common.cuh"
+#include "core/providers/rocm/math/softmax_ck.cuh"
+#include "core/providers/rocm/math/softmax_common.h"
+#include "core/providers/rocm/math/softmax_warpwise_impl.cuh"
+#include "core/providers/rocm/math/softmax_blockwise_impl.cuh"
+#include "core/providers/rocm/tunable/rocm_tunable.h"
+
+namespace onnxruntime {
+namespace rocm {
+
+template <typename InputT, typename OutputT, typename AccT, int VecSize>
+Status SoftmaxBlockwiseOp(const SoftmaxParams<InputT, OutputT>* params) {
+  dim3 grid(params->batch_count);
+  dim3 block = SoftMax_getBlockSize(VecSize, params->softmax_elements);
+  if (params->is_log_softmax) {
+    softmax_block_forward<VecSize, InputT, AccT, OutputT, LogSoftMaxForwardEpilogue>
+        <<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input),
+                                                                  params->softmax_elements, params->input_stride,
+                                                                  params->output_stride);
+  } else {
+    softmax_block_forward<VecSize, InputT, AccT, OutputT, SoftMaxForwardEpilogue>
+        <<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input),
+                                                                  params->softmax_elements, params->input_stride,
+                                                                  params->output_stride);
+  }
+  return HIP_CALL(hipGetLastError());
+}
+
+template <typename InputT, typename OutputT, typename AccT>
+Status SoftmaxBlockwiseStaticSelection(const SoftmaxParams<InputT, OutputT>* params) {
+  dim3 grid(params->batch_count);
+  constexpr int ILP = sizeof(float4) / sizeof(InputT);
+  dim3 block = SoftMax_getBlockSize(ILP, params->softmax_elements);
+  if (params->is_log_softmax) {
+    softmax_block_forward<ILP, InputT, AccT, OutputT, LogSoftMaxForwardEpilogue>
+        <<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input),
+                                                                  params->softmax_elements, params->input_stride,
+                                                                  params->output_stride);
+  } else {
+    softmax_block_forward<ILP, InputT, AccT, OutputT, SoftMaxForwardEpilogue>
+        <<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input),
+                                                                  params->softmax_elements, params->input_stride,
+                                                                  params->output_stride);
+  }
+  return HIP_CALL(hipGetLastError());
+}
+
+template <typename InputT, typename OutputT, typename AccT>
+class SoftmaxTunableOp : public onnxruntime::rocm::tunable::TunableOp<SoftmaxParams<InputT, OutputT>> {
+ public:
+  SoftmaxTunableOp() {
+    this->RegisterOp(SoftmaxBlockwiseStaticSelection<InputT, OutputT, AccT>);
+    this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 1>);
+    this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 2>);
+    this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 4>);
+    this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 8>);
+    this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 16>);
+
+#ifdef USE_COMPOSABLE_KERNEL
+    for (auto&& [_, op] : GetCKSoftmaxTypeStringAndOps<InputT, OutputT, AccT>()) {
+      ORT_UNUSED_PARAMETER(_);
+      this->RegisterOp(std::move(op));
+    }
+#endif  // USE_COMPOSABLE_KERNEL
+  }
+};
+
+}  // namespace rocm
+}  // namespace onnxruntime
diff --git a/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc b/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc
index 733c3ee5c523b..93508fcb629e4 100644
--- a/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc
+++ b/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc
@@ -7,8 +7,9 @@
 #include "python/tools/kernel_explorer/kernels/vector_add.h"
 #include "python/tools/kernel_explorer/kernels/rocm/fast_gelu.h"
 #include "python/tools/kernel_explorer/kernels/rocm/gemm.h"
-#include "python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h"
 #include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h"
+#include "python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.h"
+#include "python/tools/kernel_explorer/kernels/rocm/softmax.h"
 
 namespace py = pybind11;
 
@@ -24,6 +25,7 @@ PYBIND11_MODULE(_kernel_explorer, m) {
   InitGemm(m);
   InitSkipLayerNorm(m);
   InitGemmFastGelu(m);
+  InitSoftmax(m);
 #endif
 
   m.def("is_composable_kernel_available", []() {
diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu
new file mode 100644
index 0000000000000..8128f73243804
--- /dev/null
+++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu
@@ -0,0 +1,191 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "python/tools/kernel_explorer/kernels/rocm/softmax.h"
+
+#include <hip/hip_fp16.h>
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "core/providers/rocm/math/softmax_ck.cuh"
+#include "core/providers/rocm/math/softmax_tunable_op.cuh"
+#include "core/providers/rocm/shared_inc/accumulation_type.h"
+#include "python/tools/kernel_explorer/device_array.h"
+#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
+
+namespace py = pybind11;
+
+namespace onnxruntime {
+
+template <typename T, int VecSize>
+class SoftmaxBlockwise : public IKernelExplorer {
+ public:
+  SoftmaxBlockwise(DeviceArray& output, DeviceArray& input, int softmax_elements,
+                   int input_stride, int output_stride, int batch_count, bool is_log_softmax)
+      : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
+                softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) {
+    type_string_ = "SoftmaxBlockwise_" + std::to_string(VecSize);
+  }
+
+  void Run() override {
+    ORT_THROW_IF_ERROR((rocm::SoftmaxBlockwiseOp<T, T, rocm::AccumulationType_t<T>, VecSize>(&params_)));
+  }
+
+  std::vector<std::string> ListOps() const {
+    return {type_string_};
+  }
+
+  bool SelectOp(const std::string& name) {
+    Status status = rocm::SoftmaxBlockwiseOp<T, T, rocm::AccumulationType_t<T>, VecSize>(&params_);
+    return status.IsOK() && name == type_string_;
+  }
+
+ private:
+  using ParamsT = rocm::SoftmaxParams<T, T>;
+  ParamsT params_{};
+  std::string type_string_{};
+};
+
+template <typename T>
+class SoftmaxBlockwiseStaticSelection : public IKernelExplorer {
+ public:
+  SoftmaxBlockwiseStaticSelection(DeviceArray& output, DeviceArray& input, int softmax_elements,
+                                  int input_stride, int output_stride, int batch_count, bool is_log_softmax)
+      : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
+                softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) {}
+
+  void Run() override {
+    ORT_THROW_IF_ERROR((rocm::SoftmaxBlockwiseStaticSelection<T, T, rocm::AccumulationType_t<T>>(&params_)));
+  }
+
+  std::vector<std::string> ListOps() const {
+    return {"SoftmaxBlockwiseStaticSelection"};
+  }
+
+  bool SelectOp(const std::string& name) {
+    return name == "SoftmaxBlockwiseStaticSelection";
+  }
+
+ private:
+  using ParamsT = rocm::SoftmaxParams<T, T>;
+  ParamsT params_{};
+};
+
+template <typename T>
+class SoftmaxTunable : public IKernelExplorer {
+ public:
+  SoftmaxTunable(DeviceArray& output, DeviceArray& input, int softmax_elements,
+                 int input_stride, int output_stride, int batch_count, bool is_log_softmax)
+      : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
+                softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) {
+    params_.TuningContext()->EnableTunableOp();
+  }
+
+  void Run() override {
+    ORT_THROW_IF_ERROR(op_(&params_));
+  }
+
+  std::vector<std::string> ListOps() const {
+    return {"SoftmaxTunable"};
+  }
+
+  bool SelectOp(const std::string& name) {
+    return name == "SoftmaxTunable";
+  }
+
+ private:
+  using ParamsT = rocm::SoftmaxParams<T, T>;
+  ParamsT params_{};
+  rocm::SoftmaxTunableOp<T, T, rocm::AccumulationType_t<T>> op_{};
+};
+
+#ifdef USE_COMPOSABLE_KERNEL
+template <typename T>
+class CKSoftmax : public IKernelExplorer {
+ public:
+  CKSoftmax(DeviceArray& output, DeviceArray& input, int softmax_elements,
+            int input_stride, int output_stride, int batch_count, bool is_log_softmax)
+      : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
+                softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) {
+    for (auto&& [type_string, op] : rocm::GetCKSoftmaxTypeStringAndOps<T, T, rocm::AccumulationType_t<T>>()) {
+      type_strings_.emplace_back(std::move(type_string));
+      ops_.emplace_back(std::move(op));
+    }
+  }
+
+  void Run() override {
+    ORT_THROW_IF_ERROR(ops_[selected_op_](&params_));
+  }
+
+  std::vector<std::string> ListOps() const {
+    return type_strings_;
+  }
+
+  bool SelectOp(const std::string& name) {
+    for (size_t i = 0; i < ops_.size(); i++) {
+      if (type_strings_[i] == name) {
+        selected_op_ = i;
+        Status status = ops_[i](&params_);
+        return status.IsOK();
+      }
+    }
+
+    ORT_THROW("Cannot find implementation ", name);
+  }
+
+ private:
+  using ParamsT = rocm::SoftmaxParams<T, T>;
+  using OpT = rocm::tunable::Op<ParamsT>;
+  ParamsT params_{};
+  std::vector<OpT> ops_;
+  std::vector<std::string> type_strings_;
+  size_t selected_op_{};
+};
+#endif  // USE_COMPOSABLE_KERNEL
+
+#define REGISTER_OP(name, type, vec_size)                                    \
+  py::class_<name<type, vec_size>>(m, #name "_" #type "_" #vec_size)         \
+      .def(py::init<DeviceArray&, DeviceArray&, int, int, int, int, bool>()) \
+      .def("SetRepeats", &name<type, vec_size>::SetRepeats)                  \
+      .def("Profile", &name<type, vec_size>::Profile)                        \
+      .def("Run", &name<type, vec_size>::Run)                                \
+      .def("ListOps", &name<type, vec_size>::ListOps)                        \
+      .def("SelectOp", &name<type, vec_size>::SelectOp);
+
+#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type) \
+  REGISTER_OP(name, type, 1)                     \
+  REGISTER_OP(name, type, 2)                     \
+  REGISTER_OP(name, type, 4)                     \
+  REGISTER_OP(name, type, 8)                     \
+  REGISTER_OP(name, type, 16)
+
+#define REGISTER_OP_TYPED(name, type)                                        \
+  py::class_<name<type>>(m, #name "_" #type)                                 \
+      .def(py::init<DeviceArray&, DeviceArray&, int, int, int, int, bool>()) \
+      .def("SetRepeats", &name<type>::SetRepeats)                            \
+      .def("Profile", &name<type>::Profile)                                  \
+      .def("Run", &name<type>::Run)                                          \
+      .def("ListOps", &name<type>::ListOps)                                  \
+      .def("SelectOp", &name<type>::SelectOp);
+
+void InitSoftmax(py::module m) {
+  REGISTER_OP_FOR_ALL_VEC_SIZE(SoftmaxBlockwise, half);
+  REGISTER_OP_FOR_ALL_VEC_SIZE(SoftmaxBlockwise, float);
+
+  REGISTER_OP_TYPED(SoftmaxBlockwiseStaticSelection, half);
+  REGISTER_OP_TYPED(SoftmaxBlockwiseStaticSelection, float);
+
+  REGISTER_OP_TYPED(SoftmaxTunable, half);
+  REGISTER_OP_TYPED(SoftmaxTunable, float);
+
+#ifdef USE_COMPOSABLE_KERNEL
+  REGISTER_OP_TYPED(CKSoftmax, half);
+  REGISTER_OP_TYPED(CKSoftmax, float);
+#endif  // USE_COMPOSABLE_KERNEL
+}
+
+}  // namespace onnxruntime
diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.h
new file mode 100644
index 0000000000000..5ae71614e2c7f
--- /dev/null
+++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.h
@@ -0,0 +1,14 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include <pybind11/pybind11.h>
+
+namespace py = pybind11;
+
+namespace onnxruntime {
+
+void InitSoftmax(py::module m);
+
+}
diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/softmax_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/softmax_test.py
new file mode 100644
index 0000000000000..cd3998a2826cf
--- /dev/null
+++ b/onnxruntime/python/tools/kernel_explorer/kernels/softmax_test.py
@@ -0,0 +1,148 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+
+import re
+import sys
+from dataclasses import dataclass
+from itertools import product
+
+import kernel_explorer as ke
+import numpy as np
+import pytest
+from utils import dtype_to_bytes, dtype_to_suffix
+
+
+def get_test_sizes():
+    batch_count = [1, 8]
+    softmax_elements = [1, 2, 3, 4, 5, 7, 8, 9, 11, 16, 31, 32, 33, 64, 65, 127, 128, 1024, 1025, 2048, 4096]
+    is_log_softmax = [True, False]
+    return product(batch_count, softmax_elements, is_log_softmax)
+
+
+def dtype_to_funcs(dtype):
+    type_map = {
+        "float16": list(filter(lambda x: re.match("Softmax.*_half.*", x), dir(ke))),
+        "float32": list(filter(lambda x: re.match("Softmax.*_float.*", x), dir(ke))),
+    }
+    return type_map[dtype]
+
+
+def softmax(x, is_log_softmax):
+    x = x - np.max(x, axis=-1, keepdims=1)
+    if is_log_softmax:
+        return x - np.log(np.sum(np.exp(x), axis=-1, keepdims=1))
+    return (np.exp(x)) / np.sum(np.exp(x), axis=-1, keepdims=1)
+
+
+def _test_softmax(batch_count, softmax_elements, is_log_softmax, dtype, func):
+    np.random.seed(0)
+    x = np.random.rand(batch_count, softmax_elements).astype(dtype)
+    y = np.random.rand(batch_count, softmax_elements).astype(dtype)
+
+    x_d = ke.DeviceArray(x)
+    y_d = ke.DeviceArray(y)
+    y_ref = softmax(x, is_log_softmax)
+
+    softmax_func = getattr(ke, func)
+    softmax_op = softmax_func(
+        y_d, x_d, softmax_elements, softmax_elements, softmax_elements, batch_count, is_log_softmax
+    )
+    for impl in softmax_op.ListOps():
+        if not softmax_op.SelectOp(impl):
+            continue
+
+        softmax_op.Run()
+        y_d.UpdateHostNumpyArray()
+
+        np.testing.assert_allclose(y_ref, y, rtol=1e-02)
+
+
+dtypes = ["float16", "float32"]
+
+
+@pytest.mark.parametrize("batch_count, softmax_elements, is_log_softmax", get_test_sizes())
+@pytest.mark.parametrize("dtype", dtypes)
+def test_softmax(batch_count, softmax_elements, is_log_softmax, dtype):
+    for f in dtype_to_funcs(dtype):
+        _test_softmax(batch_count, softmax_elements, is_log_softmax, dtype, f)
+
+
+@pytest.mark.parametrize("batch_count, softmax_elements, is_log_softmax", get_test_sizes())
+@pytest.mark.parametrize("dtype", dtypes)
+def test_ck_softmax(batch_count, softmax_elements, is_log_softmax, dtype):
+    ck_f_name = "CKSoftmax" + "_" + dtype_to_suffix(dtype)
+    _test_softmax(batch_count, softmax_elements, is_log_softmax, dtype, ck_f_name)
+
+
+@dataclass
+class SoftmaxMetric(ke.BandwidthMetric):
+    batch_count: int
+    softmax_elements: int
+    is_log_softmax: bool
+
+    def report(self):
+        prefix = f"{self.name:<110} {self.dtype} batch_count={self.batch_count:<4} softmax_elements={self.softmax_elements:<4} is_log_softmax={self.is_log_softmax:<4}"
+        if self.duration > 0:
+            return prefix + f"{self.duration:.2f} us, {self.gbps:.2f} GB/s"
+        return prefix + "not supported"
+
+
+def profile_softmax_func(batch_count, softmax_elements, is_log_softmax, dtype, func):
+    np.random.seed(0)
+    x = np.random.rand(batch_count, softmax_elements).astype(dtype)
+    y = np.random.rand(batch_count, softmax_elements).astype(dtype)
+
+    x_d = ke.DeviceArray(x)
+    y_d = ke.DeviceArray(y)
+
+    softmax_func = getattr(ke, func)
+    softmax_op = softmax_func(
+        y_d, x_d, softmax_elements, softmax_elements, softmax_elements, batch_count, is_log_softmax
+    )
+
+    for impl in softmax_op.ListOps():
+        duration_ms = -1
+        if softmax_op.SelectOp(impl):
+            duration_ms = softmax_op.Profile()
+        total_bytes = 2 * batch_count * softmax_elements * dtype_to_bytes(dtype)
+
+        ke.report(SoftmaxMetric(impl, dtype, duration_ms, total_bytes, batch_count, softmax_elements, is_log_softmax))
+
+
+def profile_with_args(batch_count, softmax_elements, is_log_softmax, dtype, sort):
+    with ke.benchmark(sort):
+        for func in dtype_to_funcs(dtype):
+            profile_softmax_func(batch_count, softmax_elements, is_log_softmax, dtype, func)
+        # ck function
+        ck_f_name = "CKSoftmax" + "_" + dtype_to_suffix(dtype)
+        profile_softmax_func(batch_count, softmax_elements, is_log_softmax, dtype, ck_f_name)
+
+
+profile_size = [(1, 2048), (8, 2048), (65536, 4096)]
+
+
+def profile():
+    for dtype in dtypes:
+        for batch_count, softmax_elements in profile_size:
+            profile_with_args(batch_count, softmax_elements, False, dtype, True)
+            print()
+
+
+if __name__ == "__main__":
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    group = parser.add_argument_group("profile with args")
+    group.add_argument("batch_count", type=int)
+    group.add_argument("softmax_elements", type=int)
+    group.add_argument("is_log_softmax", type=int)
+    group.add_argument("dtype", choices=dtypes)
+    group.add_argument("--sort", action="store_true")
+
+    if len(sys.argv) == 1:
+        profile()
+    else:
+        args = parser.parse_args()
+        profile_with_args(args.batch_count, args.softmax_elements, args.is_log_softmax, args.dtype, args.sort)