Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conv fp16 kernel in xnnpack EP #22301

Merged
merged 25 commits into from
Oct 10, 2024
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/xnnpack/detail/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
#include "utils.h"
#include <unordered_map>
#include <vector>
#include <set>
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved

#include "core/common/common.h"
#include "core/common/safeint.h"
#include "core/framework/node_unit.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/graph.h"
#include "core/graph/indexed_sub_graph.h"
#include "core/graph/node_attr_utils.h"
#include "core/optimizer/initializer.h"
#include "core/providers/xnnpack/xnnpack_init.h"

#include "onnx/defs/attr_proto_util.h"

Expand Down Expand Up @@ -111,6 +114,17 @@
auto_pad == AutoPadType::SAME_UPPER;
}

bool IsComputeTypeSupported(int32_t tp) {
#ifdef XNNPACK_FP16_SUPPORTED
std::set<ONNX_NAMESPACE::TensorProto_DataType> SupportedComputeType = {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, ONNX_NAMESPACE::TensorProto_DataType_UINT8, ONNX_NAMESPACE::TensorProto_DataType_INT8,

Check warning on line 119 in onnxruntime/core/providers/xnnpack/detail/utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/xnnpack/detail/utils.cc:119: Lines should be <= 120 characters long [whitespace/line_length] [2]
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16};
#else
std::set<ONNX_NAMESPACE::TensorProto_DataType> SupportedComputeType = {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, ONNX_NAMESPACE::TensorProto_DataType_UINT8, ONNX_NAMESPACE::TensorProto_DataType_INT8};

Check warning on line 122 in onnxruntime/core/providers/xnnpack/detail/utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/xnnpack/detail/utils.cc:122: Lines should be <= 120 characters long [whitespace/line_length] [2]
#endif
ONNX_NAMESPACE::TensorProto_DataType compute_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(tp);
return std::find(SupportedComputeType.begin(), SupportedComputeType.end(), compute_type) != SupportedComputeType.end();

Check warning on line 125 in onnxruntime/core/providers/xnnpack/detail/utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/xnnpack/detail/utils.cc:125: Lines should be <= 120 characters long [whitespace/line_length] [2]
}

typedef std::string ONNXOpType;

static const std::unordered_map<QuantizedOpType, ONNXOpType> qdq_to_onnx_type_map = {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/xnnpack/detail/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ struct XnnpackOperatorDeleter {

bool IsPaddingTypeSupported(AutoPadType auto_pad);

bool IsComputeTypeSupported(int32_t tp);

using XnnpackOperator = std::unique_ptr<struct xnn_operator, XnnpackOperatorDeleter>;

std::unique_ptr<IndexedSubGraph::MetaDef> FuseActivation(const NodeUnit& conv_unit, const NodeUnit& activation,
Expand All @@ -99,5 +101,6 @@ auto xnn_u8s8_quantize(float val, float scale, T zero_point) {
auto zp = static_cast<float>(zero_point);
return static_cast<T>(lrintf(fminf(fmaxf(val / scale + zp, typed_min), typed_max)));
}

} // namespace xnnpack
} // namespace onnxruntime
29 changes: 22 additions & 7 deletions onnxruntime/core/providers/xnnpack/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
#include "core/framework/tensorprotoutils.h"
#include "core/framework/transpose_helper.h"
#include "core/providers/utils.h"
#include "core/providers/xnnpack/xnnpack_init.h"
#include "core/providers/xnnpack/detail/utils.h"
#include "core/providers/xnnpack/xnnpack_init.h"

namespace onnxruntime {
namespace xnnpack {
Expand All @@ -22,9 +22,10 @@
/*out*/ PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;
// only layout of weight input is adjusted via PrePack
if ((conv_type_ == OpComputeType::op_compute_type_fp32 && input_idx == 1) ||
(conv_type_ != OpComputeType::op_compute_type_fp32 && input_idx == 3)) { // InputTensors::IN_W
// Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group}
const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 ||
conv_type_ == OpComputeType::op_compute_type_fp16);
if ((conv_type_is_float && input_idx == 1) || (!conv_type_is_float && input_idx == 3)) {
// InputTensors::IN_W, Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group}
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
auto orig_shape = tensor.Shape();
const auto rank = orig_shape.NumDimensions();

Expand Down Expand Up @@ -52,11 +53,9 @@
}

is_packed = true;

// we can create the kernel now
ORT_RETURN_IF_ERROR(CreateKernel());
}
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved

return Status::OK();
}

Expand Down Expand Up @@ -102,8 +101,13 @@
reshape_fn = xnn_reshape_convolution2d_nhwc_qu8;
} else if (conv_type_ == OpComputeType::op_compute_type_qs8_per_channel) {
reshape_fn = xnn_reshape_convolution2d_nhwc_qs8_qc8w;
} else if (conv_type_ == OpComputeType::op_compute_type_fp16) {
reshape_fn = xnn_reshape_convolution2d_nhwc_f16;
}

if (!op0_.get()) {
throw std::invalid_argument("op0 ------");
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
}
auto status = reshape_fn(op0_.get(), N, H, W,
&workspace_size, &workspace_alignment,
/*output_height_out=*/nullptr, /*output_width_out=*/nullptr,
Expand All @@ -112,7 +116,6 @@
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_convolution2d_nhwc_", OpTypeToString(conv_type_),
"returned ", status);
}

workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size));

if (conv_type_ == OpComputeType::op_compute_type_fp32) {
Expand All @@ -127,6 +130,9 @@
} else if (conv_type_ == OpComputeType::op_compute_type_qs8_per_channel) {
status = xnn_setup_convolution2d_nhwc_qs8_qc8w(op0_.get(), workspace.get(), X.Data<int8_t>(),
Y->MutableData<int8_t>());
} else if (conv_type_ == OpComputeType::op_compute_type_fp16) {
status = xnn_setup_convolution2d_nhwc_f16(op0_.get(), workspace.get(), X.Data<MLFloat16>(),
Y->MutableData<MLFloat16>());
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
}

if (status != xnn_status_success) {
Expand All @@ -149,6 +155,15 @@
ONNX_OPERATOR_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Conv);
#ifdef XNNPACK_FP16_SUPPORTED
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 1, 10, MLFloat16, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),

Check warning on line 160 in onnxruntime/core/providers/xnnpack/nn/conv.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/xnnpack/nn/conv.cc:160: Lines should be <= 120 characters long [whitespace/line_length] [2]
Conv);

ONNX_OPERATOR_TYPED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, MLFloat16, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
Conv);
#endif

ONNX_OPERATOR_TYPED_KERNEL_EX(
QLinearConv,
Expand Down
37 changes: 31 additions & 6 deletions onnxruntime/core/providers/xnnpack/nn/conv_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,26 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs,
flags,
code_cache, weights_cache,
&p);
} else if (conv_type == OpComputeType::op_compute_type_fp16) {
const auto* B_data = Bias ? Bias->Data<MLFloat16>() : nullptr;
const float output_min = -65504.0;
const float output_max = 65504.0;
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
auto create_func = is_transpose ? xnn_create_deconvolution2d_nhwc_f16
: xnn_create_convolution2d_nhwc_f16;
status = create_func(
input_padding_top, input_padding_right, input_padding_bottom, input_padding_left,
kernel_height, kernel_width,
subsampling_height, subsampling_width,
dilation_height, dilation_width,
group_count,
group_input_channels,
group_output_channels,
C, M, // input channel stride, output channel stride
Weight.Data<MLFloat16>(), B_data, // kernel, bias
output_min, output_max,
flags,
code_cache, weights_cache,
&p);
}

if (status != xnn_status_success) {
Expand Down Expand Up @@ -236,6 +256,13 @@ OpComputeType GetConvCompType(
return op_compute_type_qu8;
}
break;
case TensorTypeFp16:
if (input_datatype == TensorTypeFp16 &&
(!bias_datatype || *bias_datatype == TensorTypeInt32) &&
output_datatype == TensorTypeFp16) {
return op_compute_type_fp16;
}
break;
default:
break;
}
Expand Down Expand Up @@ -326,10 +353,7 @@ bool ConvBase::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer&

// we only support float and u8 currently
const auto* x_type = x_arg.TypeAsProto();
if (x_type == nullptr ||
(x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 &&
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) {
if (x_type == nullptr || !IsComputeTypeSupported(x_type->tensor_type().elem_type())) {
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
break;
}
// require C, H, W to be known so we can construct the xnnpack kernel prior to Compute
Expand Down Expand Up @@ -420,9 +444,11 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose)
input_dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
weight_index = 3;
conv_type_ = ParseQuantParamAndConType(info, quant_param_, input_dtype);
} else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
conv_type_ = OpComputeType::op_compute_type_fp16;
} else {
auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*X.TypeAsProto()));
ORT_THROW("unsupported Conv in XnnpackEP, we have FLOAT|UINT8|INT8, but got ", stype);
ORT_THROW("unsupported Conv in XnnpackEP, we have FLOAT|UINT8|INT8|FLOAT16, but got ", stype);
}

ORT_ENFORCE(info.TryGetConstantInput(weight_index, &Weight),
Expand Down Expand Up @@ -491,7 +517,6 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose)

output_shape_.push_back(M_);
}

// have to delay creating the xnnpack kernel until after the weights are pre-packed.
}

Expand Down
33 changes: 26 additions & 7 deletions onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "core/framework/transpose_helper.h"
#include "core/providers/utils.h"
#include "core/providers/xnnpack/detail/utils.h"
#include "core/providers/xnnpack/xnnpack_init.h"
#include "core/framework/tensorprotoutils.h"

namespace onnxruntime {
Expand All @@ -18,8 +19,9 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr
/*out*/ PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;
// only layout of weight input is adjusted via PrePack
if ((conv_type_ == OpComputeType::op_compute_type_fp32 && input_idx == 1) ||
(conv_type_ != OpComputeType::op_compute_type_fp32 && input_idx == 3)) { // InputTensors::IN_W
const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 ||
conv_type_ == OpComputeType::op_compute_type_fp16);
if ((conv_type_is_float && input_idx == 1) || (!conv_type_is_float && input_idx == 3)) { // InputTensors::IN_W
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
auto orig_shape = tensor.Shape();
const auto rank = orig_shape.NumDimensions();

Expand Down Expand Up @@ -129,6 +131,8 @@ Status ConvTranspose::Compute(OpKernelContext* context) const {
reshape_fn = xnn_reshape_deconvolution2d_nhwc_qs8;
} else if (conv_type_ == OpComputeType::op_compute_type_qu8) {
reshape_fn = xnn_reshape_deconvolution2d_nhwc_qu8;
} else if (conv_type_ == OpComputeType::op_compute_type_fp16) {
reshape_fn = xnn_reshape_deconvolution2d_nhwc_f16;
}

status = reshape_fn(op0_.get(), N, H, W, output_pad_0, output_pad_1,
Expand All @@ -146,6 +150,8 @@ Status ConvTranspose::Compute(OpKernelContext* context) const {
status = xnn_setup_deconvolution2d_nhwc_qs8(op0_.get(), X.Data<int8_t>(), Y->MutableData<int8_t>());
} else if (conv_type_ == OpComputeType::op_compute_type_qu8) {
status = xnn_setup_deconvolution2d_nhwc_qu8(op0_.get(), X.Data<uint8_t>(), Y->MutableData<uint8_t>());
} else if (conv_type_ == OpComputeType::op_compute_type_fp16) {
status = xnn_setup_deconvolution2d_nhwc_f16(op0_.get(), X.Data<MLFloat16>(), Y->MutableData<MLFloat16>());
}

if (status != xnn_status_success) {
Expand All @@ -161,16 +167,16 @@ Status ConvTranspose::Compute(OpKernelContext* context) const {
return Status::OK();
}

ONNX_OPERATOR_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint(
"T", DataTypeImpl::GetTensorType<float>()),
ConvTranspose);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 1, 10, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint(
"T", DataTypeImpl::GetTensorType<float>()),
ConvTranspose);

ONNX_OPERATOR_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint(
"T", DataTypeImpl::GetTensorType<float>()),
ConvTranspose);

ONNX_OPERATOR_KERNEL_EX(QLinearConvTranspose, kMSInternalNHWCDomain, 1, kXnnpackExecutionProvider,
KernelDefBuilder()
.TypeConstraint(
Expand All @@ -179,5 +185,18 @@ ONNX_OPERATOR_KERNEL_EX(QLinearConvTranspose, kMSInternalNHWCDomain, 1, kXnnpack
DataTypeImpl::GetTensorType<int8_t>()}),
ConvTranspose);

#ifdef XNNPACK_FP16_SUPPORTED
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 1, 10, MLFloat16,
kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint(
"T", DataTypeImpl::GetTensorType<MLFloat16>()),
ConvTranspose);

ONNX_OPERATOR_TYPED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, MLFloat16, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint(
"T", DataTypeImpl::GetTensorType<MLFloat16>()),
ConvTranspose);

#endif
} // namespace xnnpack
} // namespace onnxruntime
11 changes: 11 additions & 0 deletions onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,14 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWC

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, Conv);
CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv);
CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, Conv);

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose);
CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10,
ConvTranspose);
CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, uint8_t, QLinearConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, int8_t, QLinearConv);
Expand Down Expand Up @@ -161,6 +166,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
KERNEL_CREATE_INFO(1, QLinearSoftmax, kDynamicDomainByCreate),

#ifdef XNNPACK_FP16_SUPPORTED
KERNEL_CREATE_INFO_VERSIONED_TYPED(1, 10, MLFloat16, Conv, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_TYPED(11, MLFloat16, Conv, kMSInternalNHWCDomain),

KERNEL_CREATE_INFO_VERSIONED_TYPED(1, 10, MLFloat16, ConvTranspose, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_TYPED(11, MLFloat16, ConvTranspose, kMSInternalNHWCDomain),

KERNEL_CREATE_INFO_VERSIONED_TYPED(8, 9, MLFloat16, MaxPool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED_TYPED(10, 10, MLFloat16, MaxPool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED_TYPED(11, 11, MLFloat16, MaxPool, kMSInternalNHWCDomain),
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/test/providers/checkers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ struct DefaultTolerance<MLFloat16> {
if (provider_type == kDmlExecutionProvider) {
return 0.005f;
}
if (provider_type == kXnnpackExecutionProvider) {
// To allow tests like ConvTranspose_2D_Bias_1 to pass
return 0.05f;
}
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
return absolute;
}
};
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include "core/mlas/inc/mlas.h"

#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM)
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(XNNPACK_FP16_SUPPORTED)

#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
Expand Down
Loading