diff --git a/tensorflow/lite/micro/kernels/conv_test.h b/tensorflow/lite/micro/kernels/conv_test.h index 5cb71bbf7b6..642f4c76d7a 100644 --- a/tensorflow/lite/micro/kernels/conv_test.h +++ b/tensorflow/lite/micro/kernels/conv_test.h @@ -121,31 +121,7 @@ TfLiteStatus TestConvFloat( const TestCompressionInfo* filter_comp_info = nullptr, const TestCompressionInfo* bias_comp_info = nullptr #endif // USE_TFLM_COMPRESSION -) { - TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); - TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); - TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); - TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); - const int output_dims_count = ElementCount(*output_dims); - constexpr int inputs_size = 3; - constexpr int outputs_size = 1; - constexpr int tensors_size = inputs_size + outputs_size; - TfLiteTensor tensors[tensors_size] = { - CreateTensor(input_data, input_dims), - CreateTensor(filter_data, filter_dims), - CreateTensor(bias_data, bias_dims), - CreateTensor(output_data, output_dims), - }; - - return ValidateConvGoldens(tensors, tensors_size, expected_output_data, - output_dims_count, conv_params, registration, - output_data -#ifdef USE_TFLM_COMPRESSION - , - 1e-5f, filter_comp_info, bias_comp_info -#endif // USE_TFLM_COMPRESSION - ); -} +); TfLiteStatus TestConvQuantizedPerChannel( int* input_dims_data, const float* input_data, int8_t* input_quantized, diff --git a/tensorflow/lite/micro/kernels/conv_test_common.cc b/tensorflow/lite/micro/kernels/conv_test_common.cc index 7b6f71a8fc3..3825e05373c 100644 --- a/tensorflow/lite/micro/kernels/conv_test_common.cc +++ b/tensorflow/lite/micro/kernels/conv_test_common.cc @@ -18,6 +18,43 @@ limitations under the License. namespace tflite { namespace testing { +TfLiteStatus TestConvFloat( + int* input_dims_data, const float* input_data, int* filter_dims_data, + const float* filter_data, int* bias_dims_data, const float* bias_data, + int* output_dims_data, const float* expected_output_data, + TfLiteConvParams* conv_params, TFLMRegistration registration, + float* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info, + const TestCompressionInfo* bias_comp_info +#endif // USE_TFLM_COMPRESSION +) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateTensor(input_data, input_dims), + CreateTensor(filter_data, filter_dims), + CreateTensor(bias_data, bias_dims), + CreateTensor(output_data, output_dims), + }; + + return ValidateConvGoldens(tensors, tensors_size, expected_output_data, + output_dims_count, conv_params, registration, + output_data +#ifdef USE_TFLM_COMPRESSION + , + 1e-5f, filter_comp_info, bias_comp_info +#endif // USE_TFLM_COMPRESSION + ); +} + template TfLiteStatus TestConvQuantizedPerChannel( int* input_dims_data, const float* input_data, T* input_quantized,