From 3bd1dd3083323914c6b0cb50c659a891c1af55b6 Mon Sep 17 00:00:00 2001 From: cad-audio <86048415+cad-audio@users.noreply.github.com> Date: Fri, 13 Dec 2024 11:46:57 -0800 Subject: [PATCH] Fixed scratch size calculation for conv for HiFi targets for scenarios when input, filter and output heights are 1. (#3004) Fixed scratch size calculation for conv for HiFi targets for scenarios when input, filter and output heights are 1. BUG=conv_hifi_scratch_size --- .../lite/micro/kernels/xtensa/conv_hifi.cc | 44 ++++++++++++++++--- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc b/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc index 1d2d7ec253e..f17809484d6 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc @@ -77,28 +77,60 @@ TfLiteStatus ConvPrepareHifi(TfLiteContext* context, TfLiteNode* node) { } const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); const int filter_height = filter_shape.Dims(1); const int filter_width = filter_shape.Dims(2); const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); const int output_channels = output_shape.Dims(3); const int stride_height = params->stride_height; + const int stride_width = params->stride_width; const int pad_height = data->reference_op_data.padding.height; + const int pad_width = data->reference_op_data.padding.width; int required_scratch = 0; // TODO(b/277112516): Dilation is currently not supported on HiFi 4 NN Library if ((params->dilation_width_factor == 1) && (params->dilation_height_factor == 1)) { if (input->type == kTfLiteInt8) { - required_scratch = xa_nn_conv2d_std_getsize( - input_height, input_depth, filter_height, filter_width, stride_height, - pad_height, output_height, output_channels, PREC_ASYM8S); + if (input_height == 1 && filter_height == 1 && output_height == 1) { + int inp_h, filt_h, filt_w, str_h, pad_h, out_h; + inp_h = input_width; + filt_h = filter_width; + filt_w = filter_height; + str_h = stride_width; + pad_h = pad_width; + out_h = output_width; + required_scratch = xa_nn_conv2d_std_getsize( + inp_h, input_depth, filt_h, filt_w, str_h, pad_h, out_h, + output_channels, PREC_ASYM8S); + } else { + required_scratch = xa_nn_conv2d_std_getsize( + input_height, input_depth, filter_height, filter_width, + stride_height, pad_height, output_height, output_channels, + PREC_ASYM8S); + } TF_LITE_ENSURE(context, required_scratch > 0); } if (input->type == kTfLiteInt16) { - required_scratch = xa_nn_conv2d_std_getsize( - input_height, input_depth, filter_height, filter_width, stride_height, - pad_height, output_height, output_channels, PREC_SYM16S); + if (input_height == 1 && filter_height == 1 && output_height == 1) { + int inp_h, filt_h, filt_w, str_h, pad_h, out_h; + inp_h = input_width; + filt_h = filter_width; + filt_w = filter_height; + str_h = stride_width; + pad_h = pad_width; + out_h = output_width; + required_scratch = xa_nn_conv2d_std_getsize( + inp_h, input_depth, filt_h, filt_w, str_h, pad_h, out_h, + output_channels, PREC_SYM16S); + } else { + required_scratch = xa_nn_conv2d_std_getsize( + input_height, input_depth, filter_height, filter_width, + stride_height, pad_height, output_height, output_channels, + PREC_SYM16S); + } TF_LITE_ENSURE(context, required_scratch > 0); } }