Skip to content

Commit

Permalink
Fixed scratch size calculation for conv for HiFi targets for scenario…
Browse files Browse the repository at this point in the history
…s when input, filter and output heights are 1. (#3004)

Fixed scratch size calculation for conv for HiFi targets for scenarios when input, filter and output heights are 1.
BUG=conv_hifi_scratch_size
  • Loading branch information
cad-audio authored Dec 13, 2024
1 parent aa3f6f3 commit 3bd1dd3
Showing 1 changed file with 38 additions and 6 deletions.
44 changes: 38 additions & 6 deletions tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,28 +77,60 @@ TfLiteStatus ConvPrepareHifi(TfLiteContext* context, TfLiteNode* node) {
}

const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int output_channels = output_shape.Dims(3);
const int stride_height = params->stride_height;
const int stride_width = params->stride_width;
const int pad_height = data->reference_op_data.padding.height;
const int pad_width = data->reference_op_data.padding.width;

int required_scratch = 0;
// TODO(b/277112516): Dilation is currently not supported on HiFi 4 NN Library
if ((params->dilation_width_factor == 1) &&
(params->dilation_height_factor == 1)) {
if (input->type == kTfLiteInt8) {
required_scratch = xa_nn_conv2d_std_getsize(
input_height, input_depth, filter_height, filter_width, stride_height,
pad_height, output_height, output_channels, PREC_ASYM8S);
if (input_height == 1 && filter_height == 1 && output_height == 1) {
int inp_h, filt_h, filt_w, str_h, pad_h, out_h;
inp_h = input_width;
filt_h = filter_width;
filt_w = filter_height;
str_h = stride_width;
pad_h = pad_width;
out_h = output_width;
required_scratch = xa_nn_conv2d_std_getsize(
inp_h, input_depth, filt_h, filt_w, str_h, pad_h, out_h,
output_channels, PREC_ASYM8S);
} else {
required_scratch = xa_nn_conv2d_std_getsize(
input_height, input_depth, filter_height, filter_width,
stride_height, pad_height, output_height, output_channels,
PREC_ASYM8S);
}
TF_LITE_ENSURE(context, required_scratch > 0);
}
if (input->type == kTfLiteInt16) {
required_scratch = xa_nn_conv2d_std_getsize(
input_height, input_depth, filter_height, filter_width, stride_height,
pad_height, output_height, output_channels, PREC_SYM16S);
if (input_height == 1 && filter_height == 1 && output_height == 1) {
int inp_h, filt_h, filt_w, str_h, pad_h, out_h;
inp_h = input_width;
filt_h = filter_width;
filt_w = filter_height;
str_h = stride_width;
pad_h = pad_width;
out_h = output_width;
required_scratch = xa_nn_conv2d_std_getsize(
inp_h, input_depth, filt_h, filt_w, str_h, pad_h, out_h,
output_channels, PREC_SYM16S);
} else {
required_scratch = xa_nn_conv2d_std_getsize(
input_height, input_depth, filter_height, filter_width,
stride_height, pad_height, output_height, output_channels,
PREC_SYM16S);
}
TF_LITE_ENSURE(context, required_scratch > 0);
}
}
Expand Down

0 comments on commit 3bd1dd3

Please sign in to comment.