Skip to content

Commit

Permalink
decompression unit test improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
ddavis-2015 committed Oct 18, 2024
1 parent b318421 commit 81ecf2e
Showing 1 changed file with 146 additions and 38 deletions.
184 changes: 146 additions & 38 deletions tensorflow/lite/micro/kernels/decompress_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct TestingInfo {
};

template <typename T>
struct TestingData7_2_256 {
struct TestingData {
static constexpr size_t kBitWidth = 7;
static constexpr size_t kChannels = 2;
static constexpr size_t kElementsPerChannel = 256;
Expand All @@ -63,14 +63,45 @@ struct TestingData7_2_256 {
T goldens[kTotalElements];
};

TestingData7_2_256<bool> TestingData7_2_256_Bool;
#ifdef notyet
TestingData7_2_256<float> TestingData7_2_256_Float32;
TestingData7_2_256<int8_t> TestingData7_2_256_Int8;
TestingData7_2_256<int16_t> TestingData7_2_256_Int16;
TestingData7_2_256<int32_t> TestingData7_2_256_Int32;
TestingData7_2_256<int64_t> TestingData7_2_256_Int64;
#endif // notyet
TestingData<bool> TestingData_Bool;
TestingData<float> TestingData_Float32;
TestingData<int8_t> TestingData_Int8;
TestingData<int16_t> TestingData_Int16;
TestingData<int32_t> TestingData_Int32;
TestingData<int64_t> TestingData_Int64;

template <typename T>
TestingData<T>* GetTestingData();

template <>
TestingData<bool>* GetTestingData() {
return &TestingData_Bool;
}

template <>
TestingData<float>* GetTestingData() {
return &TestingData_Float32;
}

template <>
TestingData<int8_t>* GetTestingData() {
return &TestingData_Int8;
}

template <>
TestingData<int16_t>* GetTestingData() {
return &TestingData_Int16;
}

template <>
TestingData<int32_t>* GetTestingData() {
return &TestingData_Int32;
}

template <>
TestingData<int64_t>* GetTestingData() {
return &TestingData_Int64;
}

template <typename T>
void FillValueTable(const size_t total_elements, T* value_table) {
Expand All @@ -81,7 +112,6 @@ void FillValueTable(const size_t total_elements, T* value_table) {
}
}

#ifdef notyet
template <>
void FillValueTable(const size_t total_elements, float* value_table) {
float fill_value = -1.1f;
Expand All @@ -90,7 +120,6 @@ void FillValueTable(const size_t total_elements, float* value_table) {
fill_value -= 1.0f;
}
}
#endif // notyet

template <>
void FillValueTable(const size_t total_elements, bool* value_table) {
Expand Down Expand Up @@ -127,8 +156,8 @@ void FillGoldens(const size_t total_elements, T* goldens,
for (size_t channel = 0; channel < channels; channel++) {
for (size_t i = 0; i < elements_per_channel; i++) {
goldens[(channel * elements_per_channel) + i] =
value_table[(channel * value_table_stride) + value_table_index++];
if (value_table_index == value_table_stride) {
value_table[(channel * value_table_stride) + value_table_index];
if (++value_table_index == value_table_stride) {
value_table_index = 0;
}
}
Expand Down Expand Up @@ -163,7 +192,7 @@ void FillCompressed(uint8_t* compressed, const size_t total_golden_elements,
for (size_t group = 0; group < golden_element_groups; group++) {
for (size_t channel = 0; channel < channels; channel++) {
size_t value_table_index = FindValueTableIndex(
goldens[(group * golden_element_groups) + channel],
goldens[(group * channels) + channel],
&value_table[channel * value_table_stride], value_table_stride);
bits |= value_table_index << (16 - bits_accumulated - bit_width);
bits_accumulated += bit_width;
Expand Down Expand Up @@ -193,10 +222,14 @@ void FillCompressed(uint8_t* compressed, const size_t total_golden_elements,
value_table += value_table_stride;
}
}

if (bits_accumulated > 0) {
*compressed = static_cast<uint8_t>(bits >> 8);
}
}

template <typename T>
TfLiteStatus TestDecompression(TestingInfo<T>* info) {
void TestDecompression(TestingInfo<T>* info) {
CompressionTensorData ctd = {};
LookupTableData lut_data = {};
ctd.scheme = CompressionScheme::kBinQuant;
Expand All @@ -211,36 +244,22 @@ TfLiteStatus TestDecompression(TestingInfo<T>* info) {
DecompressionState ds(info->compressed, info->total_elements, ctd,
info->channel_count);

std::fill_n(info->output, info->total_elements, ~0ULL);
std::fill_n(info->output, info->total_elements, static_cast<T>(~0ULL));
ds.DecompressToBuffer<T>(info->output);

bool saved_fail_state = micro_test::did_test_fail;
micro_test::did_test_fail = false;
for (size_t i = 0; i < info->total_elements; i++) {
TF_LITE_MICRO_EXPECT_EQ(info->goldens[i], info->output[i]);
TF_LITE_MICRO_CHECK_FAIL();
if (micro_test::did_test_fail) {
return;
}
}

return kTfLiteOk;
micro_test::did_test_fail = saved_fail_state;
}

template <typename T>
void TestBitWidth(size_t bit_width) {
MicroPrintf(" Testing bit width %d", bit_width);

TestingInfo<T> info = {};

if (std::is_same<T, bool>::value) {
info.output = TestingData7_2_256_Bool.output;
info.goldens = TestingData7_2_256_Bool.goldens;
info.compressed = TestingData7_2_256_Bool.compressed;
info.value_table = TestingData7_2_256_Bool.value_table;
}

info.bit_width = bit_width;
info.channel_count = 1;
info.total_elements = 16;
info.total_value_table_elements = 1 << bit_width;
info.use_alt_axis = false;

void GenerateData(TestingInfo<T>& info) {
FillValueTable(info.total_value_table_elements, info.value_table);
FillGoldens(info.total_elements, info.goldens,
info.total_value_table_elements, info.value_table,
Expand All @@ -249,14 +268,98 @@ void TestBitWidth(size_t bit_width) {
info.total_value_table_elements / info.channel_count,
info.value_table, info.channel_count, info.use_alt_axis,
info.bit_width);
}

template <typename T>
void TestDataSetup(TestingInfo<T>* info, TestingData<T>* data) {
info->output = data->output;
info->goldens = data->goldens;
info->compressed = data->compressed;
info->value_table = data->value_table;
}

template <typename T>
void TestValueTable2n(TestingInfo<T>& info) {
info.total_elements = 16;
if (std::is_same<T, bool>::value) {
info.total_value_table_elements = 2 * info.channel_count;
} else {
info.total_value_table_elements =
(1 << info.bit_width) * info.channel_count;
info.total_value_table_elements =
std::min(info.total_value_table_elements, info.total_elements);
}

MicroPrintf(" Testing value table 2^n: %d",
info.total_value_table_elements);
GenerateData(info);
TestDecompression(&info);
}

template <typename T>
void TestValueTable2nMinus1(TestingInfo<T>& info) {
info.total_elements = 16;
if (std::is_same<T, bool>::value) {
info.total_value_table_elements = 1 * info.channel_count;
} else {
info.total_value_table_elements =
((1 << info.bit_width) - 1) * info.channel_count;
info.total_value_table_elements =
std::min(info.total_value_table_elements, info.total_elements);
}

MicroPrintf(" Testing value table 2^n-1: %d",
info.total_value_table_elements);
GenerateData(info);
TestDecompression(&info);
}

template <typename T>
void TestSingleChannel(TestingInfo<T>& info) {
info.channel_count = 1;

MicroPrintf(" Testing single channel");
TestValueTable2n(info);
TestValueTable2nMinus1(info);
}

template <typename T>
void TestMultiChannel(TestingInfo<T>& info) {
info.channel_count = 2;

MicroPrintf(" Testing multiple channels: %d", info.channel_count);
TestValueTable2n(info);
TestValueTable2nMinus1(info);
}

template <typename T>
void TestBitWidth(TestingInfo<T>& info) {
info.use_alt_axis = false;

MicroPrintf(" Testing bit width %d", info.bit_width);
TestSingleChannel(info);
TestMultiChannel(info);
}

template <typename T>
void TestBitWidthAltAxis(TestingInfo<T>& info) {
info.use_alt_axis = true;

MicroPrintf(" Testing alt-axis bit width %d", info.bit_width);
TestSingleChannel(info);
TestMultiChannel(info);
}

template <typename T>
void TestAllBitWidths() {
TestingInfo<T> info = {};
TestDataSetup<T>(&info, GetTestingData<T>());

for (size_t bw = 1; bw <= 7; bw++) {
TestBitWidth<T>(bw);
info.bit_width = bw;

TestBitWidth<T>(info);
TestBitWidthAltAxis<T>(info);
}
}

Expand All @@ -267,6 +370,11 @@ void TestAllBitWidths() {
TF_LITE_MICRO_TESTS_BEGIN

TF_LITE_MICRO_TEST(TestBool) { tflite::testing::TestAllBitWidths<bool>(); }
TF_LITE_MICRO_TEST(TestFloat) { tflite::testing::TestAllBitWidths<float>(); }
TF_LITE_MICRO_TEST(TestInt8) { tflite::testing::TestAllBitWidths<int8_t>(); }
TF_LITE_MICRO_TEST(TestInt16) { tflite::testing::TestAllBitWidths<int16_t>(); }
TF_LITE_MICRO_TEST(TestInt32) { tflite::testing::TestAllBitWidths<int32_t>(); }
TF_LITE_MICRO_TEST(TestInt64) { tflite::testing::TestAllBitWidths<int64_t>(); }

TF_LITE_MICRO_TESTS_END

Expand Down

0 comments on commit 81ecf2e

Please sign in to comment.