Skip to content

Commit

Permalink
Add conv fp16 kernel in xnnpack EP (#22301)
Browse files Browse the repository at this point in the history
### Description
Add FP16 kernels of Conv and ConvTranspose

[AB#50186](https://aiinfra.visualstudio.com/6a833879-cd9b-44a4-a9de-adc2d818f13c/_workitems/edit/50186)



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------
  • Loading branch information
mszhanyi authored Oct 10, 2024
1 parent 2b0ea6c commit 25b1c38
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 31 deletions.
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/xnnpack/detail/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@

#include "utils.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "core/common/common.h"
#include "core/common/safeint.h"
#include "core/framework/node_unit.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/graph.h"
#include "core/graph/indexed_sub_graph.h"
#include "core/graph/node_attr_utils.h"
#include "core/optimizer/initializer.h"
#include "core/providers/xnnpack/xnnpack_init.h"

#include "onnx/defs/attr_proto_util.h"

Expand Down Expand Up @@ -111,6 +114,10 @@ bool IsPaddingTypeSupported(AutoPadType auto_pad) {
auto_pad == AutoPadType::SAME_UPPER;
}

bool IsComputeTypeSupported(int32_t compute_type, const ComputeTypeSet& compute_type_set) {
return std::find(compute_type_set.begin(), compute_type_set.end(), compute_type) != compute_type_set.end();
}

typedef std::string ONNXOpType;

static const std::unordered_map<QuantizedOpType, ONNXOpType> qdq_to_onnx_type_map = {
Expand Down
18 changes: 17 additions & 1 deletion onnxruntime/core/providers/xnnpack/detail/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
#include <limits>
#include <memory>
#include <unordered_map>
#include <vector>
#include <string>
#include <unordered_set>
#include <utility>

#include "core/framework/node_unit.h"
#include "core/framework/op_kernel.h"
#include "core/graph/indexed_sub_graph.h"
#include "core/providers/common.h"
#include "core/providers/xnnpack/xnnpack_init.h"

#include "xnnpack.h"

Expand Down Expand Up @@ -77,6 +78,20 @@ struct XnnpackOperatorDeleter {

bool IsPaddingTypeSupported(AutoPadType auto_pad);

using ComputeTypeSet = std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>;
#ifdef XNNPACK_FP16_SUPPORTED
bool IsComputeTypeSupported(int32_t compute_type,
const ComputeTypeSet& compute_type_set = {ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
ONNX_NAMESPACE::TensorProto_DataType_INT8});
#else
bool IsComputeTypeSupported(int32_t compute_type,
const ComputeTypeSet& compute_type_set = {ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
ONNX_NAMESPACE::TensorProto_DataType_INT8});
#endif

using XnnpackOperator = std::unique_ptr<struct xnn_operator, XnnpackOperatorDeleter>;

std::unique_ptr<IndexedSubGraph::MetaDef> FuseActivation(const NodeUnit& conv_unit, const NodeUnit& activation,
Expand All @@ -99,5 +114,6 @@ auto xnn_u8s8_quantize(float val, float scale, T zero_point) {
auto zp = static_cast<float>(zero_point);
return static_cast<T>(lrintf(fminf(fmaxf(val / scale + zp, typed_min), typed_max)));
}

} // namespace xnnpack
} // namespace onnxruntime
24 changes: 19 additions & 5 deletions onnxruntime/core/providers/xnnpack/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
#include "core/framework/tensorprotoutils.h"
#include "core/framework/transpose_helper.h"
#include "core/providers/utils.h"
#include "core/providers/xnnpack/xnnpack_init.h"
#include "core/providers/xnnpack/detail/utils.h"
#include "core/providers/xnnpack/xnnpack_init.h"

namespace onnxruntime {
namespace xnnpack {
Expand All @@ -22,8 +22,10 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
/*out*/ PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;
// only layout of weight input is adjusted via PrePack
if ((conv_type_ == OpComputeType::op_compute_type_fp32 && input_idx == 1) ||
(conv_type_ != OpComputeType::op_compute_type_fp32 && input_idx == 3)) { // InputTensors::IN_W
const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 ||
conv_type_ == OpComputeType::op_compute_type_fp16);
if ((conv_type_is_float && input_idx == 1) ||
(!conv_type_is_float && input_idx == 3)) { // InputTensors::IN_W
// Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group}
auto orig_shape = tensor.Shape();
const auto rank = orig_shape.NumDimensions();
Expand Down Expand Up @@ -56,7 +58,6 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
// we can create the kernel now
ORT_RETURN_IF_ERROR(CreateKernel());
}

return Status::OK();
}

Expand Down Expand Up @@ -102,6 +103,8 @@ Status Conv::Compute(OpKernelContext* context) const {
reshape_fn = xnn_reshape_convolution2d_nhwc_qu8;
} else if (conv_type_ == OpComputeType::op_compute_type_qs8_per_channel) {
reshape_fn = xnn_reshape_convolution2d_nhwc_qs8_qc8w;
} else if (conv_type_ == OpComputeType::op_compute_type_fp16) {
reshape_fn = xnn_reshape_convolution2d_nhwc_f16;
}

auto status = reshape_fn(op0_.get(), N, H, W,
Expand All @@ -112,12 +115,14 @@ Status Conv::Compute(OpKernelContext* context) const {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_convolution2d_nhwc_", OpTypeToString(conv_type_),
"returned ", status);
}

workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size));

if (conv_type_ == OpComputeType::op_compute_type_fp32) {
status = xnn_setup_convolution2d_nhwc_f32(op0_.get(), workspace.get(), X.Data<float>(),
Y->MutableData<float>());
} else if (conv_type_ == OpComputeType::op_compute_type_fp16) {
status = xnn_setup_convolution2d_nhwc_f16(op0_.get(), workspace.get(), X.Data<MLFloat16>(),
Y->MutableData<MLFloat16>());
} else if (conv_type_ == OpComputeType::op_compute_type_qs8) {
status = xnn_setup_convolution2d_nhwc_qs8(op0_.get(), workspace.get(), X.Data<int8_t>(),
Y->MutableData<int8_t>());
Expand Down Expand Up @@ -149,6 +154,15 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 1, 10, kXnnpackEx
ONNX_OPERATOR_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Conv);
#ifdef XNNPACK_FP16_SUPPORTED
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 1, 10, MLFloat16, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
Conv);

ONNX_OPERATOR_TYPED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, MLFloat16, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
Conv);
#endif

ONNX_OPERATOR_TYPED_KERNEL_EX(
QLinearConv,
Expand Down
39 changes: 33 additions & 6 deletions onnxruntime/core/providers/xnnpack/nn/conv_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,28 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs,
foutput_min, foutput_max, flags,
code_cache, weights_cache,
&p);
} else if (conv_type == OpComputeType::op_compute_type_fp16) {
const auto* B_data = Bias ? Bias->Data<MLFloat16>() : nullptr;
// 65504 is the max value of float16
// https://en.wikipedia.org/wiki/Half-precision_floating-point_format
const float output_min = clip_min_max ? clip_min_max->first : -65504.0f;
const float output_max = clip_min_max ? clip_min_max->second : 65504.0f;
auto create_func = is_transpose ? xnn_create_deconvolution2d_nhwc_f16
: xnn_create_convolution2d_nhwc_f16;
status = create_func(
input_padding_top, input_padding_right, input_padding_bottom, input_padding_left,
kernel_height, kernel_width,
subsampling_height, subsampling_width,
dilation_height, dilation_width,
group_count,
group_input_channels,
group_output_channels,
C, M, // input channel stride, output channel stride
Weight.Data<MLFloat16>(), B_data, // kernel, bias
output_min, output_max,
flags,
code_cache, weights_cache,
&p);
} else if (conv_type == OpComputeType::op_compute_type_qs8) {
const float output_scale = quant_param[2].first[0];
const int8_t output_zero_point = quant_param[2].second;
Expand Down Expand Up @@ -236,6 +258,13 @@ OpComputeType GetConvCompType(
return op_compute_type_qu8;
}
break;
case TensorTypeFp16:
if (input_datatype == TensorTypeFp16 &&
(!bias_datatype || *bias_datatype == TensorTypeInt32) &&
output_datatype == TensorTypeFp16) {
return op_compute_type_fp16;
}
break;
default:
break;
}
Expand Down Expand Up @@ -326,10 +355,7 @@ bool ConvBase::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer&

// we only support float and u8 currently
const auto* x_type = x_arg.TypeAsProto();
if (x_type == nullptr ||
(x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 &&
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) {
if (x_type == nullptr || !IsComputeTypeSupported(x_type->tensor_type().elem_type())) {
break;
}
// require C, H, W to be known so we can construct the xnnpack kernel prior to Compute
Expand Down Expand Up @@ -420,9 +446,11 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose)
input_dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
weight_index = 3;
conv_type_ = ParseQuantParamAndConType(info, quant_param_, input_dtype);
} else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
conv_type_ = OpComputeType::op_compute_type_fp16;
} else {
auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*X.TypeAsProto()));
ORT_THROW("unsupported Conv in XnnpackEP, we have FLOAT|UINT8|INT8, but got ", stype);
ORT_THROW("unsupported Conv in XnnpackEP, we have FLOAT|UINT8|INT8|FLOAT16, but got ", stype);
}

ORT_ENFORCE(info.TryGetConstantInput(weight_index, &Weight),
Expand Down Expand Up @@ -491,7 +519,6 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose)

output_shape_.push_back(M_);
}

// have to delay creating the xnnpack kernel until after the weights are pre-packed.
}

Expand Down
34 changes: 27 additions & 7 deletions onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "core/framework/transpose_helper.h"
#include "core/providers/utils.h"
#include "core/providers/xnnpack/detail/utils.h"
#include "core/providers/xnnpack/xnnpack_init.h"
#include "core/framework/tensorprotoutils.h"

namespace onnxruntime {
Expand All @@ -18,8 +19,10 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr
/*out*/ PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;
// only layout of weight input is adjusted via PrePack
if ((conv_type_ == OpComputeType::op_compute_type_fp32 && input_idx == 1) ||
(conv_type_ != OpComputeType::op_compute_type_fp32 && input_idx == 3)) { // InputTensors::IN_W
const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 ||
conv_type_ == OpComputeType::op_compute_type_fp16);
if ((conv_type_is_float && input_idx == 1) ||
(!conv_type_is_float && input_idx == 3)) { // InputTensors::IN_W
auto orig_shape = tensor.Shape();
const auto rank = orig_shape.NumDimensions();

Expand Down Expand Up @@ -129,6 +132,8 @@ Status ConvTranspose::Compute(OpKernelContext* context) const {
reshape_fn = xnn_reshape_deconvolution2d_nhwc_qs8;
} else if (conv_type_ == OpComputeType::op_compute_type_qu8) {
reshape_fn = xnn_reshape_deconvolution2d_nhwc_qu8;
} else if (conv_type_ == OpComputeType::op_compute_type_fp16) {
reshape_fn = xnn_reshape_deconvolution2d_nhwc_f16;
}

status = reshape_fn(op0_.get(), N, H, W, output_pad_0, output_pad_1,
Expand All @@ -146,6 +151,8 @@ Status ConvTranspose::Compute(OpKernelContext* context) const {
status = xnn_setup_deconvolution2d_nhwc_qs8(op0_.get(), X.Data<int8_t>(), Y->MutableData<int8_t>());
} else if (conv_type_ == OpComputeType::op_compute_type_qu8) {
status = xnn_setup_deconvolution2d_nhwc_qu8(op0_.get(), X.Data<uint8_t>(), Y->MutableData<uint8_t>());
} else if (conv_type_ == OpComputeType::op_compute_type_fp16) {
status = xnn_setup_deconvolution2d_nhwc_f16(op0_.get(), X.Data<MLFloat16>(), Y->MutableData<MLFloat16>());
}

if (status != xnn_status_success) {
Expand All @@ -161,16 +168,16 @@ Status ConvTranspose::Compute(OpKernelContext* context) const {
return Status::OK();
}

ONNX_OPERATOR_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint(
"T", DataTypeImpl::GetTensorType<float>()),
ConvTranspose);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 1, 10, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint(
"T", DataTypeImpl::GetTensorType<float>()),
ConvTranspose);

ONNX_OPERATOR_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint(
"T", DataTypeImpl::GetTensorType<float>()),
ConvTranspose);

ONNX_OPERATOR_KERNEL_EX(QLinearConvTranspose, kMSInternalNHWCDomain, 1, kXnnpackExecutionProvider,
KernelDefBuilder()
.TypeConstraint(
Expand All @@ -179,5 +186,18 @@ ONNX_OPERATOR_KERNEL_EX(QLinearConvTranspose, kMSInternalNHWCDomain, 1, kXnnpack
DataTypeImpl::GetTensorType<int8_t>()}),
ConvTranspose);

#ifdef XNNPACK_FP16_SUPPORTED
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 1, 10, MLFloat16,
kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint(
"T", DataTypeImpl::GetTensorType<MLFloat16>()),
ConvTranspose);

ONNX_OPERATOR_TYPED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, MLFloat16, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint(
"T", DataTypeImpl::GetTensorType<MLFloat16>()),
ConvTranspose);

#endif
} // namespace xnnpack
} // namespace onnxruntime
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/xnnpack/nn/max_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ MaxPool::MaxPool(const OpKernelInfo& info)
output_min, output_max, flags, &p);
} else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
maxpool_type_ = OpComputeType::op_compute_type_fp16;
const float output_min = -65504.0;
const float output_max = 65504.0;
const float output_min = clip_min_max_ ? clip_min_max_->first : -65504.0f;
const float output_max = clip_min_max_ ? clip_min_max_->first : 65504.0f;
status = xnn_create_max_pooling2d_nhwc_f16(input_padding_top, input_padding_right,
input_padding_bottom, input_padding_left,
pooling_height, pooling_width,
Expand Down
19 changes: 15 additions & 4 deletions onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Type, Op)>

#ifdef XNNPACK_FP16_SUPPORTED
#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name) \
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, \
#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name) \
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, domain, \
startver, endver, MLFloat16, name)

#define CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(provider, domain, startver, name) \
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, startver, \
#define CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(provider, domain, startver, name) \
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, domain, startver, \
MLFloat16, name)
#else
#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name)
Expand All @@ -64,9 +64,14 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWC

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, Conv);
CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv);
CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, Conv);

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose);
CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10,
ConvTranspose);
CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, uint8_t, QLinearConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, int8_t, QLinearConv);
Expand Down Expand Up @@ -161,6 +166,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
KERNEL_CREATE_INFO(1, QLinearSoftmax, kDynamicDomainByCreate),

#ifdef XNNPACK_FP16_SUPPORTED
KERNEL_CREATE_INFO_VERSIONED_TYPED(1, 10, MLFloat16, Conv, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_TYPED(11, MLFloat16, Conv, kMSInternalNHWCDomain),

KERNEL_CREATE_INFO_VERSIONED_TYPED(1, 10, MLFloat16, ConvTranspose, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_TYPED(11, MLFloat16, ConvTranspose, kMSInternalNHWCDomain),

KERNEL_CREATE_INFO_VERSIONED_TYPED(8, 9, MLFloat16, MaxPool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED_TYPED(10, 10, MLFloat16, MaxPool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED_TYPED(11, 11, MLFloat16, MaxPool, kMSInternalNHWCDomain),
Expand Down
11 changes: 8 additions & 3 deletions onnxruntime/core/providers/xnnpack/xnnpack_init.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,20 @@ namespace xnnpack {
#define XNN_ALLOCATION_ALIGNMENT 16
#endif

#if defined(__arm__) || defined(_M_ARM)
#define XNN_ARCH_ARM 1
#else
#define XNN_ARCH_ARM 0
#endif

#if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC)
#define XNN_ARCH_ARM64 1
#else
#define XNN_ARCH_ARM64 0
#endif

// fp16 support can vary on a kernel by kernel basis. Keep it simple and limit to arm64 for now.
// e.g. XNNPACK maxpool has x64 and arm64 fp16 kernels.
#if XNN_ARCH_ARM64
// referenced from xnn_is_f16_compatible_config in XNNPACK/src/xnnpack/hardware-config.h
#if XNN_ARCH_ARM || XNN_ARCH_ARM64 || ((XNN_ARCH_X86 || XNN_ARCH_X86_64) && !XNN_PLATFORM_MOBILE)
#define XNNPACK_FP16_SUPPORTED
#endif

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
// Licensed under the MIT License.

#include "core/mlas/inc/mlas.h"
#include "core/providers/xnnpack/xnnpack_init.h"

#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM)
// XNNPACK_FP16_SUPPORTED scope is too big, so add USE_XNNPACK to avoid the FP16 tests enabled for other EPs
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || (defined(USE_XNNPACK) && defined(XNNPACK_FP16_SUPPORTED))

#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
Expand Down
Loading

0 comments on commit 25b1c38

Please sign in to comment.