Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] BFloat16 support #10416

Merged
merged 6 commits into from
Jan 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions onnxruntime/core/providers/rocm/math/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,24 @@ SPECIALIZED_SOFTMAX_HELPER_IMPL(float)
// SPECIALIZED_SOFTMAX_HELPER_IMPL(double)
SPECIALIZED_SOFTMAX_HELPER_IMPL(MLFloat16)

// miopenSoftmaxForward/Backward doesn't support BFloat16.
#define SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(is_log_softmax) \
template <> \
Status SoftMaxComputeHelper<BFloat16, is_log_softmax>(hipStream_t stream, const BFloat16* X, \
const TensorShape& input_shape, BFloat16* Y, int64_t axis) { \
typedef typename ToHipType<BFloat16>::MappedType HipT; \
int64_t N = input_shape.SizeToDimension(axis); \
int64_t D = input_shape.SizeFromDimension(axis); \
auto Y_data = reinterpret_cast<HipT*>(Y); \
auto X_data = reinterpret_cast<const HipT*>(X); \
dispatch_warpwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>( \
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N)); \
return Status::OK(); \
}

SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(true)
SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(false)

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Softmax, \
Expand Down Expand Up @@ -203,6 +221,7 @@ SPECIALIZED_COMPUTE(float)
// MIOpen double data type not supported
// SPECIALIZED_COMPUTE(double)
SPECIALIZED_COMPUTE(MLFloat16)
SPECIALIZED_COMPUTE(BFloat16)

} // namespace rocm
} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/rocm/math/softmax_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ template void dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, true>(
SPECIALIZED_SOFTMAX_IMPL(float, float, float)
SPECIALIZED_SOFTMAX_IMPL(half, half, float)
SPECIALIZED_SOFTMAX_IMPL(double, double, double)
SPECIALIZED_SOFTMAX_IMPL(BFloat16, BFloat16, float)

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, const input_t* input, int softmax_elements, int softmax_elements_stride, int batch_count) {
Expand All @@ -119,6 +120,7 @@ template void dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, true>
SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float)
SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(half, half, float)
SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(double, double, double)
SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(BFloat16, BFloat16, float)


}
Expand Down
111 changes: 109 additions & 2 deletions onnxruntime/core/providers/rocm/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,6 @@ Status PrepareForReduce(const Tensor* X,

const auto input_dims = input_shape.GetDims();
InlinedShapeVector<bool> reduced(rank, false);
prepare_reduce_metadata.output_dims.reserve(input_dims.size());
if (axes.size() > 0) {
prepare_reduce_metadata.output_dims = input_shape.AsShapeVector();
for (auto axis : axes) {
Expand All @@ -393,6 +392,7 @@ Status PrepareForReduce(const Tensor* X,
}
} else {
// no axes provided (i.e.) default axes => reduce on all dims
prepare_reduce_metadata.output_dims.reserve(input_dims.size());
for (auto dim : input_dims) {
ORT_ENFORCE(keepdims || dim != 0,
"Can't reduce on dim with value of 0 if 'keepdims' is false. "
Expand Down Expand Up @@ -823,6 +823,111 @@ SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(int64_t)
SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(int8_t)
SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(uint8_t)

template <>
template <>
Status ReduceKernel<true>::ComputeImpl<BFloat16, MIOPEN_REDUCE_TENSOR_NO_INDICES>(
Copy link
Contributor

@centwang centwang Jan 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a new func for BFloat16? Can the default one take the job? Since the default one can handle MLFloat16, and BFloat16 is just similar to MFloat16. #WontFix

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting lib loading error if I remove the block, will leave it as is for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read the code again. Current BFloat16 reduce will cast data to float for the calculation and cast back again at the end. I think cudnn should already support BFloat16 directly since CUDA11, for better perf we sould go to the default ComputeImpl, but to make it work we need to fix some more places. Let me open a new PR to do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, ok, thanks a lot

OpKernelContext* ctx, miopenReduceTensorOp_t miopen_reduce_op) const {
typedef typename ToHipType<BFloat16>::MappedType HipT;
const Tensor* X = ctx->Input<Tensor>(0);
TensorShapeVector axes;
size_t num_inputs = ctx->InputCount();
if (num_inputs == 2) {
const Tensor* axes_tensor = ctx->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor.");
auto nDims = static_cast<size_t>(axes_tensor->Shape()[0]);
const auto* data = axes_tensor->template Data<int64_t>();
axes.assign(data, data + nDims);
} else {
axes.assign(axes_.begin(), axes_.end());
}

if (axes.empty() && noop_with_empty_axes_) {
auto* Y = ctx->Output(0, X->Shape());
HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->template MutableData<BFloat16>(), X->template Data<BFloat16>(),
X->SizeInBytes(), hipMemcpyDeviceToDevice, Stream()));
return Status::OK();
}

PrepareReduceMetadata prepare_reduce_metadata;
ORT_RETURN_IF_ERROR(PrepareForReduce(X, keepdims_, axes, prepare_reduce_metadata));

Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims);

int64_t input_count = prepare_reduce_metadata.input_count;
int64_t output_count = prepare_reduce_metadata.output_count;
auto& input_dims_miopen = prepare_reduce_metadata.input_dims_miopen;
auto& output_dims_miopen = prepare_reduce_metadata.output_dims_miopen;

if (input_count == 0) {
assert(Y->Shape().Size() == 0);
return Status::OK();
}

if (input_count == output_count) {
if (Y->template MutableData<BFloat16>() != X->template Data<BFloat16>()) {
HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->template MutableData<BFloat16>(), X->template Data<BFloat16>(),
input_count * sizeof(BFloat16), hipMemcpyDeviceToDevice, Stream()));
}
return Status::OK();
}

if (fast_reduction_ && !ctx->GetUseDeterministicCompute()) {
int m{}, n{};
const auto applicable_matrix_reduction =
get_applicable_matrix_reduction(miopen_reduce_op, X->Shape().GetDims(), axes, m, n);
switch (applicable_matrix_reduction) {
case ApplicableMatrixReduction::Rows: {
return reduce_matrix_rows(Stream(), reinterpret_cast<const HipT*>(X->template Data<BFloat16>()),
reinterpret_cast<HipT*>(Y->template MutableData<BFloat16>()), m, n);
}
case ApplicableMatrixReduction::Columns: {
const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size<HipT>(m, n);
auto buffer = rocm_ep_->GetScratchBuffer<void>(buffer_size_bytes);
return reduce_matrix_columns(Stream(), reinterpret_cast<const HipT*>(X->template Data<BFloat16>()),
reinterpret_cast<HipT*>(Y->template MutableData<BFloat16>()), m, n, buffer.get(),
buffer_size_bytes);
}
default:
break;
}
}

HIP_RETURN_IF_ERROR(hipMemsetAsync(Y->MutableDataRaw(), 0, Y->SizeInBytes(), Stream()));

size_t indices_bytes = 0;
size_t workspace_bytes = 0;
MiopenTensor input_tensor;
MiopenTensor output_tensor;
MiopenReduceDescriptor reduce_desc;

miopenDataType_t miopen_type_X = miopenFloat;
IAllocatorUniquePtr<float> temp_X = GetScratchBuffer<float>(input_count);
Impl_Cast<HipT, float>(Stream(), reinterpret_cast<const HipT*>(X->template Data<BFloat16>()), temp_X.get(),
X->Shape().Size());

ORT_RETURN_IF_ERROR(reduce_desc.Set(miopen_reduce_op, miopen_type_X, MIOPEN_REDUCE_TENSOR_NO_INDICES));
ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_miopen, miopen_type_X));
ORT_RETURN_IF_ERROR(output_tensor.Set(output_dims_miopen, miopen_type_X));
MIOPEN_RETURN_IF_ERROR(
miopenGetReductionIndicesSize(MiopenHandle(), reduce_desc, input_tensor, output_tensor, &indices_bytes));
MIOPEN_RETURN_IF_ERROR(
miopenGetReductionIndicesSize(MiopenHandle(), reduce_desc, input_tensor, output_tensor, &workspace_bytes));
IAllocatorUniquePtr<uint32_t> indices_rocm = GetScratchBuffer<uint32_t>(indices_bytes);
IAllocatorUniquePtr<HipT> workspace_rocm = GetScratchBuffer<HipT>(workspace_bytes);

const auto one = Consts<float>::One;
const auto zero = Consts<float>::Zero;
auto temp_Y = GetScratchBuffer<float>(output_count);
MIOPEN_RETURN_IF_ERROR(miopenReduceTensor(MiopenHandle(), reduce_desc, indices_rocm.get(), indices_bytes,
workspace_rocm.get(), workspace_bytes, &one, input_tensor, temp_X.get(),
&zero, output_tensor, temp_Y.get()));

Impl_Cast<float, HipT>(Stream(), temp_Y.get(), reinterpret_cast<HipT*>(Y->template MutableData<BFloat16>()), output_count);

return Status::OK();
}

namespace ReductionOps {

template <typename T, miopenReduceTensorIndices_t ReduceTensorIndices>
Expand Down Expand Up @@ -880,7 +985,8 @@ template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, MIOPEN_REDUCE_TENSOR_N

#define REGISTER_KERNEL_HFD(name) \
REGISTER_KERNEL_TYPED(name, MLFloat16) \
REGISTER_KERNEL_TYPED(name, float)
REGISTER_KERNEL_TYPED(name, float) \
REGISTER_KERNEL_TYPED(name, BFloat16)
// REGISTER_KERNEL_TYPED(name, double)

#define REGISTER_KERNEL_HFD_11(name) \
Expand Down Expand Up @@ -926,6 +1032,7 @@ REGISTER_KERNEL_TYPED_13(ReduceSum, float)
// REGISTER_KERNEL_TYPED_13(ReduceSum, double)
REGISTER_KERNEL_TYPED_13(ReduceSum, int32_t)
REGISTER_KERNEL_TYPED_13(ReduceSum, int64_t)
REGISTER_KERNEL_TYPED_13(ReduceSum, BFloat16)

REGISTER_KERNEL_HFD(ReduceLogSum)
REGISTER_KERNEL_HFD(ReduceSumSquare)
Expand Down
44 changes: 22 additions & 22 deletions onnxruntime/core/providers/rocm/rocm_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1120,18 +1120,18 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, SpaceToDepth);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, DepthToSpace);

// class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add);
// class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub);
// class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Mul);
// class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Div);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Cast);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Softmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Mul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Cast);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Softmax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul);
// class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Relu);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Sigmoid);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Tanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, ReduceSum);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, ReduceSum);

// OpSet 14
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum);
Expand Down Expand Up @@ -1188,10 +1188,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, uint8_t, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int64_t, ReduceMin);

// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Add);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Sub);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Mul);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Div);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Relu);

//OpSet 15
Expand Down Expand Up @@ -1964,18 +1964,18 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, SpaceToDepth)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, DepthToSpace)>,

// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Mul)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Div)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Cast)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Cast)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Relu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Sigmoid)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Tanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, ReduceSum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, ReduceSum)>,

// OpSet 14
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum)>,
Expand Down Expand Up @@ -2031,10 +2031,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int8_t, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, uint8_t, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int64_t, ReduceMin)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Add)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Sub)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Mul)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Div)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Relu)>,

// OpSet 15
Expand Down
Loading