Skip to content

Commit

Permalink
initial refactor of compression unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ddavis-2015 committed Dec 6, 2024
1 parent 2788d32 commit b9c62b9
Show file tree
Hide file tree
Showing 6 changed files with 490 additions and 384 deletions.
148 changes: 79 additions & 69 deletions tensorflow/lite/micro/kernels/conv_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,29 +228,31 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelCompressed) {
int8_t golden_quantized[tflite::testing::kOutputElementsQ1];
int8_t output_quantized[tflite::testing::kOutputElementsQ1];

tflite::testing::TestCompressionQuantizedInfo<int32_t> comp_info = {};
comp_info.scheme = tflite::CompressionScheme::kBinQuant;
tflite::testing::TestCompressionQuantizedInfo2<int8_t> filter_comp_info = {};
tflite::testing::TestCompressionQuantizedInfo2<int32_t> bias_comp_info = {};

comp_info.filter_value_table = filter_quantized;
comp_info.filter_value_table_stride =
filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant;
filter_comp_info.value_table = filter_quantized;
filter_comp_info.value_table_stride =
tflite::testing::kBinQuantFilterValueTableElementsQ1 /
tflite::testing::kNumChannelsQ1;
comp_info.filter_bit_width = tflite::testing::kBinQuantFilterBitWidthQ1;
comp_info.filter_compressed = tflite::testing::kBinQuantFilterDataQ1;
comp_info.filter_data = tflite::testing::kBinQuantFilterValueTableQ1;
comp_info.filter_dims_data = tflite::testing::kFilterShapeQ1;
comp_info.filter_scales = filter_scales;
comp_info.filter_zero_points = filter_zero_points;

comp_info.bias_value_table = bias_quantized;
comp_info.bias_value_table_stride =
filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1;
filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1;
filter_comp_info.data = tflite::testing::kBinQuantFilterValueTableQ1;
filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1;
filter_comp_info.scales = filter_scales;
filter_comp_info.zero_points = filter_zero_points;

bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant;
bias_comp_info.value_table = bias_quantized;
bias_comp_info.value_table_stride =
tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1;
comp_info.bias_bit_width = tflite::testing::kBinQuantBiasBitWidthQ1;
comp_info.bias_compressed = tflite::testing::kBinQuantBiasDataQ1;
comp_info.bias_data = tflite::testing::kBiasDataQ1;
comp_info.bias_dims_data = tflite::testing::kBiasShapeQ1;
comp_info.bias_scales = bias_scales;
comp_info.bias_zero_points = bias_zero_points;
bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1;
bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1;
bias_comp_info.data = tflite::testing::kBiasDataQ1;
bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1;
bias_comp_info.scales = bias_scales;
bias_comp_info.zero_points = bias_zero_points;

TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
Expand All @@ -260,7 +262,7 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelCompressed) {
tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1,
golden_quantized, output_quantized, output_scale, output_zero_point,
&tflite::testing::common_conv_params_q1, tflite::Register_CONV_2D(),
&comp_info));
&filter_comp_info, &bias_comp_info));
}

#endif // USE_TFLM_COMPRESSION
Expand All @@ -282,15 +284,19 @@ TF_LITE_MICRO_TEST(SimpleTestFloat) {
#ifdef USE_TFLM_COMPRESSION

TF_LITE_MICRO_TEST(SimpleTestFloatCompressed) {
tflite::testing::TestCompressionInfo<const float, const float> comp_info = {};
comp_info.scheme = tflite::CompressionScheme::kBinQuant;
comp_info.filter_value_table = tflite::testing::kBinQuantFilterValueTable;
comp_info.filter_value_table_stride =
tflite::testing::TestCompressionInfo<const float> filter_comp_info = {};
tflite::testing::TestCompressionInfo<const float> bias_comp_info = {};

filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant;
filter_comp_info.value_table = tflite::testing::kBinQuantFilterValueTable;
filter_comp_info.value_table_stride =
tflite::testing::kBinQuantFilterValueTableElements;
comp_info.filter_bit_width = tflite::testing::kBinQuantFilterBitWidth;
comp_info.bias_value_table = tflite::testing::kBiasData;
comp_info.bias_value_table_stride = tflite::testing::kBiasElements;
comp_info.bias_bit_width = tflite::testing::kBinQuantBiasBitWidth;
filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidth;

bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant;
bias_comp_info.value_table = tflite::testing::kBiasData;
bias_comp_info.value_table_stride = tflite::testing::kBiasElements;
bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidth;

float output_data[tflite::testing::kOutputElements];

Expand All @@ -304,7 +310,7 @@ TF_LITE_MICRO_TEST(SimpleTestFloatCompressed) {
reinterpret_cast<const float*>(tflite::testing::kBinQuantBiasData),
tflite::testing::kOutputShape, tflite::testing::kGoldenData,
&tflite::testing::common_conv_params, tflite::Register_CONV_2D(),
output_data, &comp_info));
output_data, &filter_comp_info, &bias_comp_info));
}

#endif
Expand Down Expand Up @@ -439,29 +445,31 @@ TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel64bBiasCompressed) {
int16_t golden_quantized[tflite::testing::kOutputElementsQ1];
int16_t output_quantized[tflite::testing::kOutputElementsQ1];

tflite::testing::TestCompressionQuantizedInfo<int64_t> comp_info = {};
comp_info.scheme = tflite::CompressionScheme::kBinQuant;
tflite::testing::TestCompressionQuantizedInfo2<int8_t> filter_comp_info = {};
tflite::testing::TestCompressionQuantizedInfo2<int64_t> bias_comp_info = {};

comp_info.filter_value_table = filter_quantized;
comp_info.filter_value_table_stride =
filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant;
filter_comp_info.value_table = filter_quantized;
filter_comp_info.value_table_stride =
tflite::testing::kBinQuantFilterValueTableElementsQ1 /
tflite::testing::kNumChannelsQ1;
comp_info.filter_bit_width = tflite::testing::kBinQuantFilterBitWidthQ1;
comp_info.filter_compressed = tflite::testing::kBinQuantFilterDataQ1;
comp_info.filter_data = tflite::testing::kBinQuantFilterValueTableQ1;
comp_info.filter_dims_data = tflite::testing::kFilterShapeQ1;
comp_info.filter_scales = filter_scales;
comp_info.filter_zero_points = filter_zero_points;

comp_info.bias_value_table = bias_quantized;
comp_info.bias_value_table_stride =
filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1;
filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1;
filter_comp_info.data = tflite::testing::kBinQuantFilterValueTableQ1;
filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1;
filter_comp_info.scales = filter_scales;
filter_comp_info.zero_points = filter_zero_points;

bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant;
bias_comp_info.value_table = bias_quantized;
bias_comp_info.value_table_stride =
tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1;
comp_info.bias_bit_width = tflite::testing::kBinQuantBiasBitWidthQ1;
comp_info.bias_compressed = tflite::testing::kBinQuantBiasDataQ1;
comp_info.bias_data = tflite::testing::kBiasDataQ1;
comp_info.bias_dims_data = tflite::testing::kBiasShapeQ1;
comp_info.bias_scales = bias_scales;
comp_info.bias_zero_points = bias_zero_points;
bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1;
bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1;
bias_comp_info.data = tflite::testing::kBiasDataQ1;
bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1;
bias_comp_info.scales = bias_scales;
bias_comp_info.zero_points = bias_zero_points;

TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
Expand All @@ -471,7 +479,7 @@ TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel64bBiasCompressed) {
tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1_16,
golden_quantized, output_quantized, output_scale, output_zero_point,
&tflite::testing::common_conv_params_q1, tflite::Register_CONV_2D(),
&comp_info));
&filter_comp_info, &bias_comp_info));
}

#endif // USE_TFLM_COMPRESSION
Expand Down Expand Up @@ -526,29 +534,31 @@ TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel32bBiasCompressed) {
int16_t golden_quantized[tflite::testing::kOutputElementsQ1];
int16_t output_quantized[tflite::testing::kOutputElementsQ1];

tflite::testing::TestCompressionQuantizedInfo<int32_t> comp_info = {};
comp_info.scheme = tflite::CompressionScheme::kBinQuant;
tflite::testing::TestCompressionQuantizedInfo2<int8_t> filter_comp_info = {};
tflite::testing::TestCompressionQuantizedInfo2<int32_t> bias_comp_info = {};

comp_info.filter_value_table = filter_quantized;
comp_info.filter_value_table_stride =
filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant;
filter_comp_info.value_table = filter_quantized;
filter_comp_info.value_table_stride =
tflite::testing::kBinQuantFilterValueTableElementsQ1 /
tflite::testing::kNumChannelsQ1;
comp_info.filter_bit_width = tflite::testing::kBinQuantFilterBitWidthQ1;
comp_info.filter_compressed = tflite::testing::kBinQuantFilterDataQ1;
comp_info.filter_data = tflite::testing::kBinQuantFilterValueTableQ1;
comp_info.filter_dims_data = tflite::testing::kFilterShapeQ1;
comp_info.filter_scales = filter_scales;
comp_info.filter_zero_points = filter_zero_points;

comp_info.bias_value_table = bias_quantized;
comp_info.bias_value_table_stride =
filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1;
filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1;
filter_comp_info.data = tflite::testing::kBinQuantFilterValueTableQ1;
filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1;
filter_comp_info.scales = filter_scales;
filter_comp_info.zero_points = filter_zero_points;

bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant;
bias_comp_info.value_table = bias_quantized;
bias_comp_info.value_table_stride =
tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1;
comp_info.bias_bit_width = tflite::testing::kBinQuantBiasBitWidthQ1;
comp_info.bias_compressed = tflite::testing::kBinQuantBiasDataQ1;
comp_info.bias_data = tflite::testing::kBiasDataQ1;
comp_info.bias_dims_data = tflite::testing::kBiasShapeQ1;
comp_info.bias_scales = bias_scales;
comp_info.bias_zero_points = bias_zero_points;
bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1;
bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1;
bias_comp_info.data = tflite::testing::kBiasDataQ1;
bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1;
bias_comp_info.scales = bias_scales;
bias_comp_info.zero_points = bias_zero_points;

TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
Expand All @@ -558,7 +568,7 @@ TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel32bBiasCompressed) {
tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1_16,
golden_quantized, output_quantized, output_scale, output_zero_point,
&tflite::testing::common_conv_params_q1, tflite::Register_CONV_2D(),
&comp_info));
&filter_comp_info, &bias_comp_info));
}

#endif // USE_TFLM_COMPRESSION
Expand Down
61 changes: 32 additions & 29 deletions tensorflow/lite/micro/kernels/conv_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,33 +61,35 @@ TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size,
return runner.Invoke();
}

template <typename T, typename CTF = void, typename CTB = void>
template <typename T, typename TF = void, typename TB = void>
TfLiteStatus ValidateConvGoldens(
TfLiteTensor* tensors, int tensors_size, const T* expected_output_data,
int output_length, const TfLiteConvParams* conv_params,
TFLMRegistration registration, T* output_data, float tolerance = 1e-5
#ifdef USE_TFLM_COMPRESSION
,
const TestCompressionInfo<CTF, CTB>* comp_info = nullptr
const TestCompressionInfo<TF>* filter_comp_info = nullptr,
const TestCompressionInfo<TB>* bias_comp_info = nullptr
#endif // USE_TFLM_COMPRESSION
) {
#ifdef USE_TFLM_COMPRESSION

TestCompressedList<kConvMaxInputTensors, CTF, CTB> tcl;
const CompressedTensorList* comp_list_p = nullptr;

if (comp_info != nullptr) {
TestCompressedList<kConvMaxInputTensors, TF, TB> tcl;
if (filter_comp_info != nullptr) {
TF_LITE_MICRO_EXPECT_EQ(
tcl.AddWeight(*comp_info, tensors[kConvWeightsTensor],
kConvWeightsTensor),
tcl.AddInput(*filter_comp_info, tensors[kConvWeightsTensor],
kConvWeightsTensor),
kTfLiteOk);
TF_LITE_MICRO_CHECK_FAIL();
}
if (bias_comp_info) {
TF_LITE_MICRO_EXPECT_EQ(
tcl.AddBias(*comp_info, tensors[kConvBiasTensor], kConvBiasTensor),
tcl.AddInput(*bias_comp_info, tensors[kConvBiasTensor],
kConvBiasTensor),
kTfLiteOk);
TF_LITE_MICRO_CHECK_FAIL();
comp_list_p = tcl.GetCompressedTensorList();
}
const CompressedTensorList* comp_list_p = tcl.GetCompressedTensorList();

#endif // USE_TFLM_COMPRESSION

Expand All @@ -108,7 +110,6 @@ TfLiteStatus ValidateConvGoldens(
return kTfLiteOk;
}

template <typename CTF = void, typename CTB = void>
TfLiteStatus TestConvFloat(
int* input_dims_data, const float* input_data, int* filter_dims_data,
const float* filter_data, int* bias_dims_data, const float* bias_data,
Expand All @@ -117,7 +118,8 @@ TfLiteStatus TestConvFloat(
float* output_data
#ifdef USE_TFLM_COMPRESSION
,
const TestCompressionInfo<CTF, CTB>* comp_info = nullptr
const TestCompressionInfo<const float>* filter_comp_info = nullptr,
const TestCompressionInfo<const float>* bias_comp_info = nullptr
#endif // USE_TFLM_COMPRESSION
) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
Expand All @@ -140,7 +142,7 @@ TfLiteStatus TestConvFloat(
output_data
#ifdef USE_TFLM_COMPRESSION
,
1e-5f, comp_info
1e-5f, filter_comp_info, bias_comp_info
#endif // USE_TFLM_COMPRESSION
);
}
Expand Down Expand Up @@ -179,48 +181,49 @@ TfLiteStatus TestConvQuantizedPerChannel(

#ifdef USE_TFLM_COMPRESSION

template <typename TIO, typename CTB>
template <typename TIO, typename TBIAS>
TfLiteStatus TestConvQuantizedPerChannelCompressed(
int* input_dims_data, const float* input_data, TIO* input_quantized,
float input_scale, int input_zero_point, int* output_dims_data,
const float* expected_output_data, TIO* expected_output_quantized,
TIO* output_quantized, float output_scale, int output_zero_point,
const TfLiteConvParams* conv_params, TFLMRegistration registration,
const TestCompressionQuantizedInfo<CTB>* comp_info) {
const TestCompressionQuantizedInfo2<int8_t>* filter_comp_info,
const TestCompressionQuantizedInfo2<TBIAS>* bias_comp_info) {
// TODO(b/358165875): account for optional bias tensor
// bool null_bias = comp_info->bias_data == nullptr ? true : false;

TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* filter_dims = IntArrayFromInts(comp_info->filter_dims_data);
TfLiteIntArray* bias_dims = IntArrayFromInts(comp_info->bias_dims_data);
TfLiteIntArray* filter_dims = IntArrayFromInts(filter_comp_info->dims_data);
TfLiteIntArray* bias_dims = IntArrayFromInts(bias_comp_info->dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);

TfLiteFloatArray* filter_scales =
FloatArrayFromFloats(comp_info->filter_scales);
FloatArrayFromFloats(filter_comp_info->scales);
TfLiteIntArray* filter_zero_points =
IntArrayFromInts(comp_info->filter_zero_points);
TfLiteFloatArray* bias_scales = FloatArrayFromFloats(comp_info->bias_scales);
IntArrayFromInts(filter_comp_info->zero_points);
TfLiteFloatArray* bias_scales = FloatArrayFromFloats(bias_comp_info->scales);
TfLiteIntArray* bias_zero_points =
IntArrayFromInts(comp_info->bias_zero_points);
IntArrayFromInts(bias_comp_info->zero_points);

TfLiteAffineQuantization filter_quant = {};
TfLiteTensor filter_tensor = CreatePerChannelQuantizedTensor(
comp_info->filter_compressed, filter_dims, filter_scales,
filter_comp_info->compressed, filter_dims, filter_scales,
filter_zero_points, &filter_quant, kConvQuantizedDimension,
false /* is_variable */, kTfLiteInt8);
SymmetricPerChannelQuantize(
comp_info->filter_data, comp_info->filter_value_table,
filter_scales->size * comp_info->filter_value_table_stride,
filter_comp_info->data, filter_comp_info->value_table,
filter_scales->size * filter_comp_info->value_table_stride,
filter_scales->size, filter_scales->data);

TfLiteAffineQuantization bias_quant = {};
TfLiteTensor bias_tensor = CreatePerChannelQuantizedBiasTensor(
comp_info->bias_compressed, bias_dims, input_scale, filter_scales,
bias_comp_info->compressed, bias_dims, input_scale, filter_scales,
bias_scales, bias_zero_points, &bias_quant, kConvQuantizedDimension,
false /* is_variable */, typeToTfLiteType<CTB>());
false /* is_variable */, typeToTfLiteType<TBIAS>());
SymmetricPerChannelQuantize(
comp_info->bias_data, comp_info->bias_value_table,
bias_scales->size * comp_info->bias_value_table_stride, bias_scales->size,
bias_comp_info->data, bias_comp_info->value_table,
bias_scales->size * bias_comp_info->value_table_stride, bias_scales->size,
bias_scales->data);

constexpr int tensors_size = kConvMaxTensors;
Expand All @@ -239,7 +242,7 @@ TfLiteStatus TestConvQuantizedPerChannelCompressed(
return ValidateConvGoldens(tensors, tensors_size, expected_output_quantized,
output_dims_count, conv_params, registration,
output_quantized, 1.0e-5f /* tolerance */,
comp_info);
filter_comp_info, bias_comp_info);
}

#endif // USE_TFLM_COMPRESSION
Expand Down
Loading

0 comments on commit b9c62b9

Please sign in to comment.