From 25b1c38e87c6a976849d127eca7e00c8d9956619 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Thu, 10 Oct 2024 08:48:09 +0800 Subject: [PATCH] Add conv fp16 kernel in xnnpack EP (#22301) ### Description Add FP16 kernels of Conv and ConvTranspose [AB#50186](https://aiinfra.visualstudio.com/6a833879-cd9b-44a4-a9de-adc2d818f13c/_workitems/edit/50186) ### Motivation and Context --------- --- .../core/providers/xnnpack/detail/utils.cc | 7 ++++ .../core/providers/xnnpack/detail/utils.h | 18 ++++++++- onnxruntime/core/providers/xnnpack/nn/conv.cc | 24 +++++++++--- .../core/providers/xnnpack/nn/conv_base.cc | 39 ++++++++++++++++--- .../providers/xnnpack/nn/conv_transpose.cc | 34 ++++++++++++---- .../core/providers/xnnpack/nn/max_pool.cc | 4 +- .../xnnpack/xnnpack_execution_provider.cc | 19 +++++++-- .../core/providers/xnnpack/xnnpack_init.h | 11 ++++-- .../test/providers/cpu/nn/conv_fp16_test.cc | 4 +- .../cpu/nn/conv_transpose_op_test.cc | 25 +++++++++++- tools/ci_build/build.py | 1 + 11 files changed, 155 insertions(+), 31 deletions(-) diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.cc b/onnxruntime/core/providers/xnnpack/detail/utils.cc index f9cb45ebc8abc..4eef14dddecd3 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.cc +++ b/onnxruntime/core/providers/xnnpack/detail/utils.cc @@ -3,15 +3,18 @@ #include "utils.h" #include +#include #include #include "core/common/common.h" #include "core/common/safeint.h" #include "core/framework/node_unit.h" #include "core/framework/tensorprotoutils.h" +#include "core/graph/graph.h" #include "core/graph/indexed_sub_graph.h" #include "core/graph/node_attr_utils.h" #include "core/optimizer/initializer.h" +#include "core/providers/xnnpack/xnnpack_init.h" #include "onnx/defs/attr_proto_util.h" @@ -111,6 +114,10 @@ bool IsPaddingTypeSupported(AutoPadType auto_pad) { auto_pad == AutoPadType::SAME_UPPER; } +bool IsComputeTypeSupported(int32_t compute_type, const ComputeTypeSet& compute_type_set) { + return std::find(compute_type_set.begin(), compute_type_set.end(), compute_type) != compute_type_set.end(); +} + typedef std::string ONNXOpType; static const std::unordered_map qdq_to_onnx_type_map = { diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.h b/onnxruntime/core/providers/xnnpack/detail/utils.h index d555ee2286b84..0a80bc0450b99 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.h +++ b/onnxruntime/core/providers/xnnpack/detail/utils.h @@ -6,14 +6,15 @@ #include #include #include -#include #include +#include #include #include "core/framework/node_unit.h" #include "core/framework/op_kernel.h" #include "core/graph/indexed_sub_graph.h" #include "core/providers/common.h" +#include "core/providers/xnnpack/xnnpack_init.h" #include "xnnpack.h" @@ -77,6 +78,20 @@ struct XnnpackOperatorDeleter { bool IsPaddingTypeSupported(AutoPadType auto_pad); +using ComputeTypeSet = std::unordered_set; +#ifdef XNNPACK_FP16_SUPPORTED +bool IsComputeTypeSupported(int32_t compute_type, + const ComputeTypeSet& compute_type_set = {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + ONNX_NAMESPACE::TensorProto_DataType_UINT8, + ONNX_NAMESPACE::TensorProto_DataType_INT8}); +#else +bool IsComputeTypeSupported(int32_t compute_type, + const ComputeTypeSet& compute_type_set = {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_UINT8, + ONNX_NAMESPACE::TensorProto_DataType_INT8}); +#endif + using XnnpackOperator = std::unique_ptr; std::unique_ptr FuseActivation(const NodeUnit& conv_unit, const NodeUnit& activation, @@ -99,5 +114,6 @@ auto xnn_u8s8_quantize(float val, float scale, T zero_point) { auto zp = static_cast(zero_point); return static_cast(lrintf(fminf(fmaxf(val / scale + zp, typed_min), typed_max))); } + } // namespace xnnpack } // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index b815cc1570c96..6e404c62594fd 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -10,8 +10,8 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/transpose_helper.h" #include "core/providers/utils.h" -#include "core/providers/xnnpack/xnnpack_init.h" #include "core/providers/xnnpack/detail/utils.h" +#include "core/providers/xnnpack/xnnpack_init.h" namespace onnxruntime { namespace xnnpack { @@ -22,8 +22,10 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; // only layout of weight input is adjusted via PrePack - if ((conv_type_ == OpComputeType::op_compute_type_fp32 && input_idx == 1) || - (conv_type_ != OpComputeType::op_compute_type_fp32 && input_idx == 3)) { // InputTensors::IN_W + const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 || + conv_type_ == OpComputeType::op_compute_type_fp16); + if ((conv_type_is_float && input_idx == 1) || + (!conv_type_is_float && input_idx == 3)) { // InputTensors::IN_W // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} auto orig_shape = tensor.Shape(); const auto rank = orig_shape.NumDimensions(); @@ -56,7 +58,6 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, // we can create the kernel now ORT_RETURN_IF_ERROR(CreateKernel()); } - return Status::OK(); } @@ -102,6 +103,8 @@ Status Conv::Compute(OpKernelContext* context) const { reshape_fn = xnn_reshape_convolution2d_nhwc_qu8; } else if (conv_type_ == OpComputeType::op_compute_type_qs8_per_channel) { reshape_fn = xnn_reshape_convolution2d_nhwc_qs8_qc8w; + } else if (conv_type_ == OpComputeType::op_compute_type_fp16) { + reshape_fn = xnn_reshape_convolution2d_nhwc_f16; } auto status = reshape_fn(op0_.get(), N, H, W, @@ -112,12 +115,14 @@ Status Conv::Compute(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_convolution2d_nhwc_", OpTypeToString(conv_type_), "returned ", status); } - workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size)); if (conv_type_ == OpComputeType::op_compute_type_fp32) { status = xnn_setup_convolution2d_nhwc_f32(op0_.get(), workspace.get(), X.Data(), Y->MutableData()); + } else if (conv_type_ == OpComputeType::op_compute_type_fp16) { + status = xnn_setup_convolution2d_nhwc_f16(op0_.get(), workspace.get(), X.Data(), + Y->MutableData()); } else if (conv_type_ == OpComputeType::op_compute_type_qs8) { status = xnn_setup_convolution2d_nhwc_qs8(op0_.get(), workspace.get(), X.Data(), Y->MutableData()); @@ -149,6 +154,15 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 1, 10, kXnnpackEx ONNX_OPERATOR_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Conv); +#ifdef XNNPACK_FP16_SUPPORTED +ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 1, 10, MLFloat16, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, MLFloat16, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Conv); +#endif ONNX_OPERATOR_TYPED_KERNEL_EX( QLinearConv, diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc index 2aafc9be7ffd0..d6b9a541fbec7 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc @@ -81,6 +81,28 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, foutput_min, foutput_max, flags, code_cache, weights_cache, &p); + } else if (conv_type == OpComputeType::op_compute_type_fp16) { + const auto* B_data = Bias ? Bias->Data() : nullptr; + // 65504 is the max value of float16 + // https://en.wikipedia.org/wiki/Half-precision_floating-point_format + const float output_min = clip_min_max ? clip_min_max->first : -65504.0f; + const float output_max = clip_min_max ? clip_min_max->second : 65504.0f; + auto create_func = is_transpose ? xnn_create_deconvolution2d_nhwc_f16 + : xnn_create_convolution2d_nhwc_f16; + status = create_func( + input_padding_top, input_padding_right, input_padding_bottom, input_padding_left, + kernel_height, kernel_width, + subsampling_height, subsampling_width, + dilation_height, dilation_width, + group_count, + group_input_channels, + group_output_channels, + C, M, // input channel stride, output channel stride + Weight.Data(), B_data, // kernel, bias + output_min, output_max, + flags, + code_cache, weights_cache, + &p); } else if (conv_type == OpComputeType::op_compute_type_qs8) { const float output_scale = quant_param[2].first[0]; const int8_t output_zero_point = quant_param[2].second; @@ -236,6 +258,13 @@ OpComputeType GetConvCompType( return op_compute_type_qu8; } break; + case TensorTypeFp16: + if (input_datatype == TensorTypeFp16 && + (!bias_datatype || *bias_datatype == TensorTypeInt32) && + output_datatype == TensorTypeFp16) { + return op_compute_type_fp16; + } + break; default: break; } @@ -326,10 +355,7 @@ bool ConvBase::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& // we only support float and u8 currently const auto* x_type = x_arg.TypeAsProto(); - if (x_type == nullptr || - (x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && - x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 && - x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) { + if (x_type == nullptr || !IsComputeTypeSupported(x_type->tensor_type().elem_type())) { break; } // require C, H, W to be known so we can construct the xnnpack kernel prior to Compute @@ -420,9 +446,11 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) input_dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { weight_index = 3; conv_type_ = ParseQuantParamAndConType(info, quant_param_, input_dtype); + } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + conv_type_ = OpComputeType::op_compute_type_fp16; } else { auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*X.TypeAsProto())); - ORT_THROW("unsupported Conv in XnnpackEP, we have FLOAT|UINT8|INT8, but got ", stype); + ORT_THROW("unsupported Conv in XnnpackEP, we have FLOAT|UINT8|INT8|FLOAT16, but got ", stype); } ORT_ENFORCE(info.TryGetConstantInput(weight_index, &Weight), @@ -491,7 +519,6 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) output_shape_.push_back(M_); } - // have to delay creating the xnnpack kernel until after the weights are pre-packed. } diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc index 01c8119fea79d..b399311cd8568 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc @@ -7,6 +7,7 @@ #include "core/framework/transpose_helper.h" #include "core/providers/utils.h" #include "core/providers/xnnpack/detail/utils.h" +#include "core/providers/xnnpack/xnnpack_init.h" #include "core/framework/tensorprotoutils.h" namespace onnxruntime { @@ -18,8 +19,10 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; // only layout of weight input is adjusted via PrePack - if ((conv_type_ == OpComputeType::op_compute_type_fp32 && input_idx == 1) || - (conv_type_ != OpComputeType::op_compute_type_fp32 && input_idx == 3)) { // InputTensors::IN_W + const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 || + conv_type_ == OpComputeType::op_compute_type_fp16); + if ((conv_type_is_float && input_idx == 1) || + (!conv_type_is_float && input_idx == 3)) { // InputTensors::IN_W auto orig_shape = tensor.Shape(); const auto rank = orig_shape.NumDimensions(); @@ -129,6 +132,8 @@ Status ConvTranspose::Compute(OpKernelContext* context) const { reshape_fn = xnn_reshape_deconvolution2d_nhwc_qs8; } else if (conv_type_ == OpComputeType::op_compute_type_qu8) { reshape_fn = xnn_reshape_deconvolution2d_nhwc_qu8; + } else if (conv_type_ == OpComputeType::op_compute_type_fp16) { + reshape_fn = xnn_reshape_deconvolution2d_nhwc_f16; } status = reshape_fn(op0_.get(), N, H, W, output_pad_0, output_pad_1, @@ -146,6 +151,8 @@ Status ConvTranspose::Compute(OpKernelContext* context) const { status = xnn_setup_deconvolution2d_nhwc_qs8(op0_.get(), X.Data(), Y->MutableData()); } else if (conv_type_ == OpComputeType::op_compute_type_qu8) { status = xnn_setup_deconvolution2d_nhwc_qu8(op0_.get(), X.Data(), Y->MutableData()); + } else if (conv_type_ == OpComputeType::op_compute_type_fp16) { + status = xnn_setup_deconvolution2d_nhwc_f16(op0_.get(), X.Data(), Y->MutableData()); } if (status != xnn_status_success) { @@ -161,16 +168,16 @@ Status ConvTranspose::Compute(OpKernelContext* context) const { return Status::OK(); } -ONNX_OPERATOR_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint( - "T", DataTypeImpl::GetTensorType()), - ConvTranspose); - ONNX_OPERATOR_VERSIONED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 1, 10, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint( "T", DataTypeImpl::GetTensorType()), ConvTranspose); +ONNX_OPERATOR_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint( + "T", DataTypeImpl::GetTensorType()), + ConvTranspose); + ONNX_OPERATOR_KERNEL_EX(QLinearConvTranspose, kMSInternalNHWCDomain, 1, kXnnpackExecutionProvider, KernelDefBuilder() .TypeConstraint( @@ -179,5 +186,18 @@ ONNX_OPERATOR_KERNEL_EX(QLinearConvTranspose, kMSInternalNHWCDomain, 1, kXnnpack DataTypeImpl::GetTensorType()}), ConvTranspose); +#ifdef XNNPACK_FP16_SUPPORTED +ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 1, 10, MLFloat16, + kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint( + "T", DataTypeImpl::GetTensorType()), + ConvTranspose); + +ONNX_OPERATOR_TYPED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, MLFloat16, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint( + "T", DataTypeImpl::GetTensorType()), + ConvTranspose); + +#endif } // namespace xnnpack } // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/nn/max_pool.cc b/onnxruntime/core/providers/xnnpack/nn/max_pool.cc index 749e004094ba1..0d1d9ee1793e3 100644 --- a/onnxruntime/core/providers/xnnpack/nn/max_pool.cc +++ b/onnxruntime/core/providers/xnnpack/nn/max_pool.cc @@ -200,8 +200,8 @@ MaxPool::MaxPool(const OpKernelInfo& info) output_min, output_max, flags, &p); } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { maxpool_type_ = OpComputeType::op_compute_type_fp16; - const float output_min = -65504.0; - const float output_max = 65504.0; + const float output_min = clip_min_max_ ? clip_min_max_->first : -65504.0f; + const float output_max = clip_min_max_ ? clip_min_max_->first : 65504.0f; status = xnn_create_max_pooling2d_nhwc_f16(input_padding_top, input_padding_right, input_padding_bottom, input_padding_left, pooling_height, pooling_width, diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index df7df0b4376ce..4515a31eb0da0 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -44,12 +44,12 @@ KernelCreateInfo BuildKernelCreateInfo() { ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Type, Op)> #ifdef XNNPACK_FP16_SUPPORTED -#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name) \ - class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, \ +#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name) \ + class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, domain, \ startver, endver, MLFloat16, name) -#define CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(provider, domain, startver, name) \ - class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, startver, \ +#define CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(provider, domain, startver, name) \ + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, domain, startver, \ MLFloat16, name) #else #define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name) @@ -64,9 +64,14 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWC class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, Conv); +CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv); +CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, Conv); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); +CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, + ConvTranspose); +CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, uint8_t, QLinearConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, int8_t, QLinearConv); @@ -161,6 +166,12 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO(1, QLinearSoftmax, kDynamicDomainByCreate), #ifdef XNNPACK_FP16_SUPPORTED + KERNEL_CREATE_INFO_VERSIONED_TYPED(1, 10, MLFloat16, Conv, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_TYPED(11, MLFloat16, Conv, kMSInternalNHWCDomain), + + KERNEL_CREATE_INFO_VERSIONED_TYPED(1, 10, MLFloat16, ConvTranspose, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_TYPED(11, MLFloat16, ConvTranspose, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_VERSIONED_TYPED(8, 9, MLFloat16, MaxPool, kMSInternalNHWCDomain), KERNEL_CREATE_INFO_VERSIONED_TYPED(10, 10, MLFloat16, MaxPool, kMSInternalNHWCDomain), KERNEL_CREATE_INFO_VERSIONED_TYPED(11, 11, MLFloat16, MaxPool, kMSInternalNHWCDomain), diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_init.h b/onnxruntime/core/providers/xnnpack/xnnpack_init.h index ed824939a40da..89e92d0b99b13 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_init.h +++ b/onnxruntime/core/providers/xnnpack/xnnpack_init.h @@ -46,15 +46,20 @@ namespace xnnpack { #define XNN_ALLOCATION_ALIGNMENT 16 #endif +#if defined(__arm__) || defined(_M_ARM) +#define XNN_ARCH_ARM 1 +#else +#define XNN_ARCH_ARM 0 +#endif + #if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC) #define XNN_ARCH_ARM64 1 #else #define XNN_ARCH_ARM64 0 #endif -// fp16 support can vary on a kernel by kernel basis. Keep it simple and limit to arm64 for now. -// e.g. XNNPACK maxpool has x64 and arm64 fp16 kernels. -#if XNN_ARCH_ARM64 +// referenced from xnn_is_f16_compatible_config in XNNPACK/src/xnnpack/hardware-config.h +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 || ((XNN_ARCH_X86 || XNN_ARCH_X86_64) && !XNN_PLATFORM_MOBILE) #define XNNPACK_FP16_SUPPORTED #endif diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index ce1ac7591ec34..5cd43d21aacad 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -2,8 +2,10 @@ // Licensed under the MIT License. #include "core/mlas/inc/mlas.h" +#include "core/providers/xnnpack/xnnpack_init.h" -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) +// XNNPACK_FP16_SUPPORTED scope is too big, so add USE_XNNPACK to avoid the FP16 tests enabled for other EPs +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || (defined(USE_XNNPACK) && defined(XNNPACK_FP16_SUPPORTED)) #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 29525f89ef544..0ce87fb65898b 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/xnnpack/xnnpack_init.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "default_providers.h" @@ -28,6 +29,8 @@ void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, const vector>& input_shapes, const std::vector& expected_output, const vector& expected_output_shape, + float rel_error = 0.0, + float abs_error = 0.0, bool is_weight_and_bias_initializer = false, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", @@ -64,7 +67,7 @@ void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, for (size_t i = 0; i < inputs.size(); i++) { test.AddInput(input_names[i], input_shapes[i], inputs[i], is_initializers[i]); } - test.AddOutput("Y", expected_output_shape, expected_output); + test.AddOutput("Y", expected_output_shape, expected_output, false, rel_error, abs_error); test.Run(expect_result, err_str, excluded_provider_types); // Disable TensorRT because weight as input is not supported } @@ -78,12 +81,16 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", const std::unordered_set& excluded_provider_types = - {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}) { + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}, + float rel_error = 0.0, + float abs_error = 0.0) { std::unordered_set extra_exclude_openvino_for_initializer_filter = excluded_provider_types; extra_exclude_openvino_for_initializer_filter.insert(kOpenVINOExecutionProvider); TestConvTransposeOpInitializer(attributes, inputs, input_shapes, expected_output, expected_output_shape, + rel_error, abs_error, true, expect_result, err_str, extra_exclude_openvino_for_initializer_filter); TestConvTransposeOpInitializer(attributes, inputs, input_shapes, expected_output, expected_output_shape, + rel_error, abs_error, false, expect_result, err_str, excluded_provider_types); } @@ -245,8 +252,22 @@ TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { 0.07770107f, -0.09561026f, 0.13388641f, 0.30945939f, 0.14015588f, 0.13079405f, -0.00488365f, -0.06758944f, 0.45621645f, 0.01566098f, 0.00703105f, 0.12956856f, 0.0103332f, 0.04221053f, -0.21318194f}; +#ifdef XNNPACK_FP16_SUPPORTED + if constexpr (std::is_same::value) { + TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, + {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape, + OpTester::ExpectResult::kExpectSuccess, "", // defalut value + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}, // default value + 0.5, 0.5); + } else { + TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, + {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape); + } + +#else TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape); +#endif } TEST(ConvTransposeTest, ConvTranspose_2D_Bias_2) { diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 4cef66b26f0f4..c90f97ac529d4 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -761,6 +761,7 @@ def convert_arg_line_to_args(self, arg_line): ) parser.add_argument("--use_xnnpack", action="store_true", help="Enable xnnpack EP.") + parser.add_argument("--use_avx512", action="store_true", help="Enable AVX512 instructions") parser.add_argument("--use_azure", action="store_true", help="Enable azure EP.") parser.add_argument("--use_cache", action="store_true", help="Use compiler cache in CI")