From 5b14dd190ec493fd4dc60ad30ddb753e23b980a0 Mon Sep 17 00:00:00 2001 From: cad-audio Date: Mon, 9 Dec 2024 22:05:19 -0800 Subject: [PATCH 1/2] Fixed scratch size calculation for conv for HiFi targets for scenarios when input, filter and output heights are 1. --- .../lite/micro/kernels/xtensa/conv_hifi.cc | 48 ++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc b/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc index 1d2d7ec253e..e43ef03be44 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc @@ -77,28 +77,64 @@ TfLiteStatus ConvPrepareHifi(TfLiteContext* context, TfLiteNode* node) { } const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); const int filter_height = filter_shape.Dims(1); const int filter_width = filter_shape.Dims(2); const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); const int output_channels = output_shape.Dims(3); const int stride_height = params->stride_height; + const int stride_width = params->stride_width; const int pad_height = data->reference_op_data.padding.height; + const int pad_width = data->reference_op_data.padding.width; int required_scratch = 0; // TODO(b/277112516): Dilation is currently not supported on HiFi 4 NN Library if ((params->dilation_width_factor == 1) && (params->dilation_height_factor == 1)) { if (input->type == kTfLiteInt8) { - required_scratch = xa_nn_conv2d_std_getsize( - input_height, input_depth, filter_height, filter_width, stride_height, - pad_height, output_height, output_channels, PREC_ASYM8S); + if (input_height == 1 && filter_height == 1 && output_height == 1) + { + int inp_h, filt_h, filt_w, str_h, pad_h, out_h; + inp_h = input_width; + filt_h = filter_width; + filt_w = filter_height; + str_h = stride_width; + pad_h = pad_width; + out_h = output_width; + required_scratch = xa_nn_conv2d_std_getsize( + inp_h, input_depth, filt_h, filt_w, str_h, + pad_h, out_h, output_channels, PREC_ASYM8S); + } + else + { + required_scratch = xa_nn_conv2d_std_getsize( + input_height, input_depth, filter_height, filter_width, stride_height, + pad_height, output_height, output_channels, PREC_ASYM8S); + } TF_LITE_ENSURE(context, required_scratch > 0); } if (input->type == kTfLiteInt16) { - required_scratch = xa_nn_conv2d_std_getsize( - input_height, input_depth, filter_height, filter_width, stride_height, - pad_height, output_height, output_channels, PREC_SYM16S); + if (input_height == 1 && filter_height == 1 && output_height == 1) + { + int inp_h, filt_h, filt_w, str_h, pad_h, out_h; + inp_h = input_width; + filt_h = filter_width; + filt_w = filter_height; + str_h = stride_width; + pad_h = pad_width; + out_h = output_width; + required_scratch = xa_nn_conv2d_std_getsize( + inp_h, input_depth, filt_h, filt_w, str_h, + pad_h, out_h, output_channels, PREC_SYM16S); + } + else + { + required_scratch = xa_nn_conv2d_std_getsize( + input_height, input_depth, filter_height, filter_width, stride_height, + pad_height, output_height, output_channels, PREC_SYM16S); + } TF_LITE_ENSURE(context, required_scratch > 0); } } From 0d3531308c74596d3064731d805c3209a695c33e Mon Sep 17 00:00:00 2001 From: cad-audio Date: Fri, 13 Dec 2024 10:23:36 -0800 Subject: [PATCH 2/2] Corrected formatting. --- .../lite/micro/kernels/xtensa/conv_hifi.cc | 32 ++++++++----------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc b/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc index e43ef03be44..f17809484d6 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc @@ -94,8 +94,7 @@ TfLiteStatus ConvPrepareHifi(TfLiteContext* context, TfLiteNode* node) { if ((params->dilation_width_factor == 1) && (params->dilation_height_factor == 1)) { if (input->type == kTfLiteInt8) { - if (input_height == 1 && filter_height == 1 && output_height == 1) - { + if (input_height == 1 && filter_height == 1 && output_height == 1) { int inp_h, filt_h, filt_w, str_h, pad_h, out_h; inp_h = input_width; filt_h = filter_width; @@ -104,20 +103,18 @@ TfLiteStatus ConvPrepareHifi(TfLiteContext* context, TfLiteNode* node) { pad_h = pad_width; out_h = output_width; required_scratch = xa_nn_conv2d_std_getsize( - inp_h, input_depth, filt_h, filt_w, str_h, - pad_h, out_h, output_channels, PREC_ASYM8S); - } - else - { + inp_h, input_depth, filt_h, filt_w, str_h, pad_h, out_h, + output_channels, PREC_ASYM8S); + } else { required_scratch = xa_nn_conv2d_std_getsize( - input_height, input_depth, filter_height, filter_width, stride_height, - pad_height, output_height, output_channels, PREC_ASYM8S); + input_height, input_depth, filter_height, filter_width, + stride_height, pad_height, output_height, output_channels, + PREC_ASYM8S); } TF_LITE_ENSURE(context, required_scratch > 0); } if (input->type == kTfLiteInt16) { - if (input_height == 1 && filter_height == 1 && output_height == 1) - { + if (input_height == 1 && filter_height == 1 && output_height == 1) { int inp_h, filt_h, filt_w, str_h, pad_h, out_h; inp_h = input_width; filt_h = filter_width; @@ -126,14 +123,13 @@ TfLiteStatus ConvPrepareHifi(TfLiteContext* context, TfLiteNode* node) { pad_h = pad_width; out_h = output_width; required_scratch = xa_nn_conv2d_std_getsize( - inp_h, input_depth, filt_h, filt_w, str_h, - pad_h, out_h, output_channels, PREC_SYM16S); - } - else - { + inp_h, input_depth, filt_h, filt_w, str_h, pad_h, out_h, + output_channels, PREC_SYM16S); + } else { required_scratch = xa_nn_conv2d_std_getsize( - input_height, input_depth, filter_height, filter_width, stride_height, - pad_height, output_height, output_channels, PREC_SYM16S); + input_height, input_depth, filter_height, filter_width, + stride_height, pad_height, output_height, output_channels, + PREC_SYM16S); } TF_LITE_ENSURE(context, required_scratch > 0); }