From 2388549c478a0eed4c93604de73d1743618d03ed Mon Sep 17 00:00:00 2001 From: ddavis-2015 Date: Thu, 10 Oct 2024 17:50:29 -0700 Subject: [PATCH] refactor decompression code into reference and platform specific Apply some Xtensa acceleration code changes --- tensorflow/lite/micro/kernels/decompress.cc | 65 ++ tensorflow/lite/micro/kernels/decompress.h | 82 ++ .../lite/micro/kernels/decompress_common.cc | 561 +++++++++++++ .../lite/micro/kernels/xtensa/decompress.cc | 245 ++++++ tensorflow/lite/micro/micro_context.cc | 739 +----------------- tensorflow/lite/micro/tools/make/Makefile | 2 + 6 files changed, 956 insertions(+), 738 deletions(-) create mode 100644 tensorflow/lite/micro/kernels/decompress.cc create mode 100644 tensorflow/lite/micro/kernels/decompress.h create mode 100644 tensorflow/lite/micro/kernels/decompress_common.cc create mode 100644 tensorflow/lite/micro/kernels/xtensa/decompress.cc diff --git a/tensorflow/lite/micro/kernels/decompress.cc b/tensorflow/lite/micro/kernels/decompress.cc new file mode 100644 index 00000000000..78ee84931a7 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decompress.cc @@ -0,0 +1,65 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef USE_TFLM_COMPRESSION + +#include "tensorflow/lite/micro/kernels/decompress.h" + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/micro/micro_common.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" +#include "tensorflow/lite/micro/micro_utils.h" +#include "tensorflow/lite/portable_type_to_tflitetype.h" + +namespace tflite { + +template +T* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 4 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth4_16(static_cast(buffer)); + } else if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 3 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth3_32(static_cast(buffer)); + } else if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 2 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth2_16(static_cast(buffer)); + } else { + DecompressToBufferWidthAny(static_cast(buffer)); + } + + return static_cast(buffer); +} + +template bool* DecompressionState::DecompressToBuffer(void*); +template float* DecompressionState::DecompressToBuffer(void*); +template int8_t* DecompressionState::DecompressToBuffer(void*); +template int16_t* DecompressionState::DecompressToBuffer(void*); +template int32_t* DecompressionState::DecompressToBuffer(void*); +template int64_t* DecompressionState::DecompressToBuffer(void*); + +} // namespace tflite + +#endif // USE_TFLM_COMPRESSION diff --git a/tensorflow/lite/micro/kernels/decompress.h b/tensorflow/lite/micro/kernels/decompress.h new file mode 100644 index 00000000000..2ce9501cb90 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decompress.h @@ -0,0 +1,82 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/lite/micro/micro_context.h" + +namespace tflite { + +#ifdef USE_TFLM_COMPRESSION + +struct DecompressionState { + DecompressionState() = delete; + + DecompressionState(const uint8_t* compressed_indices, + const size_t count_indices, + const CompressionTensorData& comp_data, + const size_t num_channels, MicroContext* micro_context) + : compressed_indices_(compressed_indices), + count_indices_(count_indices), + comp_data_(comp_data), + num_channels_(num_channels), + micro_context_(micro_context) {} + + DecompressionState(const DecompressionState& other) + : compressed_indices_(other.compressed_indices_), + count_indices_(other.count_indices_), + comp_data_(other.comp_data_), + num_channels_(other.num_channels_), + micro_context_(other.micro_context_) {} + + template + T* DecompressToBuffer(void* buffer); + + protected: + // optimized C++ for INT8, use_alt_axis == false + void DecompressToBufferWidth4_16(int8_t* buffer); + void DecompressToBufferWidth3_32(int8_t* buffer); + void DecompressToBufferWidth2_16(int8_t* buffer); + + // generic C++ for any bit width and value table type + template + void DecompressToBufferWidthAny(T* buffer); + + // Optimized C++ table index fetch + inline size_t GetNextTableIndexWidth7(const size_t current_offset); + inline size_t GetNextTableIndexWidth6(const size_t current_offset); + inline size_t GetNextTableIndexWidth5(const size_t current_offset); + inline size_t GetNextTableIndexWidth4(const size_t current_offset); + inline size_t GetNextTableIndexWidth3(const size_t current_offset); + inline size_t GetNextTableIndexWidth2(const size_t current_offset); + inline size_t GetNextTableIndexWidth1(const size_t current_offset); + + protected: + const uint8_t* compressed_indices_; + const size_t count_indices_; + const CompressionTensorData& comp_data_; + const size_t num_channels_; + const size_t compressed_bit_width_ = + comp_data_.data.lut_data->compressed_bit_width; + const size_t elements_per_channel_ = + comp_data_.data.lut_data->use_alternate_axis + ? 1 + : count_indices_ / num_channels_; + MicroContext* micro_context_; +}; + +#endif // USE_TFLM_COMPRESSION + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/decompress_common.cc b/tensorflow/lite/micro/kernels/decompress_common.cc new file mode 100644 index 00000000000..ce8deda6e84 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decompress_common.cc @@ -0,0 +1,561 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef USE_TFLM_COMPRESSION + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/micro/kernels/decompress.h" +#include "tensorflow/lite/micro/micro_common.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" +#include "tensorflow/lite/micro/micro_utils.h" +#include "tensorflow/lite/portable_type_to_tflitetype.h" + +namespace tflite { + +void DecompressionState::DecompressToBufferWidth4_16(int8_t* buffer) { + MicroProfiler* profiler = + static_cast(micro_context_->external_context()); + ScopedMicroProfiler scoped_profiler(__func__, profiler); + + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint8_t* value_table = + static_cast(comp_data_.data.lut_data->value_table); + const size_t max_count = elements_per_channel_; + size_t current_offset = 0; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + + // process elements at start of channel up to next uint64_t alignment of + // compressed_indices_ + while (count > 0 && (current_offset & 0x0F)) { + const size_t index = GetNextTableIndexWidth4(current_offset++); + *buffer++ = value_table[index]; + count -= 1; + } + + // process elements in current channel in groups of 16 + if (count >= 16) { + const uint64_t* indices = reinterpret_cast( + &compressed_indices_[current_offset >> 1]); + + while (count >= 16) { + count -= 16; + uint64_t index = *indices++; + uint64_t value, value2; + + value = static_cast(value_table[(index >> 4) & 0x0F]); + value |= static_cast(value_table[index & 0x0F]) << 8; + value |= static_cast(value_table[(index >> 12) & 0x0F]) << 16; + value |= static_cast(value_table[(index >> 8) & 0x0F]) << 24; + value |= static_cast(value_table[(index >> 20) & 0x0F]) << 32; + value |= static_cast(value_table[(index >> 16) & 0x0F]) << 40; + value |= static_cast(value_table[(index >> 28) & 0x0F]) << 48; + value |= static_cast(value_table[(index >> 24) & 0x0F]) << 56; + + *reinterpret_cast(buffer) = value; + + value2 = static_cast(value_table[(index >> 36) & 0x0F]); + value2 |= static_cast(value_table[(index >> 32) & 0x0F]) << 8; + value2 |= static_cast(value_table[(index >> 44) & 0x0F]) + << 16; + value2 |= static_cast(value_table[(index >> 40) & 0x0F]) + << 24; + value2 |= static_cast(value_table[(index >> 52) & 0x0F]) + << 32; + value2 |= static_cast(value_table[(index >> 48) & 0x0F]) + << 40; + value2 |= static_cast(value_table[(index >> 60) & 0x0F]) + << 48; + value2 |= static_cast(value_table[(index >> 56) & 0x0F]) + << 56; + + *reinterpret_cast(buffer + 8) = value2; + + buffer += 16; + } + + current_offset = + (reinterpret_cast(indices) - compressed_indices_) + << 1; + } + + // process remaining elements in current channel + while (count > 0) { + count -= 1; + const size_t index = GetNextTableIndexWidth4(current_offset++); + *buffer++ = value_table[index]; + } + + value_table += stride; + } +} + +void DecompressionState::DecompressToBufferWidth2_16(int8_t* buffer) { + MicroProfiler* profiler = + static_cast(micro_context_->external_context()); + ScopedMicroProfiler scoped_profiler(__func__, profiler); + + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint8_t* value_table = + static_cast(comp_data_.data.lut_data->value_table); + const size_t max_count = elements_per_channel_; + size_t current_offset = 0; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + + // process elements at start of channel up to next uint32_t alignment of + // compressed_indices_ + while (count > 0 && (current_offset & 0x0F)) { + const size_t index = GetNextTableIndexWidth2(current_offset++); + *buffer++ = value_table[index]; + count -= 1; + } + + // process elements in current channel in groups of 16 + if (count >= 16) { + const uint32_t* indices = reinterpret_cast( + &compressed_indices_[current_offset >> 2]); + + while (count >= 16) { + count -= 16; + uint32_t index = *indices++; + uint64_t value, value2; + + value = static_cast(value_table[(index >> 6) & 0x03]); + value |= static_cast(value_table[(index >> 4) & 0x03]) << 8; + value |= static_cast(value_table[(index >> 2) & 0x03]) << 16; + value |= static_cast(value_table[index & 0x03]) << 24; + value |= static_cast(value_table[(index >> 14) & 0x03]) << 32; + value |= static_cast(value_table[(index >> 12) & 0x03]) << 40; + value |= static_cast(value_table[(index >> 10) & 0x03]) << 48; + value |= static_cast(value_table[(index >> 8) & 0x03]) << 56; + + *reinterpret_cast(buffer) = value; + + value2 = static_cast(value_table[(index >> 22) & 0x03]); + value2 |= static_cast(value_table[(index >> 20) & 0x03]) << 8; + value2 |= static_cast(value_table[(index >> 18) & 0x03]) + << 16; + value2 |= static_cast(value_table[(index >> 16) & 0x03]) + << 24; + value2 |= static_cast(value_table[(index >> 30) & 0x03]) + << 32; + value2 |= static_cast(value_table[(index >> 28) & 0x03]) + << 40; + value2 |= static_cast(value_table[(index >> 26) & 0x03]) + << 48; + value2 |= static_cast(value_table[(index >> 24) & 0x03]) + << 56; + + *reinterpret_cast(buffer + 8) = value2; + + buffer += 16; + } + + current_offset = + (reinterpret_cast(indices) - compressed_indices_) + << 2; + } + + // process remaining elements in current channel + while (count > 0) { + count -= 1; + const size_t index = GetNextTableIndexWidth2(current_offset++); + *buffer++ = value_table[index]; + } + + value_table += stride; + } +} + +void DecompressionState::DecompressToBufferWidth3_32(int8_t* buffer) { + MicroProfiler* profiler = + static_cast(micro_context_->external_context()); + ScopedMicroProfiler scoped_profiler(__func__, profiler); + + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint8_t* value_table = + static_cast(comp_data_.data.lut_data->value_table); + const size_t max_count = elements_per_channel_; + size_t current_offset = 0; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + + // process elements at start of channel up to next uint32_t alignment of + // compressed_indices_ + while (count > 0 && (current_offset & 0x1F)) { + const size_t index = GetNextTableIndexWidth3(current_offset++); + *buffer++ = value_table[index]; + count -= 1; + } + + // process elements in current channel in groups of 32 + if (count >= 32) { + const uint32_t* indices = reinterpret_cast( + &compressed_indices_[(current_offset >> 5) * 12]); + + while (count >= 32) { + count -= 32; + uint32_t index0 = *indices++; + uint32_t index1 = *indices++; + uint32_t index2 = *indices++; + uint64_t value, value2; + + value = static_cast(value_table[(index0 >> 5) & 0x07]); + value |= static_cast(value_table[(index0 >> 2) & 0x07]) << 8; + value |= + static_cast( + value_table[((index0 << 1) & 0b110) | ((index0 >> 15) & 0b1)]) + << 16; + value |= static_cast(value_table[(index0 >> 12) & 0x07]) + << 24; + value |= static_cast(value_table[(index0 >> 9) & 0x07]) << 32; + value |= + static_cast( + value_table[((index0 >> 6) & 0b100) | ((index0 >> 22) & 0b11)]) + << 40; + value |= static_cast(value_table[(index0 >> 19) & 0x07]) + << 48; + value |= static_cast(value_table[(index0 >> 16) & 0x07]) + << 56; + + *reinterpret_cast(buffer) = value; + + value2 = static_cast(value_table[(index0 >> 29) & 0x07]); + value2 |= static_cast(value_table[(index0 >> 26) & 0x07]) + << 8; + value2 |= + static_cast( + value_table[((index0 >> 23) & 0b110) | ((index1 >> 7) & 0b1)]) + << 16; + value2 |= static_cast(value_table[(index1 >> 4) & 0x07]) + << 24; + value2 |= static_cast(value_table[(index1 >> 1) & 0x07]) + << 32; + value2 |= + static_cast( + value_table[((index1 << 2) & 0b100) | ((index1 >> 14) & 0b11)]) + << 40; + value2 |= static_cast(value_table[(index1 >> 11) & 0x07]) + << 48; + value2 |= static_cast(value_table[(index1 >> 8) & 0x07]) + << 56; + + *reinterpret_cast(buffer + 8) = value2; + + value = static_cast(value_table[(index1 >> 21) & 0x07]); + value |= static_cast(value_table[(index1 >> 18) & 0x07]) << 8; + value |= + static_cast( + value_table[((index1 >> 15) & 0b110) | ((index1 >> 31) & 0b1)]) + << 16; + value |= static_cast(value_table[(index1 >> 28) & 0x07]) + << 24; + value |= static_cast(value_table[(index1 >> 25) & 0x07]) + << 32; + value |= + static_cast( + value_table[((index1 >> 22) & 0b100) | ((index2 >> 6) & 0b11)]) + << 40; + value |= static_cast(value_table[(index2 >> 3) & 0x07]) << 48; + value |= static_cast(value_table[(index2 >> 0) & 0x07]) << 56; + + *reinterpret_cast(buffer + 16) = value; + + value2 = static_cast(value_table[(index2 >> 13) & 0x07]); + value2 |= static_cast(value_table[(index2 >> 10) & 0x07]) + << 8; + value2 |= + static_cast( + value_table[((index2 >> 7) & 0b110) | ((index2 >> 23) & 0b1)]) + << 16; + value2 |= static_cast(value_table[(index2 >> 20) & 0x07]) + << 24; + value2 |= static_cast(value_table[(index2 >> 17) & 0x07]) + << 32; + value2 |= + static_cast( + value_table[((index2 >> 14) & 0b100) | ((index2 >> 30) & 0b11)]) + << 40; + value2 |= static_cast(value_table[(index2 >> 27) & 0x07]) + << 48; + value2 |= static_cast(value_table[(index2 >> 24) & 0x07]) + << 56; + + *reinterpret_cast(buffer + 24) = value2; + + buffer += 32; + current_offset += 32; + } + } + + // process remaining elements in current channel + while (count > 0) { + count -= 1; + const size_t index = GetNextTableIndexWidth3(current_offset++); + *buffer++ = value_table[index]; + } + + value_table += stride; + } +} + +// TODO(ddavis-2015): templating GetNextTableIndexWidth makes this method +// more than 2x faster, but with a large code size increase (and BSS segment +// increase) +template +void DecompressionState::DecompressToBufferWidthAny(T* buffer) { + const char* func_name_p = nullptr; + MicroProfiler* profiler = + static_cast(micro_context_->external_context()); + if (profiler != nullptr) { + static char func_name[35]; + MicroSnprintf(func_name, sizeof(func_name), "%s_%u_%s", __func__, + compressed_bit_width_, + TfLiteTypeGetName(typeToTfLiteType())); + func_name_p = func_name; + } + ScopedMicroProfiler scoped_profiler(func_name_p, profiler); + + if (comp_data_.data.lut_data->use_alternate_axis) { + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + size_t current_offset = 0; + size_t count = count_indices_; + + while (count > 0) { + const T* value_table = + static_cast(comp_data_.data.lut_data->value_table); + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t index; + switch (compressed_bit_width_) { + case 1: + index = GetNextTableIndexWidth1(current_offset); + break; + case 2: + index = GetNextTableIndexWidth2(current_offset); + break; + case 3: + index = GetNextTableIndexWidth3(current_offset); + break; + case 4: + index = GetNextTableIndexWidth4(current_offset); + break; + case 5: + index = GetNextTableIndexWidth5(current_offset); + break; + case 6: + index = GetNextTableIndexWidth6(current_offset); + break; + case 7: + index = GetNextTableIndexWidth7(current_offset); + break; + } + current_offset++; + *buffer++ = value_table[index]; + value_table += stride; + } + count -= num_channels_; + } + } else { + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + const T* value_table = + static_cast(comp_data_.data.lut_data->value_table); + const size_t max_count = elements_per_channel_; + size_t current_offset = 0; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + + while (count-- > 0) { + size_t index; + switch (compressed_bit_width_) { + case 1: + index = GetNextTableIndexWidth1(current_offset); + break; + case 2: + index = GetNextTableIndexWidth2(current_offset); + break; + case 3: + index = GetNextTableIndexWidth3(current_offset); + break; + case 4: + index = GetNextTableIndexWidth4(current_offset); + break; + case 5: + index = GetNextTableIndexWidth5(current_offset); + break; + case 6: + index = GetNextTableIndexWidth6(current_offset); + break; + case 7: + index = GetNextTableIndexWidth7(current_offset); + break; + } + current_offset++; + *buffer++ = value_table[index]; + } + value_table += stride; + } + } +} + +template void DecompressionState::DecompressToBufferWidthAny(bool*); +template void DecompressionState::DecompressToBufferWidthAny(float*); +template void DecompressionState::DecompressToBufferWidthAny(int8_t*); +template void DecompressionState::DecompressToBufferWidthAny(int16_t*); +template void DecompressionState::DecompressToBufferWidthAny(int32_t*); +template void DecompressionState::DecompressToBufferWidthAny(int64_t*); + +// TODO(ddavis-2015): untested +inline size_t DecompressionState::GetNextTableIndexWidth7( + const size_t current_offset) { + const size_t current_byte_index = (current_offset >> 3) * 7; + const uint8_t* indices = &compressed_indices_[current_byte_index]; + switch (current_offset & 0b111) { + case 0: + return indices[0] >> 1; + case 1: + return ((indices[0] & 0b1) << 6) | (indices[1] >> 2); + case 2: + return ((indices[1] & 0b11) << 4) | (indices[2] >> 3); + case 3: + return ((indices[2] & 0b111) << 4) | (indices[3] >> 4); + case 4: + return ((indices[3] & 0x0F) << 3) | (indices[4] >> 5); + case 5: + return ((indices[4] & 0x1F) << 2) | (indices[5] >> 6); + case 6: + return ((indices[5] & 0x3F) << 1) | (indices[6] >> 7); + case 7: + return indices[6] & 0x7F; + } + // NOTREACHED + return 0; +} + +// TODO(ddavis-2015): untested +inline size_t DecompressionState::GetNextTableIndexWidth6( + const size_t current_offset) { + const size_t current_byte_index = (current_offset >> 2) * 3; + const uint8_t* indices = &compressed_indices_[current_byte_index]; + switch (current_offset & 0b11) { + case 0: + return indices[0] >> 2; + case 1: + return ((indices[0] & 0b11) << 4) | (indices[1] >> 4); + case 2: + return ((indices[1] & 0x0F) << 4) | (indices[2] >> 6); + case 3: + return indices[2] & 0x3F; + } + // NOTREACHED + return 0; +} + +// TODO(ddavis-2015): untested +inline size_t DecompressionState::GetNextTableIndexWidth5( + const size_t current_offset) { + const size_t current_byte_index = (current_offset >> 3) * 5; + const uint8_t* indices = &compressed_indices_[current_byte_index]; + switch (current_offset & 0b111) { + case 0: + return indices[0] >> 3; + case 1: + return ((indices[0] & 0b111) << 2) | (indices[1] >> 6); + case 2: + return (indices[1] >> 1) & 0x1F; + case 3: + return ((indices[1] & 0b1) << 4) | (indices[2] >> 4); + case 4: + return ((indices[2] & 0x0F) << 1) | (indices[3] >> 7); + case 5: + return (indices[3] >> 2) & 0x1F; + case 6: + return ((indices[3] & 0b11) << 3) | (indices[4] >> 5); + case 7: + return indices[4] & 0x1F; + } + // NOTREACHED + return 0; +} + +inline size_t DecompressionState::GetNextTableIndexWidth4( + const size_t current_offset) { + if (current_offset & 1) { + return compressed_indices_[current_offset >> 1] & 0x0F; + } else { + return compressed_indices_[current_offset >> 1] >> 4; + } +} + +inline size_t DecompressionState::GetNextTableIndexWidth3( + const size_t current_offset) { + const size_t current_byte_index = (current_offset >> 3) * 3; + const uint8_t* indices = &compressed_indices_[current_byte_index]; + switch (current_offset & 0b111) { + case 0: + return indices[0] >> 5; + case 1: + return (indices[0] >> 2) & 0b111; + case 2: + return ((indices[0] & 0b11) << 1) | (indices[1] >> 7); + case 3: + return (indices[1] >> 4) & 0b111; + case 4: + return (indices[1] >> 1) & 0b111; + case 5: + return ((indices[1] & 0b1) << 2) | (indices[2] >> 6); + case 6: + return (indices[2] >> 3) & 0b111; + case 7: + return indices[2] & 0b111; + } + // NOTREACHED + return 0; +} + +inline size_t DecompressionState::GetNextTableIndexWidth2( + const size_t current_offset) { + if (current_offset & 0b10) { + if (current_offset & 1) { + return compressed_indices_[current_offset >> 2] & 0x03; + } else { + return (compressed_indices_[current_offset >> 2] >> 2) & 0x03; + } + } else { + if (current_offset & 1) { + return (compressed_indices_[current_offset >> 2] >> 4) & 0x03; + } else { + return (compressed_indices_[current_offset >> 2] >> 6) & 0x03; + } + } +} + +inline size_t DecompressionState::GetNextTableIndexWidth1( + const size_t current_offset) { + const size_t shift = ~current_offset & 0b111; + return (compressed_indices_[current_offset >> 3] >> shift) & 0b1; +} + +} // namespace tflite + +#endif // USE_TFLM_COMPRESSION diff --git a/tensorflow/lite/micro/kernels/xtensa/decompress.cc b/tensorflow/lite/micro/kernels/xtensa/decompress.cc new file mode 100644 index 00000000000..3c1b75fc763 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/decompress.cc @@ -0,0 +1,245 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef USE_TFLM_COMPRESSION + +#include "tensorflow/lite/micro/kernels/decompress.h" + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/micro/micro_common.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" +#include "tensorflow/lite/micro/micro_utils.h" +#include "tensorflow/lite/portable_type_to_tflitetype.h" + +#ifdef HIFI5 +#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" +#endif // HIFI5 + +namespace tflite { +namespace { + +#ifdef HIFI5 + +struct DecompressionStateXtensa : DecompressionState { + DecompressionStateXtensa() = delete; + + DecompressionStateXtensa(const DecompressionState& other) + : DecompressionState(other) {} + + void DecompressToBufferWidth4_Xtensa(int8_t* buffer); + + template + void DecompressToBufferWidthAny_Xtensa(int8_t* buffer); +}; + +void DecompressionStateXtensa::DecompressToBufferWidth4_Xtensa(int8_t* buffer) { + MicroProfiler* profiler = + static_cast(micro_context_->external_context()); + ScopedMicroProfiler scoped_profiler(__func__, profiler); + + ae_int8x8 d_shuffle_t = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + int elements_per_channel_t_by_4 = elements_per_channel_ >> 4; + int elements_per_channel_t_rem = elements_per_channel_ & 0xF; + int j; + + ae_int8x8 d_out1, d_out2; + ae_int8x8 d_value_0_t, d_value_1_t; + ae_int8x8 d_value_0, d_value_1; + ae_int8x8 d_index, d_dummy; + + ae_int8x8* __restrict pIn_tmp = (ae_int8x8*)compressed_indices_; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint8_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + const uint8_t* __restrict value_table_t = value_table; + + for (size_t i = 0; i < num_channels_; i++) { + value_table_t = value_table; + ae_valignx2 align_vtab = AE_LA128_PP(value_table_t); + AE_LA8X8X2_IP(d_value_0_t, d_value_1_t, align_vtab, + (ae_int8x16*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0_t, d_value_1_t, + d_shuffle_value_t); + + ae_valignx2 align_store = AE_ZALIGN128(); + ae_valign align_load = AE_LA64_PP(pIn_tmp); + + for (j = 0; j < elements_per_channel_t_by_4; j++) { + AE_LA8X8_IP(d_index, align_load, pIn_tmp); + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + + value_table += stride; + + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + (elements_per_channel_t_rem >> + 1)); /* Loading 48 bits for decoding 16 weight values */ + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); + } +} + +template +void DecompressionStateXtensa::DecompressToBufferWidthAny_Xtensa( + int8_t* buffer) { + const char* func_name_p = nullptr; + MicroProfiler* profiler = + static_cast(micro_context_->external_context()); + if (profiler != nullptr) { + static char func_name[42]; + MicroSnprintf(func_name, sizeof(func_name), "%s_%u_%s", __func__, + compressed_bit_width_, + TfLiteTypeGetName(typeToTfLiteType())); + func_name_p = func_name; + } + ScopedMicroProfiler scoped_profiler(func_name_p, profiler); + + const int stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint8_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + int elements_per_channel_t = elements_per_channel_; + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LBI_DBI_IP((unsigned short*)p_stream, index, N); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + } + + value_table += stride; + } +} + +#endif // HIFI5 + +} // namespace + +#ifdef HIFI5 + +template +T* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 4 && + !comp_data_.data.lut_data->use_alternate_axis) { + dsx.DecompressToBufferWidth4_Xtensa(static_cast(buffer)); + } else if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 3 && + !comp_data_.data.lut_data->use_alternate_axis) { + dsx.DecompressToBufferWidthAny_Xtensa<3>(static_cast(buffer)); + } else if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 2 && + !comp_data_.data.lut_data->use_alternate_axis) { + dsx.DecompressToBufferWidthAny_Xtensa<2>(static_cast(buffer)); + } else { + if (std::is_same::value && + !comp_data_.data.lut_data->use_alternate_axis) { + switch (compressed_bit_width_) { + case 1: + dsx.DecompressToBufferWidthAny_Xtensa<1>( + static_cast(buffer)); + break; + case 4: + dsx.DecompressToBufferWidthAny_Xtensa<4>( + static_cast(buffer)); + break; + case 5: + dsx.DecompressToBufferWidthAny_Xtensa<5>( + static_cast(buffer)); + break; + case 6: + dsx.DecompressToBufferWidthAny_Xtensa<6>( + static_cast(buffer)); + break; + case 7: + dsx.DecompressToBufferWidthAny_Xtensa<7>( + static_cast(buffer)); + break; + } + } else { + DecompressToBufferWidthAny(static_cast(buffer)); + } + } + + return static_cast(buffer); +} + +#else // HIFI5 + +template +T* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 4 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth4_16(static_cast(buffer)); + } else if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 3 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth3_32(static_cast(buffer)); + } else if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 2 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth2_16(static_cast(buffer)); + } else { + DecompressToBufferWidthAny(static_cast(buffer)); + } + + return static_cast(buffer); +} + +#endif // HIFI5 + +template bool* DecompressionState::DecompressToBuffer(void*); +template float* DecompressionState::DecompressToBuffer(void*); +template int8_t* DecompressionState::DecompressToBuffer(void*); +template int16_t* DecompressionState::DecompressToBuffer(void*); +template int32_t* DecompressionState::DecompressToBuffer(void*); +template int64_t* DecompressionState::DecompressToBuffer(void*); + +} // namespace tflite + +#endif // USE_TFLM_COMPRESSION diff --git a/tensorflow/lite/micro/micro_context.cc b/tensorflow/lite/micro/micro_context.cc index 6e576a9ae79..eb557e83166 100644 --- a/tensorflow/lite/micro/micro_context.cc +++ b/tensorflow/lite/micro/micro_context.cc @@ -17,18 +17,12 @@ limitations under the License. #include #include -#include #include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/micro/kernels/decompress.h" #include "tensorflow/lite/micro/micro_common.h" #include "tensorflow/lite/micro/micro_log.h" -#include "tensorflow/lite/micro/micro_profiler.h" #include "tensorflow/lite/micro/micro_utils.h" -#include "tensorflow/lite/portable_type_to_tflitetype.h" - -#ifdef HIFI5 -#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" -#endif // HIFI5 namespace tflite { namespace { @@ -43,737 +37,6 @@ int GetTensorIndex(int index, int max_size, const int* tensor_indices) { return -1; } -#ifdef USE_TFLM_COMPRESSION - -struct DecompressionState { - DecompressionState() = delete; - - DecompressionState(const uint8_t* compressed_indices, - const size_t count_indices, - const CompressionTensorData& comp_data, - const size_t num_channels, MicroContext* micro_context) - : compressed_indices_(compressed_indices), - count_indices_(count_indices), - comp_data_(comp_data), - num_channels_(num_channels), - micro_context_(micro_context) {} - - template - T* DecompressToBuffer(void* buffer); - -#ifdef HIFI5 - void DecompressToBufferWidth4_Xtensa(int8_t* buffer); - - template - void DecompressToBufferWidthAny_Xtensa(int8_t* buffer); -#endif // HIFI5 - - void DecompressToBufferWidth4_16(int8_t* buffer); - void DecompressToBufferWidth3_32(int8_t* buffer); - void DecompressToBufferWidth2_16(int8_t* buffer); - - template - void DecompressToBufferWidthAny(T* buffer); - - inline size_t GetNextTableIndexWidth7(const size_t current_offset); - inline size_t GetNextTableIndexWidth6(const size_t current_offset); - inline size_t GetNextTableIndexWidth5(const size_t current_offset); - inline size_t GetNextTableIndexWidth4(const size_t current_offset); - inline size_t GetNextTableIndexWidth3(const size_t current_offset); - inline size_t GetNextTableIndexWidth2(const size_t current_offset); - inline size_t GetNextTableIndexWidth1(const size_t current_offset); - - private: - const uint8_t* compressed_indices_; - const size_t count_indices_; - const CompressionTensorData& comp_data_; - const size_t num_channels_; - const size_t compressed_bit_width_ = - comp_data_.data.lut_data->compressed_bit_width; - const size_t elements_per_channel_ = - comp_data_.data.lut_data->use_alternate_axis - ? 1 - : count_indices_ / num_channels_; - MicroContext* micro_context_; -}; - -void DecompressionState::DecompressToBufferWidth4_16(int8_t* buffer) { - MicroProfiler* profiler = - static_cast(micro_context_->external_context()); - ScopedMicroProfiler scoped_profiler(__func__, profiler); - - const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; - const uint8_t* value_table = - static_cast(comp_data_.data.lut_data->value_table); - const size_t max_count = elements_per_channel_; - size_t current_offset = 0; - - for (size_t channel = 0; channel < num_channels_; channel++) { - size_t count = max_count; - - // process elements at start of channel up to next uint64_t alignment of - // compressed_indices_ - while (count > 0 && (current_offset & 0x0F)) { - const size_t index = GetNextTableIndexWidth4(current_offset++); - *buffer++ = value_table[index]; - count -= 1; - } - - // process elements in current channel in groups of 16 - if (count >= 16) { - const uint64_t* indices = reinterpret_cast( - &compressed_indices_[current_offset >> 1]); - - while (count >= 16) { - count -= 16; - uint64_t index = *indices++; - uint64_t value, value2; - - value = static_cast(value_table[(index >> 4) & 0x0F]); - value |= static_cast(value_table[index & 0x0F]) << 8; - value |= static_cast(value_table[(index >> 12) & 0x0F]) << 16; - value |= static_cast(value_table[(index >> 8) & 0x0F]) << 24; - value |= static_cast(value_table[(index >> 20) & 0x0F]) << 32; - value |= static_cast(value_table[(index >> 16) & 0x0F]) << 40; - value |= static_cast(value_table[(index >> 28) & 0x0F]) << 48; - value |= static_cast(value_table[(index >> 24) & 0x0F]) << 56; - - *reinterpret_cast(buffer) = value; - - value2 = static_cast(value_table[(index >> 36) & 0x0F]); - value2 |= static_cast(value_table[(index >> 32) & 0x0F]) << 8; - value2 |= static_cast(value_table[(index >> 44) & 0x0F]) - << 16; - value2 |= static_cast(value_table[(index >> 40) & 0x0F]) - << 24; - value2 |= static_cast(value_table[(index >> 52) & 0x0F]) - << 32; - value2 |= static_cast(value_table[(index >> 48) & 0x0F]) - << 40; - value2 |= static_cast(value_table[(index >> 60) & 0x0F]) - << 48; - value2 |= static_cast(value_table[(index >> 56) & 0x0F]) - << 56; - - *reinterpret_cast(buffer + 8) = value2; - - buffer += 16; - } - - current_offset = - (reinterpret_cast(indices) - compressed_indices_) - << 1; - } - - // process remaining elements in current channel - while (count > 0) { - count -= 1; - const size_t index = GetNextTableIndexWidth4(current_offset++); - *buffer++ = value_table[index]; - } - - value_table += stride; - } -} - -void DecompressionState::DecompressToBufferWidth2_16(int8_t* buffer) { - MicroProfiler* profiler = - static_cast(micro_context_->external_context()); - ScopedMicroProfiler scoped_profiler(__func__, profiler); - - const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; - const uint8_t* value_table = - static_cast(comp_data_.data.lut_data->value_table); - const size_t max_count = elements_per_channel_; - size_t current_offset = 0; - - for (size_t channel = 0; channel < num_channels_; channel++) { - size_t count = max_count; - - // process elements at start of channel up to next uint32_t alignment of - // compressed_indices_ - while (count > 0 && (current_offset & 0x0F)) { - const size_t index = GetNextTableIndexWidth2(current_offset++); - *buffer++ = value_table[index]; - count -= 1; - } - - // process elements in current channel in groups of 16 - if (count >= 16) { - const uint32_t* indices = reinterpret_cast( - &compressed_indices_[current_offset >> 2]); - - while (count >= 16) { - count -= 16; - uint32_t index = *indices++; - uint64_t value, value2; - - value = static_cast(value_table[(index >> 6) & 0x03]); - value |= static_cast(value_table[(index >> 4) & 0x03]) << 8; - value |= static_cast(value_table[(index >> 2) & 0x03]) << 16; - value |= static_cast(value_table[index & 0x03]) << 24; - value |= static_cast(value_table[(index >> 14) & 0x03]) << 32; - value |= static_cast(value_table[(index >> 12) & 0x03]) << 40; - value |= static_cast(value_table[(index >> 10) & 0x03]) << 48; - value |= static_cast(value_table[(index >> 8) & 0x03]) << 56; - - *reinterpret_cast(buffer) = value; - - value2 = static_cast(value_table[(index >> 22) & 0x03]); - value2 |= static_cast(value_table[(index >> 20) & 0x03]) << 8; - value2 |= static_cast(value_table[(index >> 18) & 0x03]) - << 16; - value2 |= static_cast(value_table[(index >> 16) & 0x03]) - << 24; - value2 |= static_cast(value_table[(index >> 30) & 0x03]) - << 32; - value2 |= static_cast(value_table[(index >> 28) & 0x03]) - << 40; - value2 |= static_cast(value_table[(index >> 26) & 0x03]) - << 48; - value2 |= static_cast(value_table[(index >> 24) & 0x03]) - << 56; - - *reinterpret_cast(buffer + 8) = value2; - - buffer += 16; - } - - current_offset = - (reinterpret_cast(indices) - compressed_indices_) - << 2; - } - - // process remaining elements in current channel - while (count > 0) { - count -= 1; - const size_t index = GetNextTableIndexWidth2(current_offset++); - *buffer++ = value_table[index]; - } - - value_table += stride; - } -} - -void DecompressionState::DecompressToBufferWidth3_32(int8_t* buffer) { - MicroProfiler* profiler = - static_cast(micro_context_->external_context()); - ScopedMicroProfiler scoped_profiler(__func__, profiler); - - const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; - const uint8_t* value_table = - static_cast(comp_data_.data.lut_data->value_table); - const size_t max_count = elements_per_channel_; - size_t current_offset = 0; - - for (size_t channel = 0; channel < num_channels_; channel++) { - size_t count = max_count; - - // process elements at start of channel up to next uint32_t alignment of - // compressed_indices_ - while (count > 0 && (current_offset & 0x1F)) { - const size_t index = GetNextTableIndexWidth3(current_offset++); - *buffer++ = value_table[index]; - count -= 1; - } - - // process elements in current channel in groups of 32 - if (count >= 32) { - const uint32_t* indices = reinterpret_cast( - &compressed_indices_[(current_offset >> 5) * 12]); - - while (count >= 32) { - count -= 32; - uint32_t index0 = *indices++; - uint32_t index1 = *indices++; - uint32_t index2 = *indices++; - uint64_t value, value2; - - value = static_cast(value_table[(index0 >> 5) & 0x07]); - value |= static_cast(value_table[(index0 >> 2) & 0x07]) << 8; - value |= - static_cast( - value_table[((index0 << 1) & 0b110) | ((index0 >> 15) & 0b1)]) - << 16; - value |= static_cast(value_table[(index0 >> 12) & 0x07]) - << 24; - value |= static_cast(value_table[(index0 >> 9) & 0x07]) << 32; - value |= - static_cast( - value_table[((index0 >> 6) & 0b100) | ((index0 >> 22) & 0b11)]) - << 40; - value |= static_cast(value_table[(index0 >> 19) & 0x07]) - << 48; - value |= static_cast(value_table[(index0 >> 16) & 0x07]) - << 56; - - *reinterpret_cast(buffer) = value; - - value2 = static_cast(value_table[(index0 >> 29) & 0x07]); - value2 |= static_cast(value_table[(index0 >> 26) & 0x07]) - << 8; - value2 |= - static_cast( - value_table[((index0 >> 23) & 0b110) | ((index1 >> 7) & 0b1)]) - << 16; - value2 |= static_cast(value_table[(index1 >> 4) & 0x07]) - << 24; - value2 |= static_cast(value_table[(index1 >> 1) & 0x07]) - << 32; - value2 |= - static_cast( - value_table[((index1 << 2) & 0b100) | ((index1 >> 14) & 0b11)]) - << 40; - value2 |= static_cast(value_table[(index1 >> 11) & 0x07]) - << 48; - value2 |= static_cast(value_table[(index1 >> 8) & 0x07]) - << 56; - - *reinterpret_cast(buffer + 8) = value2; - - value = static_cast(value_table[(index1 >> 21) & 0x07]); - value |= static_cast(value_table[(index1 >> 18) & 0x07]) << 8; - value |= - static_cast( - value_table[((index1 >> 15) & 0b110) | ((index1 >> 31) & 0b1)]) - << 16; - value |= static_cast(value_table[(index1 >> 28) & 0x07]) - << 24; - value |= static_cast(value_table[(index1 >> 25) & 0x07]) - << 32; - value |= - static_cast( - value_table[((index1 >> 22) & 0b100) | ((index2 >> 6) & 0b11)]) - << 40; - value |= static_cast(value_table[(index2 >> 3) & 0x07]) << 48; - value |= static_cast(value_table[(index2 >> 0) & 0x07]) << 56; - - *reinterpret_cast(buffer + 16) = value; - - value2 = static_cast(value_table[(index2 >> 13) & 0x07]); - value2 |= static_cast(value_table[(index2 >> 10) & 0x07]) - << 8; - value2 |= - static_cast( - value_table[((index2 >> 7) & 0b110) | ((index2 >> 23) & 0b1)]) - << 16; - value2 |= static_cast(value_table[(index2 >> 20) & 0x07]) - << 24; - value2 |= static_cast(value_table[(index2 >> 17) & 0x07]) - << 32; - value2 |= - static_cast( - value_table[((index2 >> 14) & 0b100) | ((index2 >> 30) & 0b11)]) - << 40; - value2 |= static_cast(value_table[(index2 >> 27) & 0x07]) - << 48; - value2 |= static_cast(value_table[(index2 >> 24) & 0x07]) - << 56; - - *reinterpret_cast(buffer + 24) = value2; - - buffer += 32; - current_offset += 32; - } - } - - // process remaining elements in current channel - while (count > 0) { - count -= 1; - const size_t index = GetNextTableIndexWidth3(current_offset++); - *buffer++ = value_table[index]; - } - - value_table += stride; - } -} - -// TODO(ddavis-2015): templating GetNextTableIndexWidth makes this method -// more than 2x faster, but with a large code size increase (and BSS segment -// increase) -template -void DecompressionState::DecompressToBufferWidthAny(T* buffer) { - static char func_name[40]; - const char* func_name_p = __func__; - MicroProfiler* profiler = - static_cast(micro_context_->external_context()); - if (profiler != nullptr) { - MicroSnprintf(func_name, sizeof(func_name), "%s_%u_%s", __func__, - compressed_bit_width_, - TfLiteTypeGetName(typeToTfLiteType())); - func_name_p = func_name; - } - ScopedMicroProfiler scoped_profiler(func_name_p, profiler); - - if (comp_data_.data.lut_data->use_alternate_axis) { - const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; - size_t current_offset = 0; - size_t count = count_indices_; - - while (count > 0) { - const T* value_table = - static_cast(comp_data_.data.lut_data->value_table); - for (size_t channel = 0; channel < num_channels_; channel++) { - size_t index; - switch (compressed_bit_width_) { - case 1: - index = GetNextTableIndexWidth1(current_offset); - break; - case 2: - index = GetNextTableIndexWidth2(current_offset); - break; - case 3: - index = GetNextTableIndexWidth3(current_offset); - break; - case 4: - index = GetNextTableIndexWidth4(current_offset); - break; - case 5: - index = GetNextTableIndexWidth5(current_offset); - break; - case 6: - index = GetNextTableIndexWidth6(current_offset); - break; - case 7: - index = GetNextTableIndexWidth7(current_offset); - break; - } - current_offset++; - *buffer++ = value_table[index]; - value_table += stride; - } - count -= num_channels_; - } - } else { - const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; - const T* value_table = - static_cast(comp_data_.data.lut_data->value_table); - const size_t max_count = elements_per_channel_; - size_t current_offset = 0; - - for (size_t channel = 0; channel < num_channels_; channel++) { - size_t count = max_count; - - while (count-- > 0) { - size_t index; - switch (compressed_bit_width_) { - case 1: - index = GetNextTableIndexWidth1(current_offset); - break; - case 2: - index = GetNextTableIndexWidth2(current_offset); - break; - case 3: - index = GetNextTableIndexWidth3(current_offset); - break; - case 4: - index = GetNextTableIndexWidth4(current_offset); - break; - case 5: - index = GetNextTableIndexWidth5(current_offset); - break; - case 6: - index = GetNextTableIndexWidth6(current_offset); - break; - case 7: - index = GetNextTableIndexWidth7(current_offset); - break; - } - current_offset++; - *buffer++ = value_table[index]; - } - value_table += stride; - } - } -} - -#ifdef HIFI5 - -void DecompressionState::DecompressToBufferWidth4_Xtensa(int8_t* buffer) { - MicroProfiler* profiler = - static_cast(micro_context_->external_context()); - ScopedMicroProfiler scoped_profiler(__func__, profiler); - - char shuffle_pattern_1[8] = {0x08, 0x19, 0x2A, 0x3B, 0x4C, 0x5D, 0x6E, 0x7F}; - ae_int8x8 d_shuffle_t = *(ae_int8x8*)&shuffle_pattern_1[0]; - - char shuffle_pattern_2[8] = {0xFB, 0x73, 0xEA, 0x62, 0xD9, 0x51, 0xC8, 0x40}; - ae_int8x8 d_d_shuffle_t2 = *(ae_int8x8*)&shuffle_pattern_2[0]; - - ae_int8x8 d_out1, d_out2; - ae_int8x8 d_value_0, d_value_1; - ae_int8x8 d_index; - - ae_int8x8* pIn_tmp = (ae_int8x8*)compressed_indices_; - ae_int8* p_out_tmp = (ae_int8*)buffer; - - const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; - const uint8_t* value_table = - static_cast(comp_data_.data.lut_data->value_table); - -#ifdef notdef - MicroPrintf("indices %p buffer %p value_table %p stride %u", - compressed_indices_, buffer, value_table, stride); -#endif - - for (size_t i = 0; i < num_channels_; i++) { - ae_int8x8 d_value_0_t = *(ae_int8x8*)&value_table[0]; - ae_int8x8 d_value_1_t = *(ae_int8x8*)&value_table[8]; - - AE_DSEL8X8(d_value_0, d_value_1, d_value_0_t, d_value_1_t, d_shuffle_t); - - for (size_t j = 0; j < elements_per_channel_; j += 16) { - AE_L8X8_IP(d_index, pIn_tmp, 8); - AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index); - AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_d_shuffle_t2); - AE_S8X8X2_IP(d_out1, d_out2, (ae_int8x16*)p_out_tmp, 16); - } - - value_table += stride; - } -} - -template -void DecompressionState::DecompressToBufferWidthAny_Xtensa(int8_t* buffer) { - static char func_name[80]; - const char* func_name_p = __func__; - MicroProfiler* profiler = - static_cast(micro_context_->external_context()); - if (profiler != nullptr) { - MicroSnprintf(func_name, sizeof(func_name), "%s_%u_%s", __func__, - compressed_bit_width_, - TfLiteTypeGetName(typeToTfLiteType())); - func_name_p = func_name; - } - ScopedMicroProfiler scoped_profiler(func_name_p, profiler); - - const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; - const uint8_t* value_table = - static_cast(comp_data_.data.lut_data->value_table); - - short* p_stream = (short*)compressed_indices_; - uint32_t index; - ae_int8* p_out_tmp = (ae_int8*)buffer; - - WUR_AE_BITPTR(0); - WUR_AE_BITHEAD(0); - - AE_DBI_IP((const unsigned short*)p_stream, 16); - AE_DBI_IP((const unsigned short*)p_stream, 16); - - for (size_t i = 0; i < num_channels_; i++) { - for (size_t j = 0; j < elements_per_channel_; j++) { - AE_LBI_DBI_IP((unsigned short*)p_stream, index, N); - ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); - AE_S8_0_IP(d_tmp, p_out_tmp, 1); - } - - value_table += stride; - } -} - -#endif // HIFI5 - -template -T* DecompressionState::DecompressToBuffer(void* buffer) { - TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); - TFLITE_DCHECK(compressed_bit_width_ > 0); - - if (std::is_same::value && - comp_data_.data.lut_data->compressed_bit_width == 4 && - !comp_data_.data.lut_data->use_alternate_axis) { -#ifdef HIFI5 - if (!(elements_per_channel_ & 0x0F) && - comp_data_.data.lut_data->value_table_channel_stride == 16) { - DecompressToBufferWidth4_Xtensa(static_cast(buffer)); - } else { - //DecompressToBufferWidth4_16(static_cast(buffer)); - DecompressToBufferWidthAny_Xtensa<4>(static_cast(buffer)); - } -#else // HIFI5 - DecompressToBufferWidth4_16(static_cast(buffer)); -#endif // HIFI5 - } else if (std::is_same::value && - comp_data_.data.lut_data->compressed_bit_width == 2 && - !comp_data_.data.lut_data->use_alternate_axis) { -#ifdef HIFI5 - DecompressToBufferWidthAny_Xtensa<2>(static_cast(buffer)); -#else // HIFI5 - DecompressToBufferWidth2_16(static_cast(buffer)); -#endif // HIFI5 - } else if (std::is_same::value && - comp_data_.data.lut_data->compressed_bit_width == 3 && - !comp_data_.data.lut_data->use_alternate_axis) { -#ifdef HIFI5 - DecompressToBufferWidthAny_Xtensa<3>(static_cast(buffer)); -#else // HIFI5 - DecompressToBufferWidth3_32(static_cast(buffer)); -#endif // HIFI5 - } else { -#ifdef HIFI5 - if (std::is_same::value && - !comp_data_.data.lut_data->use_alternate_axis) { - switch (compressed_bit_width_) { - case 1: - DecompressToBufferWidthAny_Xtensa<1>(static_cast(buffer)); - break; - case 4: - DecompressToBufferWidthAny_Xtensa<4>(static_cast(buffer)); - break; - case 5: - DecompressToBufferWidthAny_Xtensa<5>(static_cast(buffer)); - break; - case 6: - DecompressToBufferWidthAny_Xtensa<6>(static_cast(buffer)); - break; - case 7: - DecompressToBufferWidthAny_Xtensa<7>(static_cast(buffer)); - break; - } - } else { - DecompressToBufferWidthAny(static_cast(buffer)); - } -#else // HIFI5 - DecompressToBufferWidthAny(static_cast(buffer)); -#endif // HIFI5 - } - - return static_cast(buffer); -} - -// TODO(ddavis-2015): untested -inline size_t DecompressionState::GetNextTableIndexWidth7( - const size_t current_offset) { - const size_t current_byte_index = (current_offset >> 3) * 7; - const uint8_t* indices = &compressed_indices_[current_byte_index]; - switch (current_offset & 0b111) { - case 0: - return indices[0] >> 1; - case 1: - return ((indices[0] & 0b1) << 6) | (indices[1] >> 2); - case 2: - return ((indices[1] & 0b11) << 4) | (indices[2] >> 3); - case 3: - return ((indices[2] & 0b111) << 4) | (indices[3] >> 4); - case 4: - return ((indices[3] & 0x0F) << 3) | (indices[4] >> 5); - case 5: - return ((indices[4] & 0x1F) << 2) | (indices[5] >> 6); - case 6: - return ((indices[5] & 0x3F) << 1) | (indices[6] >> 7); - case 7: - return indices[6] & 0x7F; - } - // NOTREACHED - return 0; -} - -// TODO(ddavis-2015): untested -inline size_t DecompressionState::GetNextTableIndexWidth6( - const size_t current_offset) { - const size_t current_byte_index = (current_offset >> 2) * 3; - const uint8_t* indices = &compressed_indices_[current_byte_index]; - switch (current_offset & 0b11) { - case 0: - return indices[0] >> 2; - case 1: - return ((indices[0] & 0b11) << 4) | (indices[1] >> 4); - case 2: - return ((indices[1] & 0x0F) << 4) | (indices[2] >> 6); - case 3: - return indices[2] & 0x3F; - } - // NOTREACHED - return 0; -} - -// TODO(ddavis-2015): untested -inline size_t DecompressionState::GetNextTableIndexWidth5( - const size_t current_offset) { - const size_t current_byte_index = (current_offset >> 3) * 5; - const uint8_t* indices = &compressed_indices_[current_byte_index]; - switch (current_offset & 0b111) { - case 0: - return indices[0] >> 3; - case 1: - return ((indices[0] & 0b111) << 2) | (indices[1] >> 6); - case 2: - return (indices[1] >> 1) & 0x1F; - case 3: - return ((indices[1] & 0b1) << 4) | (indices[2] >> 4); - case 4: - return ((indices[2] & 0x0F) << 1) | (indices[3] >> 7); - case 5: - return (indices[3] >> 2) & 0x1F; - case 6: - return ((indices[3] & 0b11) << 3) | (indices[4] >> 5); - case 7: - return indices[4] & 0x1F; - } - // NOTREACHED - return 0; -} - -inline size_t DecompressionState::GetNextTableIndexWidth4( - const size_t current_offset) { - if (current_offset & 1) { - return compressed_indices_[current_offset >> 1] & 0x0F; - } else { - return compressed_indices_[current_offset >> 1] >> 4; - } -} - -inline size_t DecompressionState::GetNextTableIndexWidth3( - const size_t current_offset) { - const size_t current_byte_index = (current_offset >> 3) * 3; - const uint8_t* indices = &compressed_indices_[current_byte_index]; - switch (current_offset & 0b111) { - case 0: - return indices[0] >> 5; - case 1: - return (indices[0] >> 2) & 0b111; - case 2: - return ((indices[0] & 0b11) << 1) | (indices[1] >> 7); - case 3: - return (indices[1] >> 4) & 0b111; - case 4: - return (indices[1] >> 1) & 0b111; - case 5: - return ((indices[1] & 0b1) << 2) | (indices[2] >> 6); - case 6: - return (indices[2] >> 3) & 0b111; - case 7: - return indices[2] & 0b111; - } - // NOTREACHED - return 0; -} - -inline size_t DecompressionState::GetNextTableIndexWidth2( - const size_t current_offset) { - if (current_offset & 0b10) { - if (current_offset & 1) { - return compressed_indices_[current_offset >> 2] & 0x03; - } else { - return (compressed_indices_[current_offset >> 2] >> 2) & 0x03; - } - } else { - if (current_offset & 1) { - return (compressed_indices_[current_offset >> 2] >> 4) & 0x03; - } else { - return (compressed_indices_[current_offset >> 2] >> 6) & 0x03; - } - } -} - -inline size_t DecompressionState::GetNextTableIndexWidth1( - const size_t current_offset) { - const size_t shift = ~current_offset & 0b111; - return (compressed_indices_[current_offset >> 3] >> shift) & 0b1; -} - -#endif // USE_TFLM_COMPRESSION - } // namespace TfLiteTensor* MicroContext::AllocateTempInputTensor(const TfLiteNode* node, diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 4c4de90f54f..d83738da508 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -376,6 +376,8 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/concatenation.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/conv.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/conv_common.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/cumsum.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress_common.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depthwise_conv.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depthwise_conv_common.cc \