From 54254a72b41cc043c13a8059e3555276391b3a3e Mon Sep 17 00:00:00 2001 From: pramods-cad Date: Tue, 30 Apr 2024 01:25:32 -0700 Subject: [PATCH] Enabling Int8 input/filter and Int16 output FullyConnected layer for xtensa (with ref code). --- .../micro/kernels/xtensa/fully_connected.cc | 27 +++++++++++++++++-- .../xtensa/fully_connected_common_xtensa.cc | 3 ++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc index 188b6ef505a..2ed000c3fc8 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc @@ -69,8 +69,31 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt8: { switch (filter->type) { case kTfLiteInt8: { - return XtensaEvalFullyConnectedQuantizedInt8( - context, node, data, input, filter, bias, output); + switch (output->type) { + case kTfLiteInt8: { + return XtensaEvalFullyConnectedQuantizedInt8( + context, node, data, input, filter, bias, output); + break; + } + case kTfLiteInt16: { + tflite::reference_integer_ops::FullyConnected( + FullyConnectedParamsQuantized(data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + default: { + MicroPrintf("Output type %s (%d) not supported.", + TfLiteTypeGetName(output->type), input->type); + return kTfLiteError; + } + } } case kTfLiteInt4: { return XtensaEvalFullyConnectedQuantizedInt8( diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc index 951205db986..5502e960415 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc @@ -104,7 +104,8 @@ TfLiteStatus XtensaPrepareFullyConnected(TfLiteContext* context, TfLiteTensor* output = micro_context->AllocateTempOutputTensor( node, kFullyConnectedOutputTensor); TF_LITE_ENSURE(context, output != nullptr); - TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); + // Disabling this check to support Int8 input, filter and Int16 output variant + // TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); if (filter->type == kTfLiteInt4) { #if defined(HIFI5) && defined(NNLIB_HIFI5)