diff --git a/onnxruntime/core/providers/rocm/math/softmax.cc b/onnxruntime/core/providers/rocm/math/softmax.cc index e1d9aa83f6d08..ecbd918d7f866 100644 --- a/onnxruntime/core/providers/rocm/math/softmax.cc +++ b/onnxruntime/core/providers/rocm/math/softmax.cc @@ -45,6 +45,24 @@ SPECIALIZED_SOFTMAX_HELPER_IMPL(float) // SPECIALIZED_SOFTMAX_HELPER_IMPL(double) SPECIALIZED_SOFTMAX_HELPER_IMPL(MLFloat16) +// miopenSoftmaxForward/Backward doesn't support BFloat16. +#define SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(is_log_softmax) \ + template <> \ + Status SoftMaxComputeHelper(hipStream_t stream, const BFloat16* X, \ + const TensorShape& input_shape, BFloat16* Y, int64_t axis) { \ + typedef typename ToHipType::MappedType HipT; \ + int64_t N = input_shape.SizeToDimension(axis); \ + int64_t D = input_shape.SizeFromDimension(axis); \ + auto Y_data = reinterpret_cast(Y); \ + auto X_data = reinterpret_cast(X); \ + dispatch_warpwise_softmax_forward, is_log_softmax>( \ + stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(N)); \ + return Status::OK(); \ + } + +SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(true) +SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(false) + #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ Softmax, \ @@ -203,6 +221,7 @@ SPECIALIZED_COMPUTE(float) // MIOpen double data type not supported // SPECIALIZED_COMPUTE(double) SPECIALIZED_COMPUTE(MLFloat16) +SPECIALIZED_COMPUTE(BFloat16) } // namespace rocm } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/softmax_impl.cu b/onnxruntime/core/providers/rocm/math/softmax_impl.cu index d892b9fb86371..68f04070f3fb0 100644 --- a/onnxruntime/core/providers/rocm/math/softmax_impl.cu +++ b/onnxruntime/core/providers/rocm/math/softmax_impl.cu @@ -97,6 +97,7 @@ template void dispatch_warpwise_softmax_forward( SPECIALIZED_SOFTMAX_IMPL(float, float, float) SPECIALIZED_SOFTMAX_IMPL(half, half, float) SPECIALIZED_SOFTMAX_IMPL(double, double, double) +SPECIALIZED_SOFTMAX_IMPL(BFloat16, BFloat16, float) template void dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, const input_t* input, int softmax_elements, int softmax_elements_stride, int batch_count) { @@ -119,6 +120,7 @@ template void dispatch_blockwise_softmax_forward SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float) SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(half, half, float) SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(double, double, double) +SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(BFloat16, BFloat16, float) } diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc index 9258beb423ad8..34537897b9f98 100644 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc @@ -379,7 +379,6 @@ Status PrepareForReduce(const Tensor* X, const auto input_dims = input_shape.GetDims(); InlinedShapeVector reduced(rank, false); - prepare_reduce_metadata.output_dims.reserve(input_dims.size()); if (axes.size() > 0) { prepare_reduce_metadata.output_dims = input_shape.AsShapeVector(); for (auto axis : axes) { @@ -393,6 +392,7 @@ Status PrepareForReduce(const Tensor* X, } } else { // no axes provided (i.e.) default axes => reduce on all dims + prepare_reduce_metadata.output_dims.reserve(input_dims.size()); for (auto dim : input_dims) { ORT_ENFORCE(keepdims || dim != 0, "Can't reduce on dim with value of 0 if 'keepdims' is false. " @@ -823,6 +823,111 @@ SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(int64_t) SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(int8_t) SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(uint8_t) +template <> +template <> +Status ReduceKernel::ComputeImpl( + OpKernelContext* ctx, miopenReduceTensorOp_t miopen_reduce_op) const { + typedef typename ToHipType::MappedType HipT; + const Tensor* X = ctx->Input(0); + TensorShapeVector axes; + size_t num_inputs = ctx->InputCount(); + if (num_inputs == 2) { + const Tensor* axes_tensor = ctx->Input(1); + ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor."); + auto nDims = static_cast(axes_tensor->Shape()[0]); + const auto* data = axes_tensor->template Data(); + axes.assign(data, data + nDims); + } else { + axes.assign(axes_.begin(), axes_.end()); + } + + if (axes.empty() && noop_with_empty_axes_) { + auto* Y = ctx->Output(0, X->Shape()); + HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->template MutableData(), X->template Data(), + X->SizeInBytes(), hipMemcpyDeviceToDevice, Stream())); + return Status::OK(); + } + + PrepareReduceMetadata prepare_reduce_metadata; + ORT_RETURN_IF_ERROR(PrepareForReduce(X, keepdims_, axes, prepare_reduce_metadata)); + + Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims); + + int64_t input_count = prepare_reduce_metadata.input_count; + int64_t output_count = prepare_reduce_metadata.output_count; + auto& input_dims_miopen = prepare_reduce_metadata.input_dims_miopen; + auto& output_dims_miopen = prepare_reduce_metadata.output_dims_miopen; + + if (input_count == 0) { + assert(Y->Shape().Size() == 0); + return Status::OK(); + } + + if (input_count == output_count) { + if (Y->template MutableData() != X->template Data()) { + HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->template MutableData(), X->template Data(), + input_count * sizeof(BFloat16), hipMemcpyDeviceToDevice, Stream())); + } + return Status::OK(); + } + + if (fast_reduction_ && !ctx->GetUseDeterministicCompute()) { + int m{}, n{}; + const auto applicable_matrix_reduction = + get_applicable_matrix_reduction(miopen_reduce_op, X->Shape().GetDims(), axes, m, n); + switch (applicable_matrix_reduction) { + case ApplicableMatrixReduction::Rows: { + return reduce_matrix_rows(Stream(), reinterpret_cast(X->template Data()), + reinterpret_cast(Y->template MutableData()), m, n); + } + case ApplicableMatrixReduction::Columns: { + const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size(m, n); + auto buffer = rocm_ep_->GetScratchBuffer(buffer_size_bytes); + return reduce_matrix_columns(Stream(), reinterpret_cast(X->template Data()), + reinterpret_cast(Y->template MutableData()), m, n, buffer.get(), + buffer_size_bytes); + } + default: + break; + } + } + + HIP_RETURN_IF_ERROR(hipMemsetAsync(Y->MutableDataRaw(), 0, Y->SizeInBytes(), Stream())); + + size_t indices_bytes = 0; + size_t workspace_bytes = 0; + MiopenTensor input_tensor; + MiopenTensor output_tensor; + MiopenReduceDescriptor reduce_desc; + + miopenDataType_t miopen_type_X = miopenFloat; + IAllocatorUniquePtr temp_X = GetScratchBuffer(input_count); + Impl_Cast(Stream(), reinterpret_cast(X->template Data()), temp_X.get(), + X->Shape().Size()); + + ORT_RETURN_IF_ERROR(reduce_desc.Set(miopen_reduce_op, miopen_type_X, MIOPEN_REDUCE_TENSOR_NO_INDICES)); + ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_miopen, miopen_type_X)); + ORT_RETURN_IF_ERROR(output_tensor.Set(output_dims_miopen, miopen_type_X)); + MIOPEN_RETURN_IF_ERROR( + miopenGetReductionIndicesSize(MiopenHandle(), reduce_desc, input_tensor, output_tensor, &indices_bytes)); + MIOPEN_RETURN_IF_ERROR( + miopenGetReductionIndicesSize(MiopenHandle(), reduce_desc, input_tensor, output_tensor, &workspace_bytes)); + IAllocatorUniquePtr indices_rocm = GetScratchBuffer(indices_bytes); + IAllocatorUniquePtr workspace_rocm = GetScratchBuffer(workspace_bytes); + + const auto one = Consts::One; + const auto zero = Consts::Zero; + auto temp_Y = GetScratchBuffer(output_count); + MIOPEN_RETURN_IF_ERROR(miopenReduceTensor(MiopenHandle(), reduce_desc, indices_rocm.get(), indices_bytes, + workspace_rocm.get(), workspace_bytes, &one, input_tensor, temp_X.get(), + &zero, output_tensor, temp_Y.get())); + + Impl_Cast(Stream(), temp_Y.get(), reinterpret_cast(Y->template MutableData()), output_count); + + return Status::OK(); +} + namespace ReductionOps { template @@ -880,7 +985,8 @@ template std::unique_ptr ReduceCompute, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // OpSet 14 BuildKernelCreateInfo, @@ -2031,10 +2031,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // OpSet 15 diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index f47d80f9abf7e..43e9ead580377 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -18,7 +18,7 @@ std::vector MakeMLFloat16(const std::initializer_list& input) return output; } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) void TestFloat16(const char* op_name, const std::vector& lhs_dim, const std::initializer_list& lhs_values, const std::vector& rhs_dim, const std::initializer_list& rhs_values, const std::vector& out_dim, @@ -29,7 +29,11 @@ void TestFloat16(const char* op_name, const std::vector& lhs_dim, tester.AddInput("B", rhs_dim, MakeMLFloat16(rhs_values)); tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -39,7 +43,11 @@ void TestFloat16(const char* op_name, const std::vector& lhs_dim, tester.AddInput("B", rhs_dim, MakeBFloat16(rhs_values)); tester.AddOutput("C", out_dim, MakeBFloat16(out_values)); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } @@ -128,7 +136,7 @@ TEST(MathOpTest, Add_float) { test.Run(); #endif -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) TestFloat16("Add", dims, lhs_values, dims, rhs_values, dims, out_values); #endif } @@ -163,7 +171,7 @@ TEST(MathOpTest, Add_Broadcast_Axis) { test.AddOutput("C", dims, out_values); test.Run(OpTester::ExpectResult::kExpectSuccess, ""); -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) TestFloat16("Add", dims, lhs_values, {3, 1}, rhs_values, dims, out_values); #endif } @@ -186,7 +194,7 @@ TEST(MathOpTest, Add_Broadcast_MultidirectionalAB) { {kTensorrtExecutionProvider}); // TensorRT: got C with shape [3, 1] #endif -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) TestFloat16("Add", {3, 1}, lhs_values, {3}, rhs_values, {3, 3}, out_values); #endif } @@ -208,7 +216,7 @@ TEST(MathOpTest, Add_Broadcast_MultidirectionalBA) { {kTensorrtExecutionProvider}); // TensorRT: got C with shape [3, 1] #endif -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) TestFloat16("Add", {3}, lhs_values, {3, 1}, rhs_values, {3, 3}, out_values); #endif } @@ -404,7 +412,7 @@ TEST(MathOpTest, Sub) { test.Run(); #endif -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) TestFloat16("Sub", dims, lhs_values, dims, rhs_values, dims, out_values); #endif } @@ -462,7 +470,7 @@ TEST(MathOpTest, Mul) { test.Run(); #endif -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) TestFloat16("Mul", dims, lhs_values, dims, rhs_values, dims, out_values); #endif } @@ -501,7 +509,7 @@ TEST(MathOpTest, Div) { test.Run(); #endif -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) TestFloat16("Div", dims, lhs_values, dims, rhs_values, dims, out_values); #endif } diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 2411f2a14b42d..e96fbc1c9c929 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -106,9 +106,9 @@ TEST(GemmOpTest, GemmNoTrans_bfloat16) { test.AddOutput("Y", {2, 3}, MakeBFloat16({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f})); std::vector> execution_providers; #ifdef USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); + execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); + execution_providers.push_back(DefaultRocmExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index 862ef375b06cc..6d2afcdc53df7 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -210,9 +210,9 @@ TEST(MathOpTest, MatMul_BFloat16) { test.AddOutput("Y", {2, 3}, MakeBFloat16({10.0f, 10.0f, 10.0f, -10.0f, -10.0f, -10.0f})); std::vector> execution_providers; #ifdef USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); + execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); + execution_providers.push_back(DefaultRocmExecutionProvider()); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 2c36b80ecc561..c0298f74aed55 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -1491,7 +1491,7 @@ TEST(ReductionOpTest, ReduceSum_half_bert) { // Add more UTs for half as needed #endif -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) TEST(ReductionOpTest, ReduceSumBFloat16) { OpTester test("ReduceSum", 14); test.AddAttribute("keepdims", (int64_t)0); @@ -1500,7 +1500,11 @@ TEST(ReductionOpTest, ReduceSumBFloat16) { test.AddInput("axes", {2}, std::vector{0, 1}); test.AddOutput("reduced", {2}, MakeBFloat16({36.0f, 42.0f})); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } #endif