diff --git a/onnxruntime/core/providers/xnnpack/nn/max_pool.cc b/onnxruntime/core/providers/xnnpack/nn/max_pool.cc index 0f0b827974f66..749e004094ba1 100644 --- a/onnxruntime/core/providers/xnnpack/nn/max_pool.cc +++ b/onnxruntime/core/providers/xnnpack/nn/max_pool.cc @@ -5,6 +5,7 @@ #include "core/graph/graph.h" #include "core/providers/utils.h" +#include "core/providers/xnnpack/xnnpack_init.h" #include "core/framework/tensorprotoutils.h" // to sanity check output shape @@ -54,6 +55,10 @@ bool MaxPool::IsOnnxNodeSupported(const NodeUnit& node_unit, // input of maxpool could be fp16/fp32/fp64,i8/u8 according to ONNX if (x_type == nullptr || (x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && +// because pool_fp16_op_test can be enabled by other preprocessor, for example, COREML_ENABLE_MLPROGRAM +#ifdef XNNPACK_FP16_SUPPORTED + x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && +#endif x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 && x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) { break; @@ -193,9 +198,19 @@ MaxPool::MaxPool(const OpKernelInfo& info) stride_height, stride_width, dilation_height, dilation_width, output_min, output_max, flags, &p); + } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + maxpool_type_ = OpComputeType::op_compute_type_fp16; + const float output_min = -65504.0; + const float output_max = 65504.0; + status = xnn_create_max_pooling2d_nhwc_f16(input_padding_top, input_padding_right, + input_padding_bottom, input_padding_left, + pooling_height, pooling_width, + stride_height, stride_width, + dilation_height, dilation_width, + output_min, output_max, flags, &p); } else { auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*X_arg.TypeAsProto())); - ORT_THROW("unsupported Conv in maxpool, we have FLOAT|UINT8, but got ", stype); + ORT_THROW("unsupported Conv in maxpool, we have FLOAT|UINT8|FLOAT16, but got ", stype); } ORT_ENFORCE(status == xnn_status_success, "xnn_create_max_pooling2d_nhwc_", OpTypeToString(maxpool_type_), "failed. Status:", status); @@ -225,10 +240,12 @@ Status MaxPool::Compute(OpKernelContext* context) const { pthreadpool_t threadpool = GetThreadPool(); auto reshape_fn = xnn_reshape_max_pooling2d_nhwc_f32; - if (maxpool_type_ == OpComputeType::op_compute_type_qu8) + if (maxpool_type_ == OpComputeType::op_compute_type_qu8) { reshape_fn = xnn_reshape_max_pooling2d_nhwc_u8; - else if (maxpool_type_ == OpComputeType::op_compute_type_qs8) { + } else if (maxpool_type_ == OpComputeType::op_compute_type_qs8) { reshape_fn = xnn_reshape_max_pooling2d_nhwc_s8; + } else if (maxpool_type_ == OpComputeType::op_compute_type_fp16) { + reshape_fn = xnn_reshape_max_pooling2d_nhwc_f16; } auto status = reshape_fn(op0_.get(), N, H, W, @@ -244,8 +261,10 @@ Status MaxPool::Compute(OpKernelContext* context) const { status = xnn_setup_max_pooling2d_nhwc_f32(op0_.get(), X.Data(), Y->MutableData()); } else if (maxpool_type_ == OpComputeType::op_compute_type_qu8) { status = xnn_setup_max_pooling2d_nhwc_u8(op0_.get(), X.Data(), Y->MutableData()); - } else { + } else if (maxpool_type_ == OpComputeType::op_compute_type_qs8) { status = xnn_setup_max_pooling2d_nhwc_s8(op0_.get(), X.Data(), Y->MutableData()); + } else if (maxpool_type_ == OpComputeType::op_compute_type_fp16) { + status = xnn_setup_max_pooling2d_nhwc_f16(op0_.get(), X.Data(), Y->MutableData()); } if (status != xnn_status_success) { @@ -285,5 +304,24 @@ ONNX_OPERATOR_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 12, kXnnpackExecutionPro DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), MaxPool); + +#ifdef XNNPACK_FP16_SUPPORTED +ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 8, 9, MLFloat16, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MaxPool); + +ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 10, 10, MLFloat16, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MaxPool); + +ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 11, 11, MLFloat16, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MaxPool); + +ONNX_OPERATOR_TYPED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 12, MLFloat16, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MaxPool); +#endif + } // namespace xnnpack } // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index 12e567e7080b3..df7df0b4376ce 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -31,6 +31,10 @@ KernelCreateInfo BuildKernelCreateInfo() { BuildKernelCreateInfo< \ ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, End, Op)> +#define KERNEL_CREATE_INFO_VERSIONED_TYPED(Start, End, Type, Op, Domain) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, End, Type, Op)> + #define KERNEL_CREATE_INFO(Start, Op, Domain) \ BuildKernelCreateInfo< \ ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Op)> @@ -39,6 +43,19 @@ KernelCreateInfo BuildKernelCreateInfo() { BuildKernelCreateInfo< \ ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Type, Op)> +#ifdef XNNPACK_FP16_SUPPORTED +#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name) \ + class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, \ + startver, endver, MLFloat16, name) + +#define CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(provider, domain, startver, name) \ + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, startver, \ + MLFloat16, name) +#else +#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name) +#define CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(provider, domain, startver, name) +#endif + // Layout sensitive operators in NHWC domain class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool); @@ -68,6 +85,10 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSIn class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); +CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool); +CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool); +CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); +CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); // ONNX operators class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 7, 8, Gemm); @@ -138,6 +159,13 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO_TYPED(10, int8_t, QLinearConv, kMSInternalNHWCDomain), KERNEL_CREATE_INFO(1, QLinearSoftmax, kDynamicDomainByCreate), + +#ifdef XNNPACK_FP16_SUPPORTED + KERNEL_CREATE_INFO_VERSIONED_TYPED(8, 9, MLFloat16, MaxPool, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_VERSIONED_TYPED(10, 10, MLFloat16, MaxPool, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_VERSIONED_TYPED(11, 11, MLFloat16, MaxPool, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_TYPED(12, MLFloat16, MaxPool, kMSInternalNHWCDomain), +#endif }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_init.h b/onnxruntime/core/providers/xnnpack/xnnpack_init.h index a1e64bf6046b2..ed824939a40da 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_init.h +++ b/onnxruntime/core/providers/xnnpack/xnnpack_init.h @@ -46,6 +46,18 @@ namespace xnnpack { #define XNN_ALLOCATION_ALIGNMENT 16 #endif +#if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC) +#define XNN_ARCH_ARM64 1 +#else +#define XNN_ARCH_ARM64 0 +#endif + +// fp16 support can vary on a kernel by kernel basis. Keep it simple and limit to arm64 for now. +// e.g. XNNPACK maxpool has x64 and arm64 fp16 kernels. +#if XNN_ARCH_ARM64 +#define XNNPACK_FP16_SUPPORTED +#endif + std::pair GetStoredAllocator(); } // namespace xnnpack diff --git a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc index 7d736d41e804b..46eb1180f4e7e 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc @@ -3,7 +3,7 @@ #include "core/mlas/inc/mlas.h" -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(XNNPACK_FP16_SUPPORTED) #include "core/providers/cpu/nn/pool.h" #include "gtest/gtest.h"