Skip to content

Commit

Permalink
Fixed scratch size calculation for conv for HiFi targets for scenario…
Browse files Browse the repository at this point in the history
…s when input, filter and output heights are 1.
  • Loading branch information
cad-audio committed Dec 10, 2024
1 parent 4a8bb6b commit 5b14dd1
Showing 1 changed file with 42 additions and 6 deletions.
48 changes: 42 additions & 6 deletions tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,28 +77,64 @@ TfLiteStatus ConvPrepareHifi(TfLiteContext* context, TfLiteNode* node) {
}

const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int output_channels = output_shape.Dims(3);
const int stride_height = params->stride_height;
const int stride_width = params->stride_width;
const int pad_height = data->reference_op_data.padding.height;
const int pad_width = data->reference_op_data.padding.width;

int required_scratch = 0;
// TODO(b/277112516): Dilation is currently not supported on HiFi 4 NN Library
if ((params->dilation_width_factor == 1) &&
(params->dilation_height_factor == 1)) {
if (input->type == kTfLiteInt8) {
required_scratch = xa_nn_conv2d_std_getsize(
input_height, input_depth, filter_height, filter_width, stride_height,
pad_height, output_height, output_channels, PREC_ASYM8S);
if (input_height == 1 && filter_height == 1 && output_height == 1)
{
int inp_h, filt_h, filt_w, str_h, pad_h, out_h;
inp_h = input_width;
filt_h = filter_width;
filt_w = filter_height;
str_h = stride_width;
pad_h = pad_width;
out_h = output_width;
required_scratch = xa_nn_conv2d_std_getsize(
inp_h, input_depth, filt_h, filt_w, str_h,
pad_h, out_h, output_channels, PREC_ASYM8S);
}
else
{
required_scratch = xa_nn_conv2d_std_getsize(
input_height, input_depth, filter_height, filter_width, stride_height,
pad_height, output_height, output_channels, PREC_ASYM8S);
}
TF_LITE_ENSURE(context, required_scratch > 0);
}
if (input->type == kTfLiteInt16) {
required_scratch = xa_nn_conv2d_std_getsize(
input_height, input_depth, filter_height, filter_width, stride_height,
pad_height, output_height, output_channels, PREC_SYM16S);
if (input_height == 1 && filter_height == 1 && output_height == 1)
{
int inp_h, filt_h, filt_w, str_h, pad_h, out_h;
inp_h = input_width;
filt_h = filter_width;
filt_w = filter_height;
str_h = stride_width;
pad_h = pad_width;
out_h = output_width;
required_scratch = xa_nn_conv2d_std_getsize(
inp_h, input_depth, filt_h, filt_w, str_h,
pad_h, out_h, output_channels, PREC_SYM16S);
}
else
{
required_scratch = xa_nn_conv2d_std_getsize(
input_height, input_depth, filter_height, filter_width, stride_height,
pad_height, output_height, output_channels, PREC_SYM16S);
}
TF_LITE_ENSURE(context, required_scratch > 0);
}
}
Expand Down

0 comments on commit 5b14dd1

Please sign in to comment.