Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conv fp16 kernel in xnnpack EP #22301

Merged
merged 25 commits into from
Oct 10, 2024
17 changes: 17 additions & 0 deletions onnxruntime/core/providers/xnnpack/detail/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@

#include "utils.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "core/common/common.h"
#include "core/common/safeint.h"
#include "core/framework/node_unit.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/graph.h"
#include "core/graph/indexed_sub_graph.h"
#include "core/graph/node_attr_utils.h"
#include "core/optimizer/initializer.h"
#include "core/providers/xnnpack/xnnpack_init.h"

#include "onnx/defs/attr_proto_util.h"

Expand Down Expand Up @@ -111,6 +114,20 @@
auto_pad == AutoPadType::SAME_UPPER;
}

bool IsComputeTypeSupported(int32_t compute_type,
std::optional<std::reference_wrapper<COMPUTE_TYPE_SETS>> compute_type_set) {

Check warning on line 118 in onnxruntime/core/providers/xnnpack/detail/utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/providers/xnnpack/detail/utils.cc:118: Do not indent within a namespace. [whitespace/indent_namespace] [4]
std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> default_supported_types{
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
ONNX_NAMESPACE::TensorProto_DataType_INT8};
auto supported_types = compute_type_set == std::nullopt ? default_supported_types : compute_type_set->get();
#ifdef XNNPACK_FP16_SUPPORTED
supported_types.insert(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
#endif
ONNX_NAMESPACE::TensorProto_DataType tp = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(compute_type);
return std::find(supported_types.begin(), supported_types.end(), tp) != supported_types.end();
}

typedef std::string ONNXOpType;

static const std::unordered_map<QuantizedOpType, ONNXOpType> qdq_to_onnx_type_map = {
Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/core/providers/xnnpack/detail/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#include <limits>
#include <memory>
#include <unordered_map>
#include <vector>
#include <string>
#include <unordered_set>
#include <utility>

#include "core/framework/node_unit.h"
Expand Down Expand Up @@ -77,6 +77,10 @@

bool IsPaddingTypeSupported(AutoPadType auto_pad);

using COMPUTE_TYPE_SETS = std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>;
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
bool IsComputeTypeSupported(int32_t compute_type,
std::optional<std::reference_wrapper<COMPUTE_TYPE_SETS>> compute_type_set = std::nullopt);

Check warning on line 82 in onnxruntime/core/providers/xnnpack/detail/utils.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/providers/xnnpack/detail/utils.h:82: Do not indent within a namespace. [whitespace/indent_namespace] [4]

using XnnpackOperator = std::unique_ptr<struct xnn_operator, XnnpackOperatorDeleter>;

std::unique_ptr<IndexedSubGraph::MetaDef> FuseActivation(const NodeUnit& conv_unit, const NodeUnit& activation,
Expand All @@ -99,5 +103,6 @@
auto zp = static_cast<float>(zero_point);
return static_cast<T>(lrintf(fminf(fmaxf(val / scale + zp, typed_min), typed_max)));
}

} // namespace xnnpack
} // namespace onnxruntime
24 changes: 19 additions & 5 deletions onnxruntime/core/providers/xnnpack/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
#include "core/framework/tensorprotoutils.h"
#include "core/framework/transpose_helper.h"
#include "core/providers/utils.h"
#include "core/providers/xnnpack/xnnpack_init.h"
#include "core/providers/xnnpack/detail/utils.h"
#include "core/providers/xnnpack/xnnpack_init.h"

namespace onnxruntime {
namespace xnnpack {
Expand All @@ -22,8 +22,10 @@
/*out*/ PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;
// only layout of weight input is adjusted via PrePack
if ((conv_type_ == OpComputeType::op_compute_type_fp32 && input_idx == 1) ||
(conv_type_ != OpComputeType::op_compute_type_fp32 && input_idx == 3)) { // InputTensors::IN_W
const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 ||
conv_type_ == OpComputeType::op_compute_type_fp16);
if ((conv_type_is_float && input_idx == 1) ||
(!conv_type_is_float && input_idx == 3)) { // InputTensors::IN_W
// Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group}
auto orig_shape = tensor.Shape();
const auto rank = orig_shape.NumDimensions();
Expand Down Expand Up @@ -56,7 +58,6 @@
// we can create the kernel now
ORT_RETURN_IF_ERROR(CreateKernel());
}
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved

return Status::OK();
}

Expand Down Expand Up @@ -102,6 +103,8 @@
reshape_fn = xnn_reshape_convolution2d_nhwc_qu8;
} else if (conv_type_ == OpComputeType::op_compute_type_qs8_per_channel) {
reshape_fn = xnn_reshape_convolution2d_nhwc_qs8_qc8w;
} else if (conv_type_ == OpComputeType::op_compute_type_fp16) {
reshape_fn = xnn_reshape_convolution2d_nhwc_f16;
}

auto status = reshape_fn(op0_.get(), N, H, W,
Expand All @@ -112,12 +115,14 @@
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_convolution2d_nhwc_", OpTypeToString(conv_type_),
"returned ", status);
}

workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size));

if (conv_type_ == OpComputeType::op_compute_type_fp32) {
status = xnn_setup_convolution2d_nhwc_f32(op0_.get(), workspace.get(), X.Data<float>(),
Y->MutableData<float>());
} else if (conv_type_ == OpComputeType::op_compute_type_fp16) {
status = xnn_setup_convolution2d_nhwc_f16(op0_.get(), workspace.get(), X.Data<MLFloat16>(),
Y->MutableData<MLFloat16>());
} else if (conv_type_ == OpComputeType::op_compute_type_qs8) {
status = xnn_setup_convolution2d_nhwc_qs8(op0_.get(), workspace.get(), X.Data<int8_t>(),
Y->MutableData<int8_t>());
Expand Down Expand Up @@ -149,6 +154,15 @@
ONNX_OPERATOR_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Conv);
#ifdef XNNPACK_FP16_SUPPORTED
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 1, 10, MLFloat16, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),

Check warning on line 159 in onnxruntime/core/providers/xnnpack/nn/conv.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/providers/xnnpack/nn/conv.cc:159: Do not indent within a namespace. [whitespace/indent_namespace] [4]

Check warning on line 159 in onnxruntime/core/providers/xnnpack/nn/conv.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/xnnpack/nn/conv.cc:159: Lines should be <= 120 characters long [whitespace/line_length] [2]
Conv);

Check warning on line 160 in onnxruntime/core/providers/xnnpack/nn/conv.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/providers/xnnpack/nn/conv.cc:160: Do not indent within a namespace. [whitespace/indent_namespace] [4]

ONNX_OPERATOR_TYPED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, MLFloat16, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),

Check warning on line 163 in onnxruntime/core/providers/xnnpack/nn/conv.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/providers/xnnpack/nn/conv.cc:163: Do not indent within a namespace. [whitespace/indent_namespace] [4]
Conv);

Check warning on line 164 in onnxruntime/core/providers/xnnpack/nn/conv.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/providers/xnnpack/nn/conv.cc:164: Do not indent within a namespace. [whitespace/indent_namespace] [4]
#endif

ONNX_OPERATOR_TYPED_KERNEL_EX(
QLinearConv,
Expand Down
39 changes: 33 additions & 6 deletions onnxruntime/core/providers/xnnpack/nn/conv_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,28 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs,
foutput_min, foutput_max, flags,
code_cache, weights_cache,
&p);
} else if (conv_type == OpComputeType::op_compute_type_fp16) {
const auto* B_data = Bias ? Bias->Data<MLFloat16>() : nullptr;
// 65504 is the max value of float16
// https://en.wikipedia.org/wiki/Half-precision_floating-point_format
auto output_min = clip_min_max ? onnxruntime::math::floatToHalf(clip_min_max->first) : -65504.0f;
auto output_max = clip_min_max ? onnxruntime::math::floatToHalf(clip_min_max->second) : 65504.0f;
auto create_func = is_transpose ? xnn_create_deconvolution2d_nhwc_f16
: xnn_create_convolution2d_nhwc_f16;
status = create_func(
input_padding_top, input_padding_right, input_padding_bottom, input_padding_left,
kernel_height, kernel_width,
subsampling_height, subsampling_width,
dilation_height, dilation_width,
group_count,
group_input_channels,
group_output_channels,
C, M, // input channel stride, output channel stride
Weight.Data<MLFloat16>(), B_data, // kernel, bias
output_min, output_max,
flags,
code_cache, weights_cache,
&p);
} else if (conv_type == OpComputeType::op_compute_type_qs8) {
const float output_scale = quant_param[2].first[0];
const int8_t output_zero_point = quant_param[2].second;
Expand Down Expand Up @@ -236,6 +258,13 @@ OpComputeType GetConvCompType(
return op_compute_type_qu8;
}
break;
case TensorTypeFp16:
if (input_datatype == TensorTypeFp16 &&
(!bias_datatype || *bias_datatype == TensorTypeInt32) &&
output_datatype == TensorTypeFp16) {
return op_compute_type_fp16;
}
break;
default:
break;
}
Expand Down Expand Up @@ -326,10 +355,7 @@ bool ConvBase::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer&

// we only support float and u8 currently
const auto* x_type = x_arg.TypeAsProto();
if (x_type == nullptr ||
(x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 &&
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) {
if (x_type == nullptr || !IsComputeTypeSupported(x_type->tensor_type().elem_type())) {
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
break;
}
// require C, H, W to be known so we can construct the xnnpack kernel prior to Compute
Expand Down Expand Up @@ -420,9 +446,11 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose)
input_dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
weight_index = 3;
conv_type_ = ParseQuantParamAndConType(info, quant_param_, input_dtype);
} else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
conv_type_ = OpComputeType::op_compute_type_fp16;
} else {
auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*X.TypeAsProto()));
ORT_THROW("unsupported Conv in XnnpackEP, we have FLOAT|UINT8|INT8, but got ", stype);
ORT_THROW("unsupported Conv in XnnpackEP, we have FLOAT|UINT8|INT8|FLOAT16, but got ", stype);
}

ORT_ENFORCE(info.TryGetConstantInput(weight_index, &Weight),
Expand Down Expand Up @@ -491,7 +519,6 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose)

output_shape_.push_back(M_);
}

// have to delay creating the xnnpack kernel until after the weights are pre-packed.
}

Expand Down
34 changes: 27 additions & 7 deletions onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "core/framework/transpose_helper.h"
#include "core/providers/utils.h"
#include "core/providers/xnnpack/detail/utils.h"
#include "core/providers/xnnpack/xnnpack_init.h"
#include "core/framework/tensorprotoutils.h"

namespace onnxruntime {
Expand All @@ -18,8 +19,10 @@
/*out*/ PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;
// only layout of weight input is adjusted via PrePack
if ((conv_type_ == OpComputeType::op_compute_type_fp32 && input_idx == 1) ||
(conv_type_ != OpComputeType::op_compute_type_fp32 && input_idx == 3)) { // InputTensors::IN_W
const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 ||
conv_type_ == OpComputeType::op_compute_type_fp16);
if ((conv_type_is_float && input_idx == 1) ||
(!conv_type_is_float && input_idx == 3)) { // InputTensors::IN_W
auto orig_shape = tensor.Shape();
const auto rank = orig_shape.NumDimensions();

Expand Down Expand Up @@ -129,6 +132,8 @@
reshape_fn = xnn_reshape_deconvolution2d_nhwc_qs8;
} else if (conv_type_ == OpComputeType::op_compute_type_qu8) {
reshape_fn = xnn_reshape_deconvolution2d_nhwc_qu8;
} else if (conv_type_ == OpComputeType::op_compute_type_fp16) {
reshape_fn = xnn_reshape_deconvolution2d_nhwc_f16;
}

status = reshape_fn(op0_.get(), N, H, W, output_pad_0, output_pad_1,
Expand All @@ -146,6 +151,8 @@
status = xnn_setup_deconvolution2d_nhwc_qs8(op0_.get(), X.Data<int8_t>(), Y->MutableData<int8_t>());
} else if (conv_type_ == OpComputeType::op_compute_type_qu8) {
status = xnn_setup_deconvolution2d_nhwc_qu8(op0_.get(), X.Data<uint8_t>(), Y->MutableData<uint8_t>());
} else if (conv_type_ == OpComputeType::op_compute_type_fp16) {
status = xnn_setup_deconvolution2d_nhwc_f16(op0_.get(), X.Data<MLFloat16>(), Y->MutableData<MLFloat16>());
}

if (status != xnn_status_success) {
Expand All @@ -161,16 +168,16 @@
return Status::OK();
}

ONNX_OPERATOR_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint(
"T", DataTypeImpl::GetTensorType<float>()),
ConvTranspose);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 1, 10, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint(
"T", DataTypeImpl::GetTensorType<float>()),
ConvTranspose);

ONNX_OPERATOR_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint(

Check warning on line 177 in onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc:177: Do not indent within a namespace. [whitespace/indent_namespace] [4]
"T", DataTypeImpl::GetTensorType<float>()),

Check warning on line 178 in onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc:178: Do not indent within a namespace. [whitespace/indent_namespace] [4]
ConvTranspose);

Check warning on line 179 in onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc:179: Do not indent within a namespace. [whitespace/indent_namespace] [4]

ONNX_OPERATOR_KERNEL_EX(QLinearConvTranspose, kMSInternalNHWCDomain, 1, kXnnpackExecutionProvider,
KernelDefBuilder()
.TypeConstraint(
Expand All @@ -179,5 +186,18 @@
DataTypeImpl::GetTensorType<int8_t>()}),
ConvTranspose);

#ifdef XNNPACK_FP16_SUPPORTED
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 1, 10, MLFloat16,
kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint(
"T", DataTypeImpl::GetTensorType<MLFloat16>()),
ConvTranspose);

ONNX_OPERATOR_TYPED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, MLFloat16, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint(
"T", DataTypeImpl::GetTensorType<MLFloat16>()),
ConvTranspose);

#endif
} // namespace xnnpack
} // namespace onnxruntime
19 changes: 15 additions & 4 deletions onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Type, Op)>

#ifdef XNNPACK_FP16_SUPPORTED
#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name) \
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, \
#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name) \
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, domain, \
startver, endver, MLFloat16, name)

#define CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(provider, domain, startver, name) \
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, startver, \
#define CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(provider, domain, startver, name) \
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, domain, startver, \
MLFloat16, name)
#else
#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name)
Expand All @@ -64,9 +64,14 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWC

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, Conv);
CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv);
CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, Conv);

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose);
CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10,
ConvTranspose);
CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, uint8_t, QLinearConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, int8_t, QLinearConv);
Expand Down Expand Up @@ -161,6 +166,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
KERNEL_CREATE_INFO(1, QLinearSoftmax, kDynamicDomainByCreate),

#ifdef XNNPACK_FP16_SUPPORTED
KERNEL_CREATE_INFO_VERSIONED_TYPED(1, 10, MLFloat16, Conv, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_TYPED(11, MLFloat16, Conv, kMSInternalNHWCDomain),

KERNEL_CREATE_INFO_VERSIONED_TYPED(1, 10, MLFloat16, ConvTranspose, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_TYPED(11, MLFloat16, ConvTranspose, kMSInternalNHWCDomain),

KERNEL_CREATE_INFO_VERSIONED_TYPED(8, 9, MLFloat16, MaxPool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED_TYPED(10, 10, MLFloat16, MaxPool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED_TYPED(11, 11, MLFloat16, MaxPool, kMSInternalNHWCDomain),
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
// Licensed under the MIT License.

#include "core/mlas/inc/mlas.h"
#include "core/providers/xnnpack/xnnpack_init.h"

#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM)
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(XNNPACK_FP16_SUPPORTED)

#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
Expand Down
Loading
Loading