From 307c6342f8ecf46dec7354a9d50bfa8ef8ab0040 Mon Sep 17 00:00:00 2001 From: ddavis-2015 Date: Fri, 6 Dec 2024 14:07:24 -0600 Subject: [PATCH] feat(compression): implement tensor compression feature Implement the tensor compression feature. BUG=part of #2636 --- tensorflow/lite/micro/BUILD | 20 + .../non_persistent_arena_buffer_allocator.h | 4 +- .../persistent_arena_buffer_allocator.h | 3 +- .../single_arena_buffer_allocator.h | 4 +- tensorflow/lite/micro/compression.h | 68 ++ .../lite/micro/compression/model_facade.py | 27 + tensorflow/lite/micro/docs/compression.md | 2 +- .../micro_speech/micro_speech_test.cc | 2 +- tensorflow/lite/micro/fake_micro_context.cc | 79 +- tensorflow/lite/micro/fake_micro_context.h | 36 +- tensorflow/lite/micro/kernels/BUILD | 41 + tensorflow/lite/micro/kernels/Makefile.inc | 6 +- .../lite/micro/kernels/assign_variable.cc | 51 +- .../lite/micro/kernels/concatenation.cc | 112 +-- .../lite/micro/kernels/concatenation_test.cc | 227 +++++- tensorflow/lite/micro/kernels/conv.cc | 49 +- tensorflow/lite/micro/kernels/conv.h | 10 +- tensorflow/lite/micro/kernels/conv_common.cc | 19 +- tensorflow/lite/micro/kernels/conv_test.cc | 299 +++++++- tensorflow/lite/micro/kernels/conv_test.h | 191 ++++- .../lite/micro/kernels/conv_test_common.cc | 103 +-- tensorflow/lite/micro/kernels/decompress.cc | 61 ++ tensorflow/lite/micro/kernels/decompress.h | 89 +++ .../lite/micro/kernels/decompress_common.cc | 539 +++++++++++++ .../lite/micro/kernels/decompress_test.cc | 414 ++++++++++ .../lite/micro/kernels/depthwise_conv.cc | 41 +- .../micro/kernels/depthwise_conv_common.cc | 23 +- .../lite/micro/kernels/depthwise_conv_test.cc | 388 +++++++++- .../lite/micro/kernels/fully_connected.cc | 70 +- .../lite/micro/kernels/fully_connected.h | 8 + .../micro/kernels/fully_connected_test.cc | 309 +++++++- .../lite/micro/kernels/kernel_runner.cc | 17 +- tensorflow/lite/micro/kernels/kernel_runner.h | 9 +- tensorflow/lite/micro/kernels/kernel_util.h | 41 +- .../lite/micro/kernels/transpose_conv.cc | 127 +++- .../lite/micro/kernels/transpose_conv.h | 15 +- .../lite/micro/kernels/transpose_conv_test.cc | 520 ++++++++++++- tensorflow/lite/micro/kernels/xtensa/conv.cc | 22 +- .../lite/micro/kernels/xtensa/conv_hifi.cc | 44 +- .../kernels/xtensa/conv_int16_reference.cc | 29 + .../kernels/xtensa/conv_int8_reference.cc | 21 + .../lite/micro/kernels/xtensa/conv_vision.cc | 72 +- .../lite/micro/kernels/xtensa/decompress.cc | 711 ++++++++++++++++++ .../micro/kernels/xtensa/depthwise_conv.cc | 34 +- .../kernels/xtensa/depthwise_conv_hifi.cc | 37 +- .../kernels/xtensa/depthwise_conv_vision.cc | 73 +- .../micro/kernels/xtensa/fully_connected.cc | 32 +- .../xtensa/fully_connected_common_xtensa.cc | 19 +- .../kernels/xtensa/fully_connected_int8.cc | 30 +- .../kernels/xtensa/fully_connected_vision.cc | 74 +- .../micro/kernels/xtensa/transpose_conv.cc | 122 ++- .../lite/micro/memory_arena_threshold_test.cc | 12 +- .../memory_planner/greedy_memory_planner.cc | 22 +- tensorflow/lite/micro/micro_allocator.cc | 272 ++++++- tensorflow/lite/micro/micro_allocator.h | 20 +- tensorflow/lite/micro/micro_context.cc | 67 +- tensorflow/lite/micro/micro_context.h | 74 +- tensorflow/lite/micro/micro_interpreter.cc | 15 + tensorflow/lite/micro/micro_interpreter.h | 25 + .../lite/micro/micro_interpreter_context.cc | 173 ++++- .../lite/micro/micro_interpreter_context.h | 65 +- .../lite/micro/micro_interpreter_graph.cc | 12 + .../lite/micro/micro_interpreter_test.cc | 213 +++++- .../lite/micro/micro_mutable_op_resolver.h | 4 +- tensorflow/lite/micro/micro_profiler.cc | 19 +- tensorflow/lite/micro/micro_profiler.h | 6 +- .../lite/micro/micro_resource_variable.cc | 11 +- .../lite/micro/micro_resource_variable.h | 6 +- .../micro/micro_resource_variable_test.cc | 7 +- tensorflow/lite/micro/micro_utils.h | 13 +- .../lite/micro/recording_micro_allocator.cc | 37 +- .../lite/micro/recording_micro_allocator.h | 17 +- .../micro/recording_micro_allocator_test.cc | 68 +- .../lite/micro/test_helper_custom_ops.cc | 187 ++++- .../lite/micro/test_helper_custom_ops.h | 19 +- tensorflow/lite/micro/test_helpers.cc | 232 +++--- tensorflow/lite/micro/test_helpers.h | 252 ++++++- .../micro/tools/benchmarking/Makefile.inc | 9 + .../tools/benchmarking/collect_meta_data.sh | 2 +- .../benchmarking/generic_model_benchmark.cc | 173 ++++- .../lite/micro/tools/benchmarking/metrics.cc | 30 +- .../benchmarking/show_meta_data.cc.template | 7 + tensorflow/lite/micro/tools/make/Makefile | 2 + .../tools/make/targets/xtensa_makefile.inc | 6 + 84 files changed, 6876 insertions(+), 545 deletions(-) create mode 100644 tensorflow/lite/micro/compression.h create mode 100644 tensorflow/lite/micro/kernels/decompress.cc create mode 100644 tensorflow/lite/micro/kernels/decompress.h create mode 100644 tensorflow/lite/micro/kernels/decompress_common.cc create mode 100644 tensorflow/lite/micro/kernels/decompress_test.cc create mode 100644 tensorflow/lite/micro/kernels/xtensa/decompress.cc diff --git a/tensorflow/lite/micro/BUILD b/tensorflow/lite/micro/BUILD index 64a49731163..39e591815ba 100644 --- a/tensorflow/lite/micro/BUILD +++ b/tensorflow/lite/micro/BUILD @@ -27,6 +27,16 @@ tflm_cc_library( ], ) +tflm_cc_library( + name = "compression", + hdrs = [ + "compression.h", + ], + deps = [ + "//tensorflow/lite/c:common", + ], +) + tflm_cc_library( # TODO(b/187093492): Rename to micro_interpreter. name = "micro_framework", @@ -62,10 +72,14 @@ tflm_cc_library( "micro_context.h", ], deps = [ + ":compression", ":micro_common", ":micro_graph", ":micro_log", + ":micro_profiler", + "//tensorflow/lite:type_to_tflitetype", "//tensorflow/lite/c:common", + "//tensorflow/lite/micro/kernels:decompress", ], ) @@ -135,6 +149,7 @@ tflm_cc_library( ":memory_helpers", ":micro_allocator", ":micro_common", + ":micro_context", ":micro_graph", ":micro_log", ":micro_profiler", @@ -162,6 +177,7 @@ tflm_cc_library( tflm_cc_library( name = "micro_allocator", srcs = [ + "compression.h", "micro_allocation_info.cc", "micro_allocator.cc", ], @@ -170,6 +186,7 @@ tflm_cc_library( "micro_allocator.h", ], deps = [ + ":compression", ":flatbuffer_utils", ":memory_helpers", ":micro_arena_constants", @@ -182,6 +199,7 @@ tflm_cc_library( "//tensorflow/lite/micro/arena_allocator:non_persistent_arena_buffer_allocator", "//tensorflow/lite/micro/arena_allocator:persistent_arena_buffer_allocator", "//tensorflow/lite/micro/arena_allocator:simple_memory_allocator", + "//tensorflow/lite/micro/compression:metadata_saved", "//tensorflow/lite/micro/memory_planner:greedy_memory_planner", "//tensorflow/lite/micro/memory_planner:linear_memory_planner", "//tensorflow/lite/micro/memory_planner:micro_memory_planner", @@ -235,7 +253,9 @@ tflm_cc_library( "test_helpers.h", ], deps = [ + ":compression", ":memory_helpers", + ":micro_log", ":micro_utils", ":op_resolvers", "//tensorflow/lite:type_to_tflitetype", diff --git a/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h b/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h index ebd376466b6..69b049404c0 100644 --- a/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h +++ b/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h @@ -74,8 +74,6 @@ class NonPersistentArenaBufferAllocator : public INonPersistentBufferAllocator { // takes in account any temporary allocations. size_t GetAvailableMemory(size_t alignment) const override; - TF_LITE_REMOVE_VIRTUAL_DELETE - private: // The memory arena that this allocator manages. uint8_t* const buffer_head_; @@ -97,6 +95,8 @@ class NonPersistentArenaBufferAllocator : public INonPersistentBufferAllocator { // Count of outstanding temp buffers. int temp_buffer_count_ = 0; bool resizable_buffer_allocated_ = false; + + TF_LITE_REMOVE_VIRTUAL_DELETE }; } // namespace tflite diff --git a/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h b/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h index 2c8e3dca53b..a86d425d7c6 100644 --- a/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h +++ b/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h @@ -39,7 +39,6 @@ class PersistentArenaBufferAllocator : public IPersistentBufferAllocator { // Returns the size of all persistent allocations in bytes. size_t GetPersistentUsedBytes() const override; - TF_LITE_REMOVE_VIRTUAL_DELETE private: // The memory arena that this allocator manages. uint8_t* const buffer_head_; @@ -51,6 +50,8 @@ class PersistentArenaBufferAllocator : public IPersistentBufferAllocator { // So in essence, the allocated region grows from the bottom and emulates // SingleArenaBufferAllocator's persistent part. uint8_t* tail_temp_; + + TF_LITE_REMOVE_VIRTUAL_DELETE }; } // namespace tflite diff --git a/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h b/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h index a2e39588963..771c2deb436 100644 --- a/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h +++ b/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h @@ -110,8 +110,6 @@ class SingleArenaBufferAllocator : public INonPersistentBufferAllocator, // account any temporary allocations. size_t GetUsedBytes() const; - TF_LITE_REMOVE_VIRTUAL_DELETE - protected: // Returns a pointer to the current end of the head buffer. uint8_t* head() const; @@ -137,6 +135,8 @@ class SingleArenaBufferAllocator : public INonPersistentBufferAllocator, intptr_t temp_buffer_ptr_check_sum_ = 0; // Count of outstanding temp buffers. int temp_buffer_count_ = 0; + + TF_LITE_REMOVE_VIRTUAL_DELETE }; } // namespace tflite diff --git a/tensorflow/lite/micro/compression.h b/tensorflow/lite/micro/compression.h new file mode 100644 index 00000000000..f944a8a37b9 --- /dev/null +++ b/tensorflow/lite/micro/compression.h @@ -0,0 +1,68 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_COMPRESSION_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_COMPRESSION_H_ + +#ifdef USE_TFLM_COMPRESSION + +#include "tensorflow/lite/c/common.h" + +namespace tflite { + +// +// Compressed tensors +// + +static constexpr const char* kCompressionMetadataString = + "COMPRESSION_METADATA"; + +enum class CompressionScheme : uint8_t { + kBinQuant, +}; + +struct LookupTableData { + static constexpr size_t kMaxBitWidth = 7; + static constexpr size_t kMaxValueTableChannelStride = 128; + + const void* value_table; // Pointer into FlatBuffer Values. + uint8_t value_table_channel_stride; // elements per channel + uint8_t compressed_bit_width : 3; // 1 to 7 bits + bool is_per_channel_quantized : 1; // tensor is per-channel quantized + bool use_alternate_axis : 1; // shape default channel: + // 0 = first, 1 = last + uint8_t reserved : 3; +}; + +union CompressionData { + LookupTableData* lut_data; +}; + +struct CompressionTensorData { + CompressionScheme scheme; + CompressionData data; +}; + +struct CompressedTensorList { + // Sparsely populated array with the same number of elements as there are + // tensors in the Subgraph. An alternative would include a tensor index in + // the struct for each and walk the list on look up. This could be slow. + const CompressionTensorData** tensors; +}; + +} // namespace tflite + +#endif // USE_TFLM_COMPRESSION +#endif // TENSORFLOW_LITE_MICRO_MICRO_COMPRESSION_H_ diff --git a/tensorflow/lite/micro/compression/model_facade.py b/tensorflow/lite/micro/compression/model_facade.py index 2e58d8080f1..6a8afc8bac8 100644 --- a/tensorflow/lite/micro/compression/model_facade.py +++ b/tensorflow/lite/micro/compression/model_facade.py @@ -100,10 +100,37 @@ def __init__(self, operator, index, subgraph): def opcode(self) -> tflite.OperatorCodeT: return self.subgraph.model.operatorCodes[self.operator.opcodeIndex] + @property + def builtin_opcode(self) -> int: + result: int = self.opcode.deprecatedBuiltinCode + if result == tflite.BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES: + result = self.opcode.builtinCode + return result + @property def inputs(self): return _IndirectIterator(self.operator.inputs, self.subgraph.tensors) + @property + def outputs(self): + return _IndirectIterator(self.operator.outputs, self.subgraph.tensors) + + @property + def inputs_indices(self): + return self.operator.inputs + + @property + def outputs_indices(self): + return self.operator.outputs + + @property + def builtin_options_type(self) -> int: + return self.operator.builtinOptionsType + + @property + def builtin_options(self): + return self.operator.builtinOptions + _NP_DTYPES = { tflite.TensorType.FLOAT16: np.dtype("inputs->size) { + int index = node->inputs->data[tensor_idx]; + if (index >= 0 && compressed_tensors_->tensors[index] != nullptr) { + return true; + } + } + + return false; +} + +// Only available during Prepare. The kernel is responsible for storing the +// scratch buffer handle. +int FakeMicroContext::AllocateDecompressionScratchBuffer(const TfLiteNode* node, + int tensor_idx) { + if (compressed_tensors_ == nullptr || tensor_idx >= node->inputs->size) { + return -1; + } + int index = node->inputs->data[tensor_idx]; + if (index < 0 || compressed_tensors_->tensors[index] == nullptr) { + return -1; + } + TfLiteTensor* tensor = &tensors_[index]; + int scratch_index = -1; + TfLiteStatus result = + RequestScratchBufferInArena(tensor->bytes, &scratch_index); + if (result != kTfLiteOk) { + return -1; + } + + return scratch_index; +} + +// Available during Prepare & Eval. Returns nullptr if tensor is not +// compressed. +const CompressionTensorData* FakeMicroContext::GetTensorCompressionData( + const TfLiteNode* node, int tensor_idx) { + if (compressed_tensors_ == nullptr || tensor_idx >= node->inputs->size) { + return nullptr; + } + + int index = node->inputs->data[tensor_idx]; + if (index < 0) { + return nullptr; + } + + return compressed_tensors_->tensors[index]; +} + +#endif // USE_TFLM_COMPRESSION + } // namespace tflite diff --git a/tensorflow/lite/micro/fake_micro_context.h b/tensorflow/lite/micro/fake_micro_context.h index 46d8a9b1ec4..7cf9c682e5c 100644 --- a/tensorflow/lite/micro/fake_micro_context.h +++ b/tensorflow/lite/micro/fake_micro_context.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,7 +30,12 @@ class FakeMicroContext : public MicroContext { ~FakeMicroContext() = default; FakeMicroContext(TfLiteTensor* tensors, SingleArenaBufferAllocator* allocator, - MicroGraph* micro_graph); + MicroGraph* micro_graph +#ifdef USE_TFLM_COMPRESSION + , + const CompressedTensorList* compressed_tensors = nullptr +#endif // USE_TFLM_COMPRESSION + ); void* AllocatePersistentBuffer(size_t bytes) override; TfLiteStatus RequestScratchBufferInArena(size_t bytes, @@ -50,6 +55,24 @@ class FakeMicroContext : public MicroContext { void* external_context() override; MicroGraph& graph() override; +#ifdef USE_TFLM_COMPRESSION + + // Available during Prepare & Eval. Returns false if tensor is not + // compressed. + bool IsTensorCompressed(const TfLiteNode* node, int tensor_idx) override; + + // Only available during Prepare. The kernel is responsible for storing the + // scratch buffer handle. + int AllocateDecompressionScratchBuffer(const TfLiteNode* node, + int tensor_idx) override; + + // Available during Prepare & Eval. Returns nullptr if tensor is not + // compressed. + const CompressionTensorData* GetTensorCompressionData( + const TfLiteNode* node, int tensor_idx) override; + +#endif // USE_TFLM_COMPRESSION + private: static constexpr int kNumScratchBuffers_ = 12; @@ -62,6 +85,15 @@ class FakeMicroContext : public MicroContext { SingleArenaBufferAllocator* allocator_; +#ifdef USE_TFLM_COMPRESSION + + // + // Compression + // + const CompressedTensorList* compressed_tensors_; + +#endif // USE_TFLM_COMPRESSION + TF_LITE_REMOVE_VIRTUAL_DELETE }; diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 8e416c28f09..4ba18075b3c 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -79,6 +79,29 @@ tflm_cc_library( ], ) +tflm_cc_library( + name = "decompress", + srcs = [ + "decompress.cc", + "decompress_common.cc", + ], + hdrs = [ + "decompress.h", + ], + visibility = [ + ":kernel_friends", + ":tflite_micro", + ], + deps = [ + "//tensorflow/lite:type_to_tflitetype", + "//tensorflow/lite/kernels/internal:compatibility", + "//tensorflow/lite/micro:compression", + "//tensorflow/lite/micro:micro_common", + "//tensorflow/lite/micro:micro_log", + "//tensorflow/lite/micro:micro_profiler", + ], +) + tflm_cc_library( name = "detection_postprocess_flexbuffers_generated_data", srcs = [ @@ -613,6 +636,24 @@ tflm_cc_test( ], ) +tflm_cc_test( + name = "decompress_test", + srcs = [ + "decompress_test.cc", + ], + target_compatible_with = select({ + "//conditions:default": ["@platforms//:incompatible"], + "//:with_compression_enabled": [], + }), + deps = [ + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_arena_constants", + "//tensorflow/lite/micro:micro_log", + "//tensorflow/lite/micro:test_helpers", + "//tensorflow/lite/micro/testing:micro_test", + ], +) + tflm_cc_test( name = "depth_to_space_test", srcs = [ diff --git a/tensorflow/lite/micro/kernels/Makefile.inc b/tensorflow/lite/micro/kernels/Makefile.inc index 0bd846bc679..f4456242fef 100644 --- a/tensorflow/lite/micro/kernels/Makefile.inc +++ b/tensorflow/lite/micro/kernels/Makefile.inc @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -180,6 +180,10 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/unpack_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/while_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/zeros_like_test.cc +ifeq ($(ENABLE_COMPRESSION), yes) +MICROLITE_KERNEL_SIMPLE_TEST_SRCS += $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress_test.cc +endif + # Generate simple kernel test targets in a common way $(foreach TEST_TARGET,$(MICROLITE_KERNEL_SIMPLE_TEST_SRCS),\ $(eval $(call microlite_test,kernel_$(notdir $(basename $(TEST_TARGET))),$(TEST_TARGET)))) diff --git a/tensorflow/lite/micro/kernels/assign_variable.cc b/tensorflow/lite/micro/kernels/assign_variable.cc index bd99bd1aa0c..9374279e9af 100644 --- a/tensorflow/lite/micro/kernels/assign_variable.cc +++ b/tensorflow/lite/micro/kernels/assign_variable.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/micro/micro_graph.h" #include "tensorflow/lite/micro/micro_log.h" #include "tensorflow/lite/micro/micro_resource_variable.h" +#include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/schema/schema_generated.h" namespace tflite { @@ -35,6 +36,20 @@ namespace { constexpr int kInputVariableId = 0; constexpr int kInputValue = 1; +#ifdef USE_TFLM_COMPRESSION + +struct OpData { + // scratch buffer for compressed input tensor + int scratch_index; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpData)); +} + +#endif // USE_TFLM_COMPRESSION + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0); @@ -70,6 +85,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context, input_value)); } +#ifdef USE_TFLM_COMPRESSION + + TFLITE_DCHECK(node->user_data != nullptr); + OpData* data = static_cast(node->user_data); + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + data->scratch_index = + micro_context->AllocateDecompressionScratchBuffer(node, kInputValue); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(input_value); return kTfLiteOk; } @@ -93,15 +119,36 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { "ResourceVariables and pass it to the interpreter."); return kTfLiteError; } + +#ifdef USE_TFLM_COMPRESSION + OpData* data = static_cast(node->user_data); + const CompressionTensorData* comp_td = + micro_context->GetTensorCompressionData(node, kInputValue); + const void* buffer = tflite::micro::GetTensorData( + micro_context, input_value, comp_td, data->scratch_index); +#else // USE_TFLM_COMPRESSION + const void* buffer = tflite::micro::GetTensorData(input_value); +#endif // USE_TFLM_COMPRESSION + TF_LITE_ENSURE_OK(context, - resources->Assign(input_id->data.i32[0], input_value)); + resources->Assign(input_id->data.i32[0], + EvalTensorBytes(input_value), buffer)); return kTfLiteOk; } } // namespace. +#ifdef USE_TFLM_COMPRESSION + +TFLMRegistration Register_ASSIGN_VARIABLE() { + return tflite::micro::RegisterOp(Init, Prepare, Eval); + +#else // USE_TFLM_COMPRESSION + TFLMRegistration Register_ASSIGN_VARIABLE() { return tflite::micro::RegisterOp(nullptr, Prepare, Eval); + +#endif // USE_TFLM_COMPRESSION } } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/concatenation.cc b/tensorflow/lite/micro/kernels/concatenation.cc index 57d63a916a1..151d3b47ed5 100644 --- a/tensorflow/lite/micro/kernels/concatenation.cc +++ b/tensorflow/lite/micro/kernels/concatenation.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,6 +33,13 @@ constexpr int kOutputTensor = 0; struct OpData { ConcatenationParams params; + +#ifdef USE_TFLM_COMPRESSION + + // scratch buffers for compressed tensors + int scratch_indices[kMaxInputNum]; + +#endif // USE_TFLM_COMPRESSION }; // Handles negative axis index, coerces to positive index value. @@ -52,8 +59,6 @@ inline int CalculatePositiveAxis(int axis, const TfLiteTensor* output_tensor) { inline void GetAllInputTensorShapes(const TfLiteContext* context, const TfLiteNode* node, RuntimeShape all_shapes[kMaxInputNum]) { - TFLITE_DCHECK(context != nullptr); - TFLITE_DCHECK(node != nullptr); for (int i = 0; i < node->inputs->size; ++i) { const TfLiteEvalTensor* t = tflite::micro::GetEvalInput(context, node, i); RuntimeShape shape = tflite::micro::GetTensorShape(t); @@ -73,12 +78,22 @@ inline void GetShapesPointers(const RuntimeShape* shapes, size_t num, template inline void GetAllInputTensorData(const TfLiteContext* context, const TfLiteNode* node, - T* all_data[kMaxInputNum]) { - TFLITE_DCHECK(context != nullptr); - TFLITE_DCHECK(node != nullptr); + const T* all_data[kMaxInputNum]) { +#ifdef USE_TFLM_COMPRESSION + const OpData* data = static_cast(node->user_data); + MicroContext* micro_context = GetMicroContext(context); +#endif // USE_TFLM_COMPRESSION + for (int i = 0; i < node->inputs->size; ++i) { const TfLiteEvalTensor* t = tflite::micro::GetEvalInput(context, node, i); +#ifdef USE_TFLM_COMPRESSION + const CompressionTensorData* comp_td = + micro_context->GetTensorCompressionData(node, i); + all_data[i] = tflite::micro::GetTensorData(micro_context, t, comp_td, + data->scratch_indices[i]); +#else // USE_TFLM_COMPRESSION all_data[i] = tflite::micro::GetTensorData(t); +#endif // USE_TFLM_COMPRESSION } } @@ -88,6 +103,10 @@ void EvalUnquantized(TfLiteContext* context, TfLiteNode* node) { RuntimeShape inputs_shape[kMaxInputNum]; const RuntimeShape* inputs_shape_ptr[kMaxInputNum]; const data_type* inputs_data[kMaxInputNum]; + TFLITE_DCHECK(context != nullptr); + TFLITE_DCHECK(node != nullptr); + TFLITE_DCHECK(node->user_data != nullptr); + const OpData* data = static_cast(node->user_data); GetAllInputTensorShapes(context, node, inputs_shape); GetShapesPointers(inputs_shape, node->inputs->size, inputs_shape_ptr); GetAllInputTensorData(context, node, inputs_data); @@ -95,9 +114,6 @@ void EvalUnquantized(TfLiteContext* context, TfLiteNode* node) { TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, kOutputTensor); - TFLITE_DCHECK(node->user_data != nullptr); - const OpData* data = static_cast(node->user_data); - reference_ops::Concatenation(data->params, inputs_shape_ptr, inputs_data, tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); @@ -126,7 +142,6 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteType output_type = output_tensor->type; micro_context->DeallocateTempTfLiteTensor(input_tensor); - micro_context->DeallocateTempTfLiteTensor(output_tensor); // Check activation and input type TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); @@ -136,16 +151,22 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) { input_type == kTfLiteInt64 || input_type == kTfLiteBool); // Output type must match input type - TF_LITE_ENSURE_EQ(context, output_type, input_type); + TF_LITE_ENSURE_TYPES_EQ(context, output_type, input_type); // This implementation does not support large number of input tensors const int num_inputs = NumInputs(node); TF_LITE_ENSURE(context, num_inputs <= kMaxInputNum); - // Shapes with dimensions >4 are not yet supported with static allocation. + // Calculate OpData. + TFLITE_DCHECK(node->user_data != nullptr); + OpData* data = static_cast(node->user_data); + + // Shapes with dimensions > kMaxSmallSize are not yet supported with static + // allocation. for (int i = 0; i < num_inputs; ++i) { TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, i); TF_LITE_ENSURE(context, input != nullptr); + TF_LITE_ENSURE_TYPES_EQ(context, input->type, input_type); int num_dimensions = NumDimensions(input); if (num_dimensions > RuntimeShape::kMaxSmallSize) { @@ -155,62 +176,53 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) { RuntimeShape::kMaxSmallSize, num_dimensions); return kTfLiteError; } + + if (input_type == kTfLiteInt8) { + // Make sure there is no re-scaling needed for Int8 quantized kernel. This + // is a restriction we introduced to Int8 kernels. + TF_LITE_ENSURE_EQ(context, static_cast(input->params.scale), + static_cast(output_tensor->params.scale)); + TF_LITE_ENSURE_EQ(context, input->params.zero_point, + output_tensor->params.zero_point); + } else if (input_type == kTfLiteInt16) { + // Make sure that all Int16 inputs have a null zero-point. + TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); + } + +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + data->scratch_indices[i] = + micro_context->AllocateDecompressionScratchBuffer(node, i); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(input); } - // Calculate OpData. - TFLITE_DCHECK(node->user_data != nullptr); - OpData* data = static_cast(node->user_data); - - TfLiteTensor* output = - micro_context->AllocateTempOutputTensor(node, kOutputTensor); - TF_LITE_ENSURE(context, output != nullptr); + if (input_type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, output_tensor->params.zero_point, 0); + } switch (output_type) { // Already know in/outtypes are same. case kTfLiteBool: case kTfLiteFloat32: + case kTfLiteInt8: case kTfLiteInt16: case kTfLiteInt32: case kTfLiteInt64: { - data->params.axis = CalculatePositiveAxis(params->axis, output); - data->params.inputs_count = node->inputs->size; - break; - } - case kTfLiteInt8: { - data->params.axis = CalculatePositiveAxis(params->axis, output); + data->params.axis = CalculatePositiveAxis(params->axis, output_tensor); data->params.inputs_count = node->inputs->size; - - float* input_scales = - reinterpret_cast(context->AllocatePersistentBuffer( - context, node->inputs->size * sizeof(float))); - - int32_t* input_zero_points = - reinterpret_cast(context->AllocatePersistentBuffer( - context, node->inputs->size * sizeof(int32_t))); - - // Allocate persistent scale and zeropoint buffers. - // Store input scale and zero point values in OpParams: - for (int i = 0; i < node->inputs->size; ++i) { - TfLiteTensor* t = micro_context->AllocateTempInputTensor(node, i); - TF_LITE_ENSURE(context, t != nullptr); - input_scales[i] = t->params.scale; - input_zero_points[i] = t->params.zero_point; - micro_context->DeallocateTempTfLiteTensor(t); - } - - data->params.input_scale = input_scales; - data->params.input_zeropoint = input_zero_points; - data->params.output_zeropoint = output->params.zero_point; - data->params.output_scale = output->params.scale; break; } default: - MicroPrintf("Op Concatenation does not currently support Type '%s'.", + MicroPrintf("Op Concatenation does not currently support type '%s'.", TfLiteTypeGetName(output_type)); return kTfLiteError; } - micro_context->DeallocateTempTfLiteTensor(output); + micro_context->DeallocateTempTfLiteTensor(output_tensor); return kTfLiteOk; } diff --git a/tensorflow/lite/micro/kernels/concatenation_test.cc b/tensorflow/lite/micro/kernels/concatenation_test.cc index ddbc74d4aa4..c7e698007ea 100644 --- a/tensorflow/lite/micro/kernels/concatenation_test.cc +++ b/tensorflow/lite/micro/kernels/concatenation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" @@ -56,9 +57,14 @@ void TestConcatenateOneInput(int* input1_dims_data, const T* input1_data, } template -void TestConcatenateTwoInputs(int* input1_dims_data, const T* input1_data, - int* input2_dims_data, const T* input2_data, - int axis, int* output_dims_data, T* output_data) { +void TestConcatenateTwoInputs( + int* input1_dims_data, const T* input1_data, int* input2_dims_data, + const T* input2_data, int axis, int* output_dims_data, T* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo (*comp_info)[2] = nullptr +#endif // USE_TFLM_COMPRESSION +) { TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data); TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data); TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); @@ -70,6 +76,21 @@ void TestConcatenateTwoInputs(int* input1_dims_data, const T* input1_data, CreateTensor(input2_data, input2_dims), CreateTensor(output_data, output_dims)}; +#ifdef USE_TFLM_COMPRESSION + + TestCompressedList tcl; + const CompressedTensorList* comp_list_p = nullptr; + + if (comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ(tcl.AddInput((*comp_info)[0], tensors[0], 0), + kTfLiteOk); + TF_LITE_MICRO_EXPECT_EQ(tcl.AddInput((*comp_info)[1], tensors[1], 1), + kTfLiteOk); + comp_list_p = tcl.GetCompressedTensorList(); + } + +#endif // USE_TFLM_COMPRESSION + int inputs_array_data[] = {2, 0, 1}; TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); int outputs_array_data[] = {1, 2}; @@ -83,7 +104,12 @@ void TestConcatenateTwoInputs(int* input1_dims_data, const T* input1_data, const TFLMRegistration registration = Register_CONCATENATION(); micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, outputs_array, - reinterpret_cast(&builtin_data)); + reinterpret_cast(&builtin_data) +#ifdef USE_TFLM_COMPRESSION + , + nullptr, comp_list_p +#endif // USE_TFLM_COMPRESSION + ); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); @@ -92,9 +118,19 @@ void TestConcatenateTwoInputs(int* input1_dims_data, const T* input1_data, void TestConcatenateTwoFloatInputs( int* input1_dims_data, const float* input1_data, int* input2_dims_data, const float* input2_data, int axis, int* output_dims_data, - const float* expected_output_data, float* output_data) { + const float* expected_output_data, float* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo (*comp_info)[2] = nullptr +#endif // USE_TFLM_COMPRESSION +) { TestConcatenateTwoInputs(input1_dims_data, input1_data, input2_dims_data, - input2_data, axis, output_dims_data, output_data); + input2_data, axis, output_dims_data, output_data +#ifdef USE_TFLM_COMPRESSION + , + comp_info +#endif // USE_TFLM_COMPRESSION + ); TfLiteIntArray* dims = tflite::testing::IntArrayFromInts(output_dims_data); const int output_dims_count = ElementCount(*dims); @@ -148,6 +184,68 @@ void TestConcatenateQuantizedTwoInputs( } } +#ifdef USE_TFLM_COMPRESSION + +template +void TestConcatenateQuantizedTwoInputsCompressed( + int* input1_dims_data, const uint8_t* input1_data, int* input2_dims_data, + const uint8_t* input2_data, const float input_scale, + const int input_zero_point, int axis, int* output_dims_data, + const T* expected_output_data, const float output_scale, + const int output_zero_point, T* output_data, + const TestCompressionInfo (&comp_info)[2]) { + TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data); + TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + + constexpr int input_size = 2; + constexpr int output_size = 1; + constexpr int tensors_size = input_size + output_size; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input1_data, input1_dims, input_scale, + input_zero_point, false, typeToTfLiteType()), + CreateQuantizedTensor(input2_data, input2_dims, input_scale, + input_zero_point, false, typeToTfLiteType()), + CreateQuantizedTensor(output_data, output_dims, output_scale, + output_zero_point)}; + +#ifdef USE_TFLM_COMPRESSION + + TestCompressedList tcl; + const CompressedTensorList* comp_list_p = nullptr; + + TF_LITE_MICRO_EXPECT_EQ(tcl.AddInput(comp_info[0], tensors[0], 0), kTfLiteOk); + TF_LITE_MICRO_EXPECT_EQ(tcl.AddInput(comp_info[1], tensors[1], 1), kTfLiteOk); + comp_list_p = tcl.GetCompressedTensorList(); + +#endif // USE_TFLM_COMPRESSION + + int inputs_array_data[] = {2, 0, 1}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 2}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + TfLiteConcatenationParams builtin_data = { + .axis = axis, + .activation = kTfLiteActNone // Only activation supported in this impl + }; + + const TFLMRegistration registration = Register_CONCATENATION(); + micro::KernelRunner runner( + registration, tensors, tensors_size, inputs_array, outputs_array, + reinterpret_cast(&builtin_data), nullptr, comp_list_p); + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); + + const int output_dims_count = ElementCount(*output_dims); + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]); + } +} + +#endif // USE_TFLM_COMPRESSION + } // namespace } // namespace testing } // namespace tflite @@ -237,6 +335,43 @@ TF_LITE_MICRO_TEST(TwoInputsAllAxesCombinations) { output_shape_axis1, output_value_axis1, output_data); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(TwoInputsFloatCompressed) { + int input_shape[] = {2, 2, 3}; + const float input1_value[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + const float input2_value[] = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; + // Align the tensor data the same as a Buffer in the schema + alignas(16) const uint8_t inputs_compressed[] = {0x05, 0x39, 0x40}; + constexpr int kBitWidth = 3; + + // expected output when concatenating on axis 0 + int output_shape_axis0[] = {2, 4, 3}; + const float output_value_axis0[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; + + float output_data[std::extent::value]; + + tflite::testing::TestCompressionInfo comp_info[2] = {}; + comp_info[0].scheme = tflite::CompressionScheme::kBinQuant; + comp_info[0].value_table = input1_value; + comp_info[0].value_table_stride = std::extent::value; + comp_info[0].bit_width = kBitWidth; + comp_info[1].scheme = tflite::CompressionScheme::kBinQuant; + comp_info[1].value_table = input2_value; + comp_info[1].value_table_stride = std::extent::value; + comp_info[1].bit_width = kBitWidth; + + // Axis = 0 + tflite::testing::TestConcatenateTwoFloatInputs( + input_shape, reinterpret_cast(inputs_compressed), + input_shape, reinterpret_cast(inputs_compressed), + /* axis */ 0, output_shape_axis0, output_value_axis0, output_data, + &comp_info); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(TwoInputsQuantizedInt8) { const int axis = 2; int input_shape[] = {3, 2, 1, 2}; @@ -260,6 +395,45 @@ TF_LITE_MICRO_TEST(TwoInputsQuantizedInt8) { output_zero_point, output_data); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(TwoInputsQuantizedInt8Compressed) { + const int axis = 2; + int input_shape[] = {3, 2, 1, 2}; + int output_shape[] = {3, 2, 1, 4}; + + const float input_scale = 0.1f; + const int input_zero_point = 0; + const float output_scale = 0.1f; + const int output_zero_point = 0; + + const int8_t input1_values[] = {1, 2, 3, 4}; + const int8_t input2_values[] = {5, 6, 7, 8}; + const int8_t output_value[] = {1, 2, 5, 6, 3, 4, 7, 8}; + // Align the tensor data the same as a Buffer in the schema + alignas(16) const uint8_t input_compressed[] = {0x1B}; + constexpr int kBitWidth = 2; + + int8_t output_data[std::extent::value]; + + tflite::testing::TestCompressionInfo comp_info[2] = {}; + comp_info[0].scheme = tflite::CompressionScheme::kBinQuant; + comp_info[0].value_table = input1_values; + comp_info[0].value_table_stride = std::extent::value; + comp_info[0].bit_width = kBitWidth; + comp_info[1].scheme = tflite::CompressionScheme::kBinQuant; + comp_info[1].value_table = input2_values; + comp_info[1].value_table_stride = std::extent::value; + comp_info[1].bit_width = kBitWidth; + + tflite::testing::TestConcatenateQuantizedTwoInputsCompressed( + input_shape, input_compressed, input_shape, input_compressed, input_scale, + input_zero_point, axis, output_shape, output_value, output_scale, + output_zero_point, output_data, comp_info); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(TwoInputsQuantizedInt16) { const int axis = 2; int input_shape[] = {3, 2, 1, 2}; @@ -283,6 +457,45 @@ TF_LITE_MICRO_TEST(TwoInputsQuantizedInt16) { output_zero_point, output_data); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(TwoInputsQuantizedInt16Compressed) { + const int axis = 2; + int input_shape[] = {3, 2, 1, 2}; + int output_shape[] = {3, 2, 1, 4}; + + const float input_scale = 0.1f; + const int input_zero_point = 0; + const float output_scale = 0.1f; + const int output_zero_point = 0; + + const int16_t input1_values[] = {1, 2, 3, 4}; + const int16_t input2_values[] = {5, 6, 7, 8}; + const int16_t output_value[] = {1, 2, 5, 6, 3, 4, 7, 8}; + // Align the tensor data the same as a Buffer in the schema + alignas(16) const uint8_t input_compressed[] = {0x1B}; + constexpr int kBitWidth = 2; + + int16_t output_data[std::extent::value]; + + tflite::testing::TestCompressionInfo comp_info[2] = {}; + comp_info[0].scheme = tflite::CompressionScheme::kBinQuant; + comp_info[0].value_table = input1_values; + comp_info[0].value_table_stride = std::extent::value; + comp_info[0].bit_width = kBitWidth; + comp_info[1].scheme = tflite::CompressionScheme::kBinQuant; + comp_info[1].value_table = input2_values; + comp_info[1].value_table_stride = std::extent::value; + comp_info[1].bit_width = kBitWidth; + + tflite::testing::TestConcatenateQuantizedTwoInputsCompressed( + input_shape, input_compressed, input_shape, input_compressed, input_scale, + input_zero_point, axis, output_shape, output_value, output_scale, + output_zero_point, output_data, comp_info); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(ThreeDimensionalTwoInputsDifferentShapes) { const int axis = 1; diff --git a/tensorflow/lite/micro/kernels/conv.cc b/tensorflow/lite/micro/kernels/conv.cc index 0df35fce4eb..3e4fb62318d 100644 --- a/tensorflow/lite/micro/kernels/conv.cc +++ b/tensorflow/lite/micro/kernels/conv.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,15 +45,35 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); const auto& data = *(static_cast(node->user_data)); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + switch (input->type) { // Already know in/out types are same. case kTfLiteFloat32: { tflite::reference_ops::Conv( ConvParamsFloat(params, data), tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr); @@ -67,9 +87,18 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); } else if (bias->type == kTfLiteInt64) { @@ -79,9 +108,18 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); } else { @@ -119,9 +157,18 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; diff --git a/tensorflow/lite/micro/kernels/conv.h b/tensorflow/lite/micro/kernels/conv.h index 0c8073f48f0..0090053e03c 100644 --- a/tensorflow/lite/micro/kernels/conv.h +++ b/tensorflow/lite/micro/kernels/conv.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -49,6 +49,14 @@ struct OpDataConv { // A buffer used to store unpacked filter values. This is used if the source // tensor is of n-bit precision that cannot be easily processed by kernels. int filter_buffer_index; + +#ifdef USE_TFLM_COMPRESSION + + // scratch buffers for compressed tensors + int weights_scratch_index; + int bias_scratch_index; + +#endif // USE_TFLM_COMPRESSION }; extern const int kConvInputTensor; diff --git a/tensorflow/lite/micro/kernels/conv_common.cc b/tensorflow/lite/micro/kernels/conv_common.cc index 51c7a6ff2d6..9f0f2f79588 100644 --- a/tensorflow/lite/micro/kernels/conv_common.cc +++ b/tensorflow/lite/micro/kernels/conv_common.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -209,6 +209,23 @@ TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node) { &data->filter_buffer_index); } +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + if (micro_context->IsTensorCompressed(node, kConvWeightsTensor) && + filter->type == kTfLiteInt4) { + MicroPrintf("Compression not supported with INT4 tensors"); + return kTfLiteError; + } + data->weights_scratch_index = + micro_context->AllocateDecompressionScratchBuffer(node, + kConvWeightsTensor); + data->bias_scratch_index = + micro_context->AllocateDecompressionScratchBuffer(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(filter); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(output); diff --git a/tensorflow/lite/micro/kernels/conv_test.cc b/tensorflow/lite/micro/kernels/conv_test.cc index 0fb9411a3f0..48eddeb9958 100644 --- a/tensorflow/lite/micro/kernels/conv_test.cc +++ b/tensorflow/lite/micro/kernels/conv_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/conv_test.h" +#include + #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" @@ -46,6 +48,90 @@ static int kOutputShape[] = {4, 2, 1, 2, 3}; static const float kGoldenData[kOutputElements] = {18, 2, 5, 18, 2, 5, 17, 4, 3, 37, 4, 3}; +#ifdef USE_TFLM_COMPRESSION + +// compressed filter data for kBinQuant scheme, matches kFilterData +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantFilterData[] = { + 0x05, 0x38, 0x20, 0x90, 0x00, +}; +constexpr float kBinQuantFilterValueTable[] = { + 1, 2, 3, 4, -1, +}; +constexpr size_t kBinQuantFilterValueTableElements = + std::extent::value; +constexpr int kBinQuantFilterBitWidth = 3; +// compressed bias data for kBinQuant scheme, matches kBiasData +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasData[] = {0x18}; +constexpr int kBinQuantBiasBitWidth = 2; + +// Common inputs and outputs for quantized compressed tensor tests. +// Values from TfLite conv_test.cc SimplePerChannelTest. +static int kInputShapeQ1[] = {4, 1, 2, 3, 2}; +static const float kInputDataQ1[] = { + // [1 * 2 * 3 * 2] as [batch, y, x, input_channel] + 3, 2, // batch = 0, y = 0, x = 0 + 1, -1, // batch = 0, y = 0, x = 1 + -2, -3, // batch = 0, y = 0, x = 2 + 4, 3, // batch = 0, y = 1, x = 0 + 2, -2, // batch = 0, y = 1, x = 1 + -3, -4, // batch = 0, y = 1, x = 2 +}; +constexpr size_t kInputElementsQ1 = std::extent::value; + +constexpr int kNumChannelsQ1 = 2; +static int kFilterShapeQ1[] = {4, 2, 2, 2, 2}; +// Original filter data: +// static constexpr float kFilterDataQ1[] = { +// // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel] +// 1, 2, // out channel = 0, y = 0, x = 0 +// 3, 4, // out channel = 0, y = 0, x = 1 +// 3, 4, // out channel = 0, y = 1, x = 0 +// 5, 6, // out channel = 0, y = 1, x = 1 +// 7, 8, // out channel = 1, y = 0, x = 0 +// 5, 6, // out channel = 1, y = 0, x = 1 +// 3, 4, // out channel = 1, y = 1, x = 0 +// 1, 2, // out channel = 1, y = 1, x = 1 +// }; + +static int kBiasShapeQ1[] = {1, 2}; +static const float kBiasDataQ1[] = {3, -2}; +constexpr size_t kBiasElementsQ1 = std::extent::value; + +static int kOutputShapeQ1[] = {4, 1, 1, 2, 2}; +static const float kGoldenDataQ1[] = {31, 64, -57, -46}; +constexpr int kOutputElementsQ1 = std::extent::value; +static const float kGoldenDataQ1_16[] = {31, 63.99804688, -57, -46}; + +// compressed filter data for kBinQuant scheme, matches kFilterDataQ1 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantFilterDataQ1[] = { + 0x05, 0x34, 0xE5, 0xDE, 0x54, 0xC1, +}; +constexpr float kBinQuantFilterValueTableQ1[] = { + 1, 2, 3, 4, 5, 6, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, +}; +constexpr size_t kBinQuantFilterValueTableElementsQ1 = + std::extent::value; +constexpr int kBinQuantFilterBitWidthQ1 = 3; +// compressed bias data for kBinQuant scheme, matches kBiasDataQ1 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasDataQ1[] = {0x00}; +constexpr int kBinQuantBiasBitWidthQ1 = 1; + +static TfLiteConvParams common_conv_params_q1 = { + kTfLitePaddingValid, // padding + 1, // stride_width + 1, // stride_height + kTfLiteActNone, // activation + 1, // dilation_width_factor + 1, // dilation_height_factor + kTfLiteNoType // quantized_bias_type +}; + +#endif // USE_TFLM_COMPRESSION + static TfLiteConvParams common_conv_params = { kTfLitePaddingValid, // padding 2, // stride_width @@ -122,6 +208,65 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannel) { output_data)); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelCompressed) { + const float input_scale = 0.5f; + const float output_scale = 0.5f; + const int input_zero_point = -1; + const int output_zero_point = -1; + constexpr float filter_scales[] = {tflite::testing::kNumChannelsQ1, 1.0f, + 2.0f}; + constexpr int filter_zero_points[] = {tflite::testing::kNumChannelsQ1, 0, 0}; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int8_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kBinQuantFilterValueTableElementsQ1]; + int32_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int8_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int8_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kBinQuantFilterValueTableElementsQ1 / + tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kBinQuantFilterValueTableQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestConvQuantizedPerChannelCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1, + golden_quantized, output_quantized, output_scale, output_zero_point, + &tflite::testing::common_conv_params_q1, tflite::Register_CONV_2D(), + &filter_comp_info, &bias_comp_info)); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(SimpleTestFloat) { float output_data[tflite::testing::kOutputElements]; @@ -136,6 +281,40 @@ TF_LITE_MICRO_TEST(SimpleTestFloat) { output_data)); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestFloatCompressed) { + tflite::testing::TestCompressionInfo filter_comp_info = {}; + tflite::testing::TestCompressionInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = tflite::testing::kBinQuantFilterValueTable; + filter_comp_info.value_table_stride = + tflite::testing::kBinQuantFilterValueTableElements; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidth; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = tflite::testing::kBiasData; + bias_comp_info.value_table_stride = tflite::testing::kBiasElements; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidth; + + float output_data[tflite::testing::kOutputElements]; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestConvFloat( + tflite::testing::kInputShape, tflite::testing::kInputData, + tflite::testing::kFilterShape, + reinterpret_cast(tflite::testing::kBinQuantFilterData), + tflite::testing::kBiasShape, + reinterpret_cast(tflite::testing::kBinQuantBiasData), + tflite::testing::kOutputShape, tflite::testing::kGoldenData, + &tflite::testing::common_conv_params, tflite::Register_CONV_2D(), + output_data, &filter_comp_info, &bias_comp_info)); +} + +#endif + TF_LITE_MICRO_TEST(InputAndFilterSameWidthHeight) { const int output_dims_count = 2; float output_data[output_dims_count]; @@ -246,6 +425,65 @@ TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel64bBias) { output_data)); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel64bBiasCompressed) { + const float input_scale = 128.0f / 65536; + const float output_scale = 128.0f / 65536; + const int input_zero_point = 0; + const int output_zero_point = 0; + constexpr float filter_scales[] = {tflite::testing::kNumChannelsQ1, 1.0f, + 2.0f}; + constexpr int filter_zero_points[] = {tflite::testing::kNumChannelsQ1, 0, 0}; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int16_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kBinQuantFilterValueTableElementsQ1]; + int64_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int16_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int16_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kBinQuantFilterValueTableElementsQ1 / + tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kBinQuantFilterValueTableQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestConvQuantizedPerChannelCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1_16, + golden_quantized, output_quantized, output_scale, output_zero_point, + &tflite::testing::common_conv_params_q1, tflite::Register_CONV_2D(), + &filter_comp_info, &bias_comp_info)); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel32bBias) { const int output_dims_count = 12; int16_t output_data[output_dims_count]; @@ -276,6 +514,65 @@ TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel32bBias) { output_data)); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel32bBiasCompressed) { + const float input_scale = 128.0f / 65536; + const float output_scale = 128.0f / 65536; + const int input_zero_point = 0; + const int output_zero_point = 0; + constexpr float filter_scales[] = {tflite::testing::kNumChannelsQ1, 1.0f, + 2.0f}; + constexpr int filter_zero_points[] = {tflite::testing::kNumChannelsQ1, 0, 0}; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int16_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kBinQuantFilterValueTableElementsQ1]; + int32_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int16_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int16_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kBinQuantFilterValueTableElementsQ1 / + tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kBinQuantFilterValueTableQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestConvQuantizedPerChannelCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1_16, + golden_quantized, output_quantized, output_scale, output_zero_point, + &tflite::testing::common_conv_params_q1, tflite::Register_CONV_2D(), + &filter_comp_info, &bias_comp_info)); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(SimpleTestDilatedQuantizedPerChannel) { const int output_dims_count = 24; int8_t output_data[output_dims_count]; diff --git a/tensorflow/lite/micro/kernels/conv_test.h b/tensorflow/lite/micro/kernels/conv_test.h index c655f043bcc..642f4c76d7a 100644 --- a/tensorflow/lite/micro/kernels/conv_test.h +++ b/tensorflow/lite/micro/kernels/conv_test.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/conv.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" #include "tensorflow/lite/micro/kernels/micro_ops.h" #include "tensorflow/lite/micro/test_helpers.h" @@ -26,35 +27,101 @@ limitations under the License. namespace tflite { namespace testing { -TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, TfLiteConvParams* conv_params, - TFLMRegistration registration, float* output_data); +constexpr int kConvMaxTensors = 4; +constexpr int kConvMaxInputTensors = 3; +template TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, TfLiteConvParams* conv_params, - TFLMRegistration registration, int8_t* output_data); - -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const float* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - TFLMRegistration registration, - float* output_data, float tolerance = 1e-5); - -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const int8_t* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - TFLMRegistration registration, - int8_t* output_data, float tolerance = 1e-5); - -TfLiteStatus TestConvFloat(int* input_dims_data, const float* input_data, - int* filter_dims_data, const float* filter_data, - int* bias_dims_data, const float* bias_data, - int* output_dims_data, - const float* expected_output_data, - TfLiteConvParams* conv_params, - TFLMRegistration registration, float* output_data); + int output_length, const TfLiteConvParams* conv_params, + TFLMRegistration registration, T* output_data +#ifdef USE_TFLM_COMPRESSION + , + const CompressedTensorList* comp_list_p = nullptr +#endif // USE_TFLM_COMPRESSION +) { + // TODO(b/358165875): support optional bias tensor + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 3}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, + outputs_array, conv_params +#ifdef USE_TFLM_COMPRESSION + , + nullptr, comp_list_p +#endif // USE_TFLM_COMPRESSION + ); + + const char* init_data = reinterpret_cast(conv_params); + TfLiteStatus status = runner.InitAndPrepare(init_data); + if (status != kTfLiteOk) { + return status; + } + return runner.Invoke(); +} + +template +TfLiteStatus ValidateConvGoldens( + TfLiteTensor* tensors, int tensors_size, const T* expected_output_data, + int output_length, const TfLiteConvParams* conv_params, + TFLMRegistration registration, T* output_data, float tolerance = 1e-5 +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { +#ifdef USE_TFLM_COMPRESSION + + TestCompressedList tcl; + if (filter_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*filter_comp_info, tensors[kConvWeightsTensor], + kConvWeightsTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + if (bias_comp_info) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*bias_comp_info, tensors[kConvBiasTensor], + kConvBiasTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + const CompressedTensorList* comp_list_p = tcl.GetCompressedTensorList(); + +#endif // USE_TFLM_COMPRESSION + + TfLiteStatus status = InvokeConv(tensors, tensors_size, output_length, + conv_params, registration, output_data +#ifdef USE_TFLM_COMPRESSION + , + comp_list_p +#endif // USE_TFLM_COMPRESSION + ); + if (status != kTfLiteOk) { + return status; + } + for (int i = 0; i < output_length; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], + tolerance); + } + return kTfLiteOk; +} + +TfLiteStatus TestConvFloat( + int* input_dims_data, const float* input_data, int* filter_dims_data, + const float* filter_data, int* bias_dims_data, const float* bias_data, + int* output_dims_data, const float* expected_output_data, + TfLiteConvParams* conv_params, TFLMRegistration registration, + float* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +); TfLiteStatus TestConvQuantizedPerChannel( int* input_dims_data, const float* input_data, int8_t* input_quantized, @@ -88,6 +155,74 @@ TfLiteStatus TestConvQuantizedPerChannel( float output_scale, int output_zero_point, TfLiteConvParams* conv_params, TFLMRegistration registration, int16_t* output_data); +#ifdef USE_TFLM_COMPRESSION + +template +TfLiteStatus TestConvQuantizedPerChannelCompressed( + int* input_dims_data, const float* input_data, TIO* input_quantized, + float input_scale, int input_zero_point, int* output_dims_data, + const float* expected_output_data, TIO* expected_output_quantized, + TIO* output_quantized, float output_scale, int output_zero_point, + const TfLiteConvParams* conv_params, TFLMRegistration registration, + const TestCompressionQuantizedInfo* filter_comp_info, + const TestCompressionQuantizedInfo* bias_comp_info) { + // TODO(b/358165875): account for optional bias tensor + // bool null_bias = comp_info->bias_data == nullptr ? true : false; + + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInts(filter_comp_info->dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_comp_info->dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + + TfLiteFloatArray* filter_scales = + FloatArrayFromFloats(filter_comp_info->scales); + TfLiteIntArray* filter_zero_points = + IntArrayFromInts(filter_comp_info->zero_points); + TfLiteFloatArray* bias_scales = FloatArrayFromFloats(bias_comp_info->scales); + TfLiteIntArray* bias_zero_points = + IntArrayFromInts(bias_comp_info->zero_points); + + TfLiteAffineQuantization filter_quant = {}; + TfLiteTensor filter_tensor = CreatePerChannelQuantizedTensor( + filter_comp_info->compressed, filter_dims, filter_scales, + filter_zero_points, &filter_quant, kConvQuantizedDimension, + false /* is_variable */, kTfLiteInt8); + SymmetricPerChannelQuantize( + filter_comp_info->data, filter_comp_info->value_table, + filter_scales->size * filter_comp_info->value_table_stride, + filter_scales->size, filter_scales->data); + + TfLiteAffineQuantization bias_quant = {}; + TfLiteTensor bias_tensor = CreatePerChannelQuantizedBiasTensor( + bias_comp_info->compressed, bias_dims, input_scale, filter_scales, + bias_scales, bias_zero_points, &bias_quant, kConvQuantizedDimension, + false /* is_variable */, typeToTfLiteType()); + SymmetricPerChannelQuantize( + bias_comp_info->data, bias_comp_info->value_table, + bias_scales->size * bias_comp_info->value_table_stride, bias_scales->size, + bias_scales->data); + + constexpr int tensors_size = kConvMaxTensors; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_quantized, input_dims, + input_scale, input_zero_point), + filter_tensor, + bias_tensor, + CreateQuantizedTensor(output_quantized, output_dims, output_scale, + output_zero_point), + }; + + const int output_dims_count = ElementCount(*output_dims); + Quantize(expected_output_data, expected_output_quantized, output_dims_count, + output_scale, output_zero_point); + return ValidateConvGoldens(tensors, tensors_size, expected_output_quantized, + output_dims_count, conv_params, registration, + output_quantized, 1.0e-5f /* tolerance */, + filter_comp_info, bias_comp_info); +} + +#endif // USE_TFLM_COMPRESSION + } // namespace testing } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/conv_test_common.cc b/tensorflow/lite/micro/kernels/conv_test_common.cc index a0f733b8e42..3825e05373c 100644 --- a/tensorflow/lite/micro/kernels/conv_test_common.cc +++ b/tensorflow/lite/micro/kernels/conv_test_common.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,88 +18,18 @@ limitations under the License. namespace tflite { namespace testing { -template -TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, TfLiteConvParams* conv_params, - TFLMRegistration registration, T* output_data) { - int inputs_array_data[] = {3, 0, 1, 2}; - TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); - int outputs_array_data[] = {1, 3}; - TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - - micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, - outputs_array, conv_params); - - const char* init_data = reinterpret_cast(conv_params); - TfLiteStatus status = runner.InitAndPrepare(init_data); - if (status != kTfLiteOk) { - return status; - } - return runner.Invoke(); -} - -template -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const T* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - TFLMRegistration registration, T* output_data, - float tolerance) { - TfLiteStatus status = InvokeConv(tensors, tensors_size, output_length, - conv_params, registration, output_data); - if (status != kTfLiteOk) { - return status; - } - for (int i = 0; i < output_length; ++i) { - TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], - tolerance); - } - return kTfLiteOk; -} - -TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, TfLiteConvParams* conv_params, - TFLMRegistration registration, float* output_data) { - return InvokeConv(tensors, tensors_size, output_length, conv_params, - registration, output_data); -} - -TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, TfLiteConvParams* conv_params, - TFLMRegistration registration, int8_t* output_data) { - return InvokeConv(tensors, tensors_size, output_length, conv_params, - registration, output_data); -} - -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const float* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - TFLMRegistration registration, - float* output_data, float tolerance) { - return ValidateConvGoldens(tensors, tensors_size, expected_output_data, - output_length, conv_params, registration, - output_data, tolerance); -} - -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const int8_t* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - TFLMRegistration registration, - int8_t* output_data, float tolerance) { - return ValidateConvGoldens( - tensors, tensors_size, expected_output_data, output_length, conv_params, - registration, output_data, tolerance); -} - -TfLiteStatus TestConvFloat(int* input_dims_data, const float* input_data, - int* filter_dims_data, const float* filter_data, - int* bias_dims_data, const float* bias_data, - int* output_dims_data, - const float* expected_output_data, - TfLiteConvParams* conv_params, - TFLMRegistration registration, float* output_data) { +TfLiteStatus TestConvFloat( + int* input_dims_data, const float* input_data, int* filter_dims_data, + const float* filter_data, int* bias_dims_data, const float* bias_data, + int* output_dims_data, const float* expected_output_data, + TfLiteConvParams* conv_params, TFLMRegistration registration, + float* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info, + const TestCompressionInfo* bias_comp_info +#endif // USE_TFLM_COMPRESSION +) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); @@ -117,7 +47,12 @@ TfLiteStatus TestConvFloat(int* input_dims_data, const float* input_data, return ValidateConvGoldens(tensors, tensors_size, expected_output_data, output_dims_count, conv_params, registration, - output_data); + output_data +#ifdef USE_TFLM_COMPRESSION + , + 1e-5f, filter_comp_info, bias_comp_info +#endif // USE_TFLM_COMPRESSION + ); } template diff --git a/tensorflow/lite/micro/kernels/decompress.cc b/tensorflow/lite/micro/kernels/decompress.cc new file mode 100644 index 00000000000..2d6d8670870 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decompress.cc @@ -0,0 +1,61 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef USE_TFLM_COMPRESSION + +#include "tensorflow/lite/micro/kernels/decompress.h" + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/micro/micro_common.h" + +namespace tflite { + +template +T* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 4 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth4_16(static_cast(buffer)); + } else if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 3 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth3_32(static_cast(buffer)); + } else if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 2 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth2_16(static_cast(buffer)); + } else { + DecompressToBufferWidthAny(static_cast(buffer)); + } + + return static_cast(buffer); +} + +template bool* DecompressionState::DecompressToBuffer(void*); +template float* DecompressionState::DecompressToBuffer(void*); +template int8_t* DecompressionState::DecompressToBuffer(void*); +template int16_t* DecompressionState::DecompressToBuffer(void*); +template int32_t* DecompressionState::DecompressToBuffer(void*); +template int64_t* DecompressionState::DecompressToBuffer(void*); + +} // namespace tflite + +#endif // USE_TFLM_COMPRESSION diff --git a/tensorflow/lite/micro/kernels/decompress.h b/tensorflow/lite/micro/kernels/decompress.h new file mode 100644 index 00000000000..d1d4e98d943 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decompress.h @@ -0,0 +1,89 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECOMPRESS_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECOMPRESS_H_ + +#include + +#include "tensorflow/lite/micro/compression.h" +#include "tensorflow/lite/micro/micro_profiler.h" + +namespace tflite { + +#ifdef USE_TFLM_COMPRESSION + +struct DecompressionState { + DecompressionState() = delete; + + DecompressionState(const uint8_t* compressed_indices, + const size_t count_indices, + const CompressionTensorData& comp_data, + const size_t num_channels, + MicroProfilerInterface* profiler = nullptr) + : compressed_indices_(compressed_indices), + count_indices_(count_indices), + comp_data_(comp_data), + num_channels_(num_channels), + micro_profiler_(profiler) {} + + DecompressionState(const DecompressionState& other) + : compressed_indices_(other.compressed_indices_), + count_indices_(other.count_indices_), + comp_data_(other.comp_data_), + num_channels_(other.num_channels_), + micro_profiler_(other.micro_profiler_) {} + + template + T* DecompressToBuffer(void* buffer); + + protected: + // optimized C++ for INT8, use_alt_axis == false + void DecompressToBufferWidth4_16(int8_t* buffer); + void DecompressToBufferWidth3_32(int8_t* buffer); + void DecompressToBufferWidth2_16(int8_t* buffer); + + // generic C++ for any bit width and value table type + template + void DecompressToBufferWidthAny(T* buffer); + + // Optimized C++ table index fetch + inline size_t GetNextTableIndexWidth7(const size_t current_offset); + inline size_t GetNextTableIndexWidth6(const size_t current_offset); + inline size_t GetNextTableIndexWidth5(const size_t current_offset); + inline size_t GetNextTableIndexWidth4(const size_t current_offset); + inline size_t GetNextTableIndexWidth3(const size_t current_offset); + inline size_t GetNextTableIndexWidth2(const size_t current_offset); + inline size_t GetNextTableIndexWidth1(const size_t current_offset); + + protected: + const uint8_t* compressed_indices_; + const size_t count_indices_; + const CompressionTensorData& comp_data_; + const size_t num_channels_; + const size_t compressed_bit_width_ = + comp_data_.data.lut_data->compressed_bit_width; + const size_t elements_per_channel_ = + comp_data_.data.lut_data->use_alternate_axis + ? 1 + : count_indices_ / num_channels_; + MicroProfilerInterface* micro_profiler_; +}; + +#endif // USE_TFLM_COMPRESSION + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECOMPRESS_H_ diff --git a/tensorflow/lite/micro/kernels/decompress_common.cc b/tensorflow/lite/micro/kernels/decompress_common.cc new file mode 100644 index 00000000000..5c40af83bf3 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decompress_common.cc @@ -0,0 +1,539 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef USE_TFLM_COMPRESSION + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/micro/kernels/decompress.h" +#include "tensorflow/lite/micro/micro_common.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" + +namespace tflite { + +void DecompressionState::DecompressToBufferWidth4_16(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint8_t* value_table = + static_cast(comp_data_.data.lut_data->value_table); + const size_t max_count = elements_per_channel_; + size_t current_offset = 0; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + + // process elements at start of channel up to next uint64_t alignment of + // compressed_indices_ + while (count > 0 && (current_offset & 0x0F)) { + const size_t index = GetNextTableIndexWidth4(current_offset++); + *buffer++ = value_table[index]; + count -= 1; + } + + // process elements in current channel in groups of 16 + if (count >= 16) { + const uint64_t* indices = reinterpret_cast( + &compressed_indices_[current_offset >> 1]); + + while (count >= 16) { + count -= 16; + uint64_t index = *indices++; + uint64_t value, value2; + + value = static_cast(value_table[(index >> 4) & 0x0F]); + value |= static_cast(value_table[index & 0x0F]) << 8; + value |= static_cast(value_table[(index >> 12) & 0x0F]) << 16; + value |= static_cast(value_table[(index >> 8) & 0x0F]) << 24; + value |= static_cast(value_table[(index >> 20) & 0x0F]) << 32; + value |= static_cast(value_table[(index >> 16) & 0x0F]) << 40; + value |= static_cast(value_table[(index >> 28) & 0x0F]) << 48; + value |= static_cast(value_table[(index >> 24) & 0x0F]) << 56; + + *reinterpret_cast(buffer) = value; + + value2 = static_cast(value_table[(index >> 36) & 0x0F]); + value2 |= static_cast(value_table[(index >> 32) & 0x0F]) << 8; + value2 |= static_cast(value_table[(index >> 44) & 0x0F]) + << 16; + value2 |= static_cast(value_table[(index >> 40) & 0x0F]) + << 24; + value2 |= static_cast(value_table[(index >> 52) & 0x0F]) + << 32; + value2 |= static_cast(value_table[(index >> 48) & 0x0F]) + << 40; + value2 |= static_cast(value_table[(index >> 60) & 0x0F]) + << 48; + value2 |= static_cast(value_table[(index >> 56) & 0x0F]) + << 56; + + *reinterpret_cast(buffer + 8) = value2; + + buffer += 16; + } + + current_offset = + (reinterpret_cast(indices) - compressed_indices_) + << 1; + } + + // process remaining elements in current channel + while (count > 0) { + count -= 1; + const size_t index = GetNextTableIndexWidth4(current_offset++); + *buffer++ = value_table[index]; + } + + value_table += stride; + } +} + +void DecompressionState::DecompressToBufferWidth2_16(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint8_t* value_table = + static_cast(comp_data_.data.lut_data->value_table); + const size_t max_count = elements_per_channel_; + size_t current_offset = 0; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + + // process elements at start of channel up to next uint32_t alignment of + // compressed_indices_ + while (count > 0 && (current_offset & 0x0F)) { + const size_t index = GetNextTableIndexWidth2(current_offset++); + *buffer++ = value_table[index]; + count -= 1; + } + + // process elements in current channel in groups of 16 + if (count >= 16) { + const uint32_t* indices = reinterpret_cast( + &compressed_indices_[current_offset >> 2]); + + while (count >= 16) { + count -= 16; + uint32_t index = *indices++; + uint64_t value, value2; + + value = static_cast(value_table[(index >> 6) & 0x03]); + value |= static_cast(value_table[(index >> 4) & 0x03]) << 8; + value |= static_cast(value_table[(index >> 2) & 0x03]) << 16; + value |= static_cast(value_table[index & 0x03]) << 24; + value |= static_cast(value_table[(index >> 14) & 0x03]) << 32; + value |= static_cast(value_table[(index >> 12) & 0x03]) << 40; + value |= static_cast(value_table[(index >> 10) & 0x03]) << 48; + value |= static_cast(value_table[(index >> 8) & 0x03]) << 56; + + *reinterpret_cast(buffer) = value; + + value2 = static_cast(value_table[(index >> 22) & 0x03]); + value2 |= static_cast(value_table[(index >> 20) & 0x03]) << 8; + value2 |= static_cast(value_table[(index >> 18) & 0x03]) + << 16; + value2 |= static_cast(value_table[(index >> 16) & 0x03]) + << 24; + value2 |= static_cast(value_table[(index >> 30) & 0x03]) + << 32; + value2 |= static_cast(value_table[(index >> 28) & 0x03]) + << 40; + value2 |= static_cast(value_table[(index >> 26) & 0x03]) + << 48; + value2 |= static_cast(value_table[(index >> 24) & 0x03]) + << 56; + + *reinterpret_cast(buffer + 8) = value2; + + buffer += 16; + } + + current_offset = + (reinterpret_cast(indices) - compressed_indices_) + << 2; + } + + // process remaining elements in current channel + while (count > 0) { + count -= 1; + const size_t index = GetNextTableIndexWidth2(current_offset++); + *buffer++ = value_table[index]; + } + + value_table += stride; + } +} + +void DecompressionState::DecompressToBufferWidth3_32(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint8_t* value_table = + static_cast(comp_data_.data.lut_data->value_table); + const size_t max_count = elements_per_channel_; + size_t current_offset = 0; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + + // process elements at start of channel up to next uint32_t alignment of + // compressed_indices_ + while (count > 0 && (current_offset & 0x1F)) { + const size_t index = GetNextTableIndexWidth3(current_offset++); + *buffer++ = value_table[index]; + count -= 1; + } + + // process elements in current channel in groups of 32 + if (count >= 32) { + const uint32_t* indices = reinterpret_cast( + &compressed_indices_[(current_offset >> 5) * 12]); + + while (count >= 32) { + count -= 32; + uint32_t index0 = *indices++; + uint32_t index1 = *indices++; + uint32_t index2 = *indices++; + uint64_t value, value2; + + value = static_cast(value_table[(index0 >> 5) & 0x07]); + value |= static_cast(value_table[(index0 >> 2) & 0x07]) << 8; + value |= + static_cast( + value_table[((index0 << 1) & 0b110) | ((index0 >> 15) & 0b1)]) + << 16; + value |= static_cast(value_table[(index0 >> 12) & 0x07]) + << 24; + value |= static_cast(value_table[(index0 >> 9) & 0x07]) << 32; + value |= + static_cast( + value_table[((index0 >> 6) & 0b100) | ((index0 >> 22) & 0b11)]) + << 40; + value |= static_cast(value_table[(index0 >> 19) & 0x07]) + << 48; + value |= static_cast(value_table[(index0 >> 16) & 0x07]) + << 56; + + *reinterpret_cast(buffer) = value; + + value2 = static_cast(value_table[(index0 >> 29) & 0x07]); + value2 |= static_cast(value_table[(index0 >> 26) & 0x07]) + << 8; + value2 |= + static_cast( + value_table[((index0 >> 23) & 0b110) | ((index1 >> 7) & 0b1)]) + << 16; + value2 |= static_cast(value_table[(index1 >> 4) & 0x07]) + << 24; + value2 |= static_cast(value_table[(index1 >> 1) & 0x07]) + << 32; + value2 |= + static_cast( + value_table[((index1 << 2) & 0b100) | ((index1 >> 14) & 0b11)]) + << 40; + value2 |= static_cast(value_table[(index1 >> 11) & 0x07]) + << 48; + value2 |= static_cast(value_table[(index1 >> 8) & 0x07]) + << 56; + + *reinterpret_cast(buffer + 8) = value2; + + value = static_cast(value_table[(index1 >> 21) & 0x07]); + value |= static_cast(value_table[(index1 >> 18) & 0x07]) << 8; + value |= + static_cast( + value_table[((index1 >> 15) & 0b110) | ((index1 >> 31) & 0b1)]) + << 16; + value |= static_cast(value_table[(index1 >> 28) & 0x07]) + << 24; + value |= static_cast(value_table[(index1 >> 25) & 0x07]) + << 32; + value |= + static_cast( + value_table[((index1 >> 22) & 0b100) | ((index2 >> 6) & 0b11)]) + << 40; + value |= static_cast(value_table[(index2 >> 3) & 0x07]) << 48; + value |= static_cast(value_table[(index2 >> 0) & 0x07]) << 56; + + *reinterpret_cast(buffer + 16) = value; + + value2 = static_cast(value_table[(index2 >> 13) & 0x07]); + value2 |= static_cast(value_table[(index2 >> 10) & 0x07]) + << 8; + value2 |= + static_cast( + value_table[((index2 >> 7) & 0b110) | ((index2 >> 23) & 0b1)]) + << 16; + value2 |= static_cast(value_table[(index2 >> 20) & 0x07]) + << 24; + value2 |= static_cast(value_table[(index2 >> 17) & 0x07]) + << 32; + value2 |= + static_cast( + value_table[((index2 >> 14) & 0b100) | ((index2 >> 30) & 0b11)]) + << 40; + value2 |= static_cast(value_table[(index2 >> 27) & 0x07]) + << 48; + value2 |= static_cast(value_table[(index2 >> 24) & 0x07]) + << 56; + + *reinterpret_cast(buffer + 24) = value2; + + buffer += 32; + current_offset += 32; + } + } + + // process remaining elements in current channel + while (count > 0) { + count -= 1; + const size_t index = GetNextTableIndexWidth3(current_offset++); + *buffer++ = value_table[index]; + } + + value_table += stride; + } +} + +// TODO(ddavis-2015): templating GetNextTableIndexWidth makes this method +// more than 2x faster, but with a large code size increase +template +void DecompressionState::DecompressToBufferWidthAny(T* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + if (comp_data_.data.lut_data->use_alternate_axis) { + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + size_t current_offset = 0; + size_t count = count_indices_; + + while (count > 0) { + const T* value_table = + static_cast(comp_data_.data.lut_data->value_table); + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t index; + switch (compressed_bit_width_) { + case 1: + index = GetNextTableIndexWidth1(current_offset); + break; + case 2: + index = GetNextTableIndexWidth2(current_offset); + break; + case 3: + index = GetNextTableIndexWidth3(current_offset); + break; + case 4: + index = GetNextTableIndexWidth4(current_offset); + break; + case 5: + index = GetNextTableIndexWidth5(current_offset); + break; + case 6: + index = GetNextTableIndexWidth6(current_offset); + break; + case 7: + index = GetNextTableIndexWidth7(current_offset); + break; + } + current_offset++; + *buffer++ = value_table[index]; + value_table += stride; + } + count -= num_channels_; + } + } else { + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + const T* value_table = + static_cast(comp_data_.data.lut_data->value_table); + const size_t max_count = elements_per_channel_; + size_t current_offset = 0; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + + while (count-- > 0) { + size_t index; + switch (compressed_bit_width_) { + case 1: + index = GetNextTableIndexWidth1(current_offset); + break; + case 2: + index = GetNextTableIndexWidth2(current_offset); + break; + case 3: + index = GetNextTableIndexWidth3(current_offset); + break; + case 4: + index = GetNextTableIndexWidth4(current_offset); + break; + case 5: + index = GetNextTableIndexWidth5(current_offset); + break; + case 6: + index = GetNextTableIndexWidth6(current_offset); + break; + case 7: + index = GetNextTableIndexWidth7(current_offset); + break; + } + current_offset++; + *buffer++ = value_table[index]; + } + value_table += stride; + } + } +} + +template void DecompressionState::DecompressToBufferWidthAny(bool*); +template void DecompressionState::DecompressToBufferWidthAny(float*); +template void DecompressionState::DecompressToBufferWidthAny(int8_t*); +template void DecompressionState::DecompressToBufferWidthAny(int16_t*); +template void DecompressionState::DecompressToBufferWidthAny(int32_t*); +template void DecompressionState::DecompressToBufferWidthAny(int64_t*); + +inline size_t DecompressionState::GetNextTableIndexWidth7( + const size_t current_offset) { + const size_t current_byte_index = (current_offset >> 3) * 7; + const uint8_t* indices = &compressed_indices_[current_byte_index]; + switch (current_offset & 0b111) { + case 0: + return indices[0] >> 1; + case 1: + return ((indices[0] & 0b1) << 6) | (indices[1] >> 2); + case 2: + return ((indices[1] & 0b11) << 5) | (indices[2] >> 3); + case 3: + return ((indices[2] & 0b111) << 4) | (indices[3] >> 4); + case 4: + return ((indices[3] & 0x0F) << 3) | (indices[4] >> 5); + case 5: + return ((indices[4] & 0x1F) << 2) | (indices[5] >> 6); + case 6: + return ((indices[5] & 0x3F) << 1) | (indices[6] >> 7); + case 7: + return indices[6] & 0x7F; + } + // NOTREACHED + return 0; +} + +inline size_t DecompressionState::GetNextTableIndexWidth6( + const size_t current_offset) { + const size_t current_byte_index = (current_offset >> 2) * 3; + const uint8_t* indices = &compressed_indices_[current_byte_index]; + switch (current_offset & 0b11) { + case 0: + return indices[0] >> 2; + case 1: + return ((indices[0] & 0b11) << 4) | (indices[1] >> 4); + case 2: + return ((indices[1] & 0x0F) << 2) | (indices[2] >> 6); + case 3: + return indices[2] & 0x3F; + } + // NOTREACHED + return 0; +} + +inline size_t DecompressionState::GetNextTableIndexWidth5( + const size_t current_offset) { + const size_t current_byte_index = (current_offset >> 3) * 5; + const uint8_t* indices = &compressed_indices_[current_byte_index]; + switch (current_offset & 0b111) { + case 0: + return indices[0] >> 3; + case 1: + return ((indices[0] & 0b111) << 2) | (indices[1] >> 6); + case 2: + return (indices[1] >> 1) & 0x1F; + case 3: + return ((indices[1] & 0b1) << 4) | (indices[2] >> 4); + case 4: + return ((indices[2] & 0x0F) << 1) | (indices[3] >> 7); + case 5: + return (indices[3] >> 2) & 0x1F; + case 6: + return ((indices[3] & 0b11) << 3) | (indices[4] >> 5); + case 7: + return indices[4] & 0x1F; + } + // NOTREACHED + return 0; +} + +inline size_t DecompressionState::GetNextTableIndexWidth4( + const size_t current_offset) { + if (current_offset & 1) { + return compressed_indices_[current_offset >> 1] & 0x0F; + } else { + return compressed_indices_[current_offset >> 1] >> 4; + } +} + +inline size_t DecompressionState::GetNextTableIndexWidth3( + const size_t current_offset) { + const size_t current_byte_index = (current_offset >> 3) * 3; + const uint8_t* indices = &compressed_indices_[current_byte_index]; + switch (current_offset & 0b111) { + case 0: + return indices[0] >> 5; + case 1: + return (indices[0] >> 2) & 0b111; + case 2: + return ((indices[0] & 0b11) << 1) | (indices[1] >> 7); + case 3: + return (indices[1] >> 4) & 0b111; + case 4: + return (indices[1] >> 1) & 0b111; + case 5: + return ((indices[1] & 0b1) << 2) | (indices[2] >> 6); + case 6: + return (indices[2] >> 3) & 0b111; + case 7: + return indices[2] & 0b111; + } + // NOTREACHED + return 0; +} + +inline size_t DecompressionState::GetNextTableIndexWidth2( + const size_t current_offset) { + if (current_offset & 0b10) { + if (current_offset & 1) { + return compressed_indices_[current_offset >> 2] & 0x03; + } else { + return (compressed_indices_[current_offset >> 2] >> 2) & 0x03; + } + } else { + if (current_offset & 1) { + return (compressed_indices_[current_offset >> 2] >> 4) & 0x03; + } else { + return (compressed_indices_[current_offset >> 2] >> 6) & 0x03; + } + } +} + +inline size_t DecompressionState::GetNextTableIndexWidth1( + const size_t current_offset) { + const size_t shift = ~current_offset & 0b111; + return (compressed_indices_[current_offset >> 3] >> shift) & 0b1; +} + +} // namespace tflite + +#endif // USE_TFLM_COMPRESSION diff --git a/tensorflow/lite/micro/kernels/decompress_test.cc b/tensorflow/lite/micro/kernels/decompress_test.cc new file mode 100644 index 00000000000..cdd2a633545 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decompress_test.cc @@ -0,0 +1,414 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef USE_TFLM_COMPRESSION + +#include "tensorflow/lite/micro/kernels/decompress.h" + +#include +#include +#include + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro//micro_log.h" +#include "tensorflow/lite/micro/micro_arena_constants.h" +#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +template +struct TestingInfo { + T* output; + T* goldens; + uint8_t* compressed; + T* value_table; + + size_t bit_width; + size_t total_elements; + size_t total_value_table_elements; + size_t channel_count; + bool use_alt_axis; +}; + +template +struct TestingData { + static constexpr size_t kBitWidth = 7; + static constexpr size_t kChannels = 2; + static constexpr size_t kElementsPerChannel = 256; + + static constexpr size_t kTotalElements = kElementsPerChannel * kChannels; + static constexpr size_t kCompressedBytes = + ((kTotalElements * kBitWidth) + 7) / 8; + static constexpr size_t kValueTableSize = (1 << kBitWidth) * kChannels; + + alignas(MicroArenaBufferAlignment()) T output[kTotalElements]; + alignas(MicroArenaBufferAlignment()) uint8_t compressed[kCompressedBytes]; + alignas(MicroArenaBufferAlignment()) T value_table[kValueTableSize]; + T goldens[kTotalElements]; +}; + +TestingData TestingData_Bool; +TestingData TestingData_Float32; +TestingData TestingData_Int8; +TestingData TestingData_Int16; +TestingData TestingData_Int32; +TestingData TestingData_Int64; + +template +TestingData* GetTestingData(); + +template <> +TestingData* GetTestingData() { + return &TestingData_Bool; +} + +template <> +TestingData* GetTestingData() { + return &TestingData_Float32; +} + +template <> +TestingData* GetTestingData() { + return &TestingData_Int8; +} + +template <> +TestingData* GetTestingData() { + return &TestingData_Int16; +} + +template <> +TestingData* GetTestingData() { + return &TestingData_Int32; +} + +template <> +TestingData* GetTestingData() { + return &TestingData_Int64; +} + +template +void FillValueTable(const size_t total_elements, T* value_table) { + T fill_value = -1; + for (size_t i = 0; i < total_elements; i++) { + value_table[i] = fill_value; + fill_value -= 1; + } +} + +template <> +void FillValueTable(const size_t total_elements, float* value_table) { + float fill_value = -1.1f; + for (size_t i = 0; i < total_elements; i++) { + value_table[i] = fill_value; + fill_value -= 1.0f; + } +} + +template <> +void FillValueTable(const size_t total_elements, bool* value_table) { + bool fill_value = true; + for (size_t i = 0; i < total_elements; i++) { + value_table[i] = fill_value; + fill_value = !fill_value; + } +} + +template +void FillGoldens(const size_t total_elements, T* goldens, + const size_t value_table_elements, const T* value_table, + const size_t channels, const bool use_alt_axis) { + if (use_alt_axis) { + const size_t value_table_stride = value_table_elements / channels; + const size_t element_groups = total_elements / channels; + size_t value_table_index = 0; // index within current channel + + for (size_t group = 0; group < element_groups; group++) { + for (size_t channel = 0; channel < channels; channel++) { + goldens[(group * channels) + channel] = + value_table[(channel * value_table_stride) + value_table_index]; + } + if (++value_table_index == value_table_stride) { + value_table_index = 0; + } + } + } else { + const size_t value_table_stride = value_table_elements / channels; + const size_t elements_per_channel = total_elements / channels; + size_t value_table_index = 0; // index within current channel + + for (size_t channel = 0; channel < channels; channel++) { + for (size_t i = 0; i < elements_per_channel; i++) { + goldens[(channel * elements_per_channel) + i] = + value_table[(channel * value_table_stride) + value_table_index]; + if (++value_table_index == value_table_stride) { + value_table_index = 0; + } + } + value_table_index = 0; + } + } +} + +// returns index within channel +template +size_t FindValueTableIndex(const T value, const T* value_table, + const size_t value_table_stride) { + for (size_t i = 0; i < value_table_stride; i++) { + if (value == value_table[i]) { + return i; + } + } + return 0; +} + +template +void FillCompressed(uint8_t* compressed, const size_t total_golden_elements, + const T* goldens, const size_t value_table_stride, + const T* value_table, const size_t channels, + const bool use_alt_axis, const size_t bit_width) { + uint16_t bits = 0; + size_t bits_accumulated = 0; + + if (use_alt_axis) { + size_t golden_element_groups = total_golden_elements / channels; + + for (size_t group = 0; group < golden_element_groups; group++) { + for (size_t channel = 0; channel < channels; channel++) { + size_t value_table_index = FindValueTableIndex( + goldens[(group * channels) + channel], + &value_table[channel * value_table_stride], value_table_stride); + bits |= value_table_index << (16 - bits_accumulated - bit_width); + bits_accumulated += bit_width; + if (bits_accumulated > 8) { + *compressed++ = static_cast(bits >> 8); + bits <<= 8; + bits_accumulated -= 8; + } + } + } + } else { + size_t golden_elements_per_channel = total_golden_elements / channels; + + for (size_t channel = 0; channel < channels; channel++) { + for (size_t i = 0; i < golden_elements_per_channel; i++) { + size_t value_table_index = FindValueTableIndex( + goldens[(channel * golden_elements_per_channel) + i], value_table, + value_table_stride); + bits |= value_table_index << (16 - bits_accumulated - bit_width); + bits_accumulated += bit_width; + if (bits_accumulated > 8) { + *compressed++ = static_cast(bits >> 8); + bits <<= 8; + bits_accumulated -= 8; + } + } + value_table += value_table_stride; + } + } + + if (bits_accumulated > 0) { + *compressed = static_cast(bits >> 8); + } +} + +template +void GenerateData(TestingInfo& info) { + FillValueTable(info.total_value_table_elements, info.value_table); + FillGoldens(info.total_elements, info.goldens, + info.total_value_table_elements, info.value_table, + info.channel_count, info.use_alt_axis); + FillCompressed(info.compressed, info.total_elements, info.goldens, + info.total_value_table_elements / info.channel_count, + info.value_table, info.channel_count, info.use_alt_axis, + info.bit_width); +} + +template +void TestDataSetup(TestingInfo* info, TestingData* data) { + info->output = data->output; + info->goldens = data->goldens; + info->compressed = data->compressed; + info->value_table = data->value_table; +} + +template +TfLiteStatus TestDecompression(TestingInfo* info) { + GenerateData(*info); + + CompressionTensorData ctd = {}; + LookupTableData lut_data = {}; + ctd.scheme = CompressionScheme::kBinQuant; + ctd.data.lut_data = &lut_data; + lut_data.compressed_bit_width = info->bit_width; + lut_data.is_per_channel_quantized = info->channel_count > 1 ? true : false; + lut_data.use_alternate_axis = info->use_alt_axis; + lut_data.value_table = info->value_table; + lut_data.value_table_channel_stride = + info->total_value_table_elements / info->channel_count; + + DecompressionState ds(info->compressed, info->total_elements, ctd, + info->channel_count); + + std::fill_n(info->output, info->total_elements, static_cast(~0ULL)); + ds.DecompressToBuffer(info->output); + + bool saved_fail_state = micro_test::did_test_fail; + micro_test::did_test_fail = false; + for (size_t i = 0; i < info->total_elements; i++) { + TF_LITE_MICRO_EXPECT_EQ(info->goldens[i], info->output[i]); + if (micro_test::did_test_fail) { + return kTfLiteError; + } + } + micro_test::did_test_fail = saved_fail_state; + return kTfLiteOk; +} + +template +TfLiteStatus TestValueTable2n(TestingInfo& info) { + if (std::is_same::value) { + info.total_value_table_elements = 2 * info.channel_count; + } else { + info.total_value_table_elements = + (1 << info.bit_width) * info.channel_count; + } + info.total_value_table_elements = + std::min(info.total_value_table_elements, info.total_elements); + info.total_value_table_elements = std::min(info.total_value_table_elements, + TestingData::kValueTableSize); + + MicroPrintf(" Testing value table 2^n: %d", + info.total_value_table_elements); + return TestDecompression(&info); +} + +template +TfLiteStatus TestValueTable2nMinus1(TestingInfo& info) { + if (std::is_same::value) { + info.total_value_table_elements = 1 * info.channel_count; + } else { + info.total_value_table_elements = + ((1 << info.bit_width) - 1) * info.channel_count; + } + info.total_value_table_elements = + std::min(info.total_value_table_elements, info.total_elements); + info.total_value_table_elements = std::min(info.total_value_table_elements, + TestingData::kValueTableSize); + + MicroPrintf(" Testing value table 2^n-1: %d", + info.total_value_table_elements); + return TestDecompression(&info); +} + +template +void TestElementCount(TestingInfo& info) { + static constexpr std::initializer_list elements_per_channel{ + 1, 2, + 3, 4, + 5, 7, + 8, 9, + 15, 16, + 17, 31, + 32, 33, + 63, 64, + 65, 127, + 128, 129, + 255, TestingData::kElementsPerChannel}; + + MicroPrintf(" Testing element count: %d thru %d", + elements_per_channel.begin()[0], elements_per_channel.end()[-1]); + + for (size_t i = 0; i < elements_per_channel.size(); i++) { + info.total_elements = elements_per_channel.begin()[i] * info.channel_count; + + TfLiteStatus s; + s = TestValueTable2n(info); + if (s == kTfLiteError) { + MicroPrintf(" Failed element count: %d", info.total_elements); + } + s = TestValueTable2nMinus1(info); + if (s == kTfLiteError) { + MicroPrintf(" Failed element count: %d", info.total_elements); + } + } +} + +template +void TestSingleChannel(TestingInfo& info) { + info.channel_count = 1; + + MicroPrintf(" Testing single channel"); + TestElementCount(info); +} + +template +void TestMultiChannel(TestingInfo& info) { + info.channel_count = TestingData::kChannels; + + MicroPrintf(" Testing multiple channels: %d", info.channel_count); + TestElementCount(info); +} + +template +void TestBitWidth(TestingInfo& info) { + info.use_alt_axis = false; + + MicroPrintf(" Testing bit width %d", info.bit_width); + TestSingleChannel(info); + TestMultiChannel(info); +} + +template +void TestBitWidthAltAxis(TestingInfo& info) { + info.use_alt_axis = true; + + MicroPrintf(" Testing alt-axis bit width %d", info.bit_width); + TestSingleChannel(info); + TestMultiChannel(info); +} + +template +void TestAllBitWidths() { + TestingInfo info = {}; + TestDataSetup(&info, GetTestingData()); + + for (size_t bw = 1; bw <= TestingData::kBitWidth; bw++) { + info.bit_width = bw; + + TestBitWidth(info); + TestBitWidthAltAxis(info); + } +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestBool) { tflite::testing::TestAllBitWidths(); } +TF_LITE_MICRO_TEST(TestFloat) { tflite::testing::TestAllBitWidths(); } +TF_LITE_MICRO_TEST(TestInt8) { tflite::testing::TestAllBitWidths(); } +TF_LITE_MICRO_TEST(TestInt16) { tflite::testing::TestAllBitWidths(); } +TF_LITE_MICRO_TEST(TestInt32) { tflite::testing::TestAllBitWidths(); } +TF_LITE_MICRO_TEST(TestInt64) { tflite::testing::TestAllBitWidths(); } + +TF_LITE_MICRO_TESTS_END + +#endif // USE_TFLM_COMPRESSION diff --git a/tensorflow/lite/micro/kernels/depthwise_conv.cc b/tensorflow/lite/micro/kernels/depthwise_conv.cc index fa55a705606..4d6cb4c4979 100644 --- a/tensorflow/lite/micro/kernels/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/depthwise_conv.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -52,6 +52,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { ? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor) : nullptr; +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kDepthwiseConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + switch (input->type) { // Already know in/out types are same. case kTfLiteFloat32: { tflite::reference_ops::DepthwiseConv( @@ -59,9 +71,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -94,9 +115,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -118,9 +148,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; diff --git a/tensorflow/lite/micro/kernels/depthwise_conv_common.cc b/tensorflow/lite/micro/kernels/depthwise_conv_common.cc index 52804de3315..0813d2b028e 100644 --- a/tensorflow/lite/micro/kernels/depthwise_conv_common.cc +++ b/tensorflow/lite/micro/kernels/depthwise_conv_common.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -127,7 +127,9 @@ TfLiteStatus CalculateOpDataDepthwiseConv( micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); - micro_context->DeallocateTempTfLiteTensor(bias); + if (has_bias) { + micro_context->DeallocateTempTfLiteTensor(bias); + } micro_context->DeallocateTempTfLiteTensor(output); return kTfLiteOk; @@ -209,6 +211,23 @@ TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node) { context, node, params, input_width, input_height, filter_width, filter_height, output_width, output_height, input->type, data)); +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + if (micro_context->IsTensorCompressed(node, kDepthwiseConvWeightsTensor) && + filter->type == kTfLiteInt4) { + MicroPrintf("Compression not supported with INT4 tensors"); + return kTfLiteError; + } + data->weights_scratch_index = + micro_context->AllocateDecompressionScratchBuffer( + node, kDepthwiseConvWeightsTensor); + data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer( + node, kDepthwiseConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); diff --git a/tensorflow/lite/micro/kernels/depthwise_conv_test.cc b/tensorflow/lite/micro/kernels/depthwise_conv_test.cc index b50b40ae6d6..adedcaeb04e 100644 --- a/tensorflow/lite/micro/kernels/depthwise_conv_test.cc +++ b/tensorflow/lite/micro/kernels/depthwise_conv_test.cc @@ -1,5 +1,5 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,6 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" @@ -32,17 +34,99 @@ constexpr int kOutputTensorIndex = 3; constexpr int kMaxFilterChannels = 64; constexpr int kMaxBiasChannels = 64; +#ifdef USE_TFLM_COMPRESSION + +constexpr size_t kDepthwiseConvMaxTensors = 4; +constexpr size_t kDepthwiseConvMaxInputTensors = 3; + +// Common inputs and outputs (quantized multi channel). +// data from TfLite test: +// PerChannelQuantizedDepthwiseConvolutionOpTest SimpleTestMixedOutputShift +static int kInputShapeQ1[] = {4, 1, 2, 3, 2}; +static constexpr float kInputDataQ1[] = { + // [1 * 2 * 3 * 2] as [batch, y, x, input_channel] + 3, 2, // batch = 0, y = 0, x = 0 + 1, -1, // batch = 0, y = 0, x = 1 + -2, -3, // batch = 0, y = 0, x = 2 + 4, 3, // batch = 0, y = 1, x = 0 + 2, -2, // batch = 0, y = 1, x = 1 + -3, -4, // batch = 0, y = 1, x = 2 +}; +constexpr size_t kInputElementsQ1 = std::extent::value; + +constexpr int kNumChannelsQ1 = 4; +static int kFilterShapeQ1[] = {4, 1, 2, 2, 4}; +static constexpr float kFilterDataQ1[] = { + // This is a compact value table. Original data is: + // [1 * 2 * 2 * 4] as [input_channel, y, x, output_channel] + // depth multiplier = 2 + // 1, 2, 3, 4, y = 0, x = 0 + // 3, 4, 5, 6, y = 0, x = 1 + // 7, 8, 5, 6, y = 1, x = 0 + // 3, 4, 1, 2, y = 1, x = 1 + 1, 3, 7, 8, 2, 4, 1, 3, 5, 2, 4, 6, +}; +constexpr size_t kFilterElementsQ1 = + std::extent::value; + +static int kBiasShapeQ1[] = {1, 4}; +static constexpr float kBiasDataQ1[] = {3, -2, 4, 6}; +constexpr size_t kBiasElementsQ1 = std::extent::value; + +static int kOutputShapeQ1[] = {4, 1, 1, 2, 4}; +static constexpr float kGoldenDataQ1[] = {43, 48, 21, 22, 3, -4, -30, -36}; +constexpr int kOutputElementsQ1 = std::extent::value; + +// compressed filter data for kBinQuant scheme, matches kFilterDataQ1 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantFilterDataQ1[] = {0x15, 0x6A, 0x8A, + 0x60}; +constexpr int kBinQuantFilterBitWidthQ1 = 2; +// compressed bias data for kBinQuant scheme, matches kBiasDataQ1 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasDataQ1[] = {0x00}; +constexpr int kBinQuantBiasBitWidthQ1 = 1; + +#endif // USE_TFLM_COMPRESSION + // Creates a DepthwiseConv opeerator, calls it with the provided input tensors // and some defaults parameters, and compares the output with // expected_output_data. // // The tensors parameter contains both the input tensors as well as a // preallocated output tensor into which the output is stored. -template +template TfLiteStatus ValidateDepthwiseConvGoldens( const T* expected_output_data, int output_length, TfLiteDepthwiseConvParams* conv_params, float tolerance, int tensors_size, - TfLiteTensor* tensors) { + TfLiteTensor* tensors +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { +#ifdef USE_TFLM_COMPRESSION + + TestCompressedList tcl; + if (filter_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*filter_comp_info, tensors[kDepthwiseConvWeightsTensor], + kDepthwiseConvWeightsTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + if (bias_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*bias_comp_info, tensors[kDepthwiseConvBiasTensor], + kDepthwiseConvBiasTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + const CompressedTensorList* comp_list_p = tcl.GetCompressedTensorList(); + +#endif // USE_TFLM_COMPRESSION + int inputs_array_data[] = {3, 0, 1, 2}; TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); int outputs_array_data[] = {1, 3}; @@ -50,8 +134,12 @@ TfLiteStatus ValidateDepthwiseConvGoldens( const TFLMRegistration registration = Register_DEPTHWISE_CONV_2D(); micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, - outputs_array, - reinterpret_cast(conv_params)); + outputs_array, reinterpret_cast(conv_params) +#ifdef USE_TFLM_COMPRESSION + , + nullptr, comp_list_p +#endif // USE_TFLM_COMPRESSION + ); int input_depth = tensors[0].dims->data[3]; int output_depth = tensors[1].dims->data[3]; @@ -183,18 +271,93 @@ void TestDepthwiseConvQuantizedPerChannel( output_scale, output_zero_point, conv_params, filter_packed_type); } +#ifdef USE_TFLM_COMPRESSION + +template +TfLiteStatus TestDepthwiseConvQuantizedCompressed( + int* input_dims_data, const float* input_data, TIO* input_quantized, + float input_scale, int input_zero_point, int* output_dims_data, + const float* expected_output_data, TIO* expected_output_quantized, + TIO* output_quantized, float output_scale, int output_zero_point, + TfLiteDepthwiseConvParams* conv_params, const unsigned int tolerance, + const TestCompressionQuantizedInfo* filter_comp_info, + const TestCompressionQuantizedInfo* bias_comp_info) { + // TODO(b/360169306): account for optional bias tensor + // bool null_bias = comp_info->bias_data == nullptr ? true : false; + + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInts(filter_comp_info->dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_comp_info->dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + + TfLiteFloatArray* filter_scales = + FloatArrayFromFloats(filter_comp_info->scales); + TfLiteIntArray* filter_zero_points = + IntArrayFromInts(filter_comp_info->zero_points); + TfLiteFloatArray* bias_scales = FloatArrayFromFloats(bias_comp_info->scales); + TfLiteIntArray* bias_zero_points = + IntArrayFromInts(bias_comp_info->zero_points); + + TfLiteAffineQuantization filter_quant = {}; + TfLiteTensor filter_tensor = CreatePerChannelQuantizedTensor( + filter_comp_info->compressed, filter_dims, filter_scales, + filter_zero_points, &filter_quant, kDepthwiseConvQuantizedDimension, + false /* is_variable */, kTfLiteInt8); + // Value tables are always in channel order, therefore do not use the + // quantized dimension. + SymmetricPerChannelQuantize( + filter_comp_info->data, filter_comp_info->value_table, + filter_scales->size * filter_comp_info->value_table_stride, + filter_scales->size, filter_scales->data, 0 /* see comment above */); + + TfLiteAffineQuantization bias_quant = {}; + TfLiteTensor bias_tensor = CreatePerChannelQuantizedBiasTensor( + bias_comp_info->compressed, bias_dims, input_scale, filter_scales, + bias_scales, bias_zero_points, &bias_quant, + 0 /* quantized dimension for bias tensor */, false /* is_variable */, + typeToTfLiteType()); + SymmetricPerChannelQuantize( + bias_comp_info->data, bias_comp_info->value_table, + bias_scales->size * bias_comp_info->value_table_stride, bias_scales->size, + bias_scales->data); + + constexpr int tensors_size = kDepthwiseConvMaxTensors; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_quantized, input_dims, + input_scale, input_zero_point), + filter_tensor, + bias_tensor, + CreateQuantizedTensor(output_quantized, output_dims, output_scale, + output_zero_point), + }; + + const int output_dims_count = ElementCount(*output_dims); + Quantize(expected_output_data, expected_output_quantized, output_dims_count, + output_scale, output_zero_point); + return ValidateDepthwiseConvGoldens( + expected_output_quantized, output_dims_count, conv_params, tolerance, + tensors_size, tensors, filter_comp_info, bias_comp_info); +} + +#endif // USE_TFLM_COMPRESSION + +// TODO(ddavis-2015): is this still valid? // Xtensa kernels do not support float activations., and the corresponding tests // are disabled. As a result, helper functions that are only needed for float // kernel tests also need to be ifdef'd out to avoid build errors due to unused // functions. #if !defined(XTENSA) -void TestDepthwiseConvFloat(int* input_dims_data, const float* input_data, - int* filter_dims_data, const float* filter_data, - int* bias_dims_data, const float* bias_data, - const float* expected_output_data, - int* output_dims_data, - TfLiteDepthwiseConvParams* conv_params, - float* output_data) { +void TestDepthwiseConvFloat( + int* input_dims_data, const float* input_data, int* filter_dims_data, + const float* filter_data, int* bias_dims_data, const float* bias_data, + const float* expected_output_data, int* output_dims_data, + TfLiteDepthwiseConvParams* conv_params, float* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); @@ -212,7 +375,12 @@ void TestDepthwiseConvFloat(int* input_dims_data, const float* input_data, }; ValidateDepthwiseConvGoldens(expected_output_data, output_dims_count, - conv_params, 1e-5, tensors_size, tensors); + conv_params, 1e-5, tensors_size, tensors +#ifdef USE_TFLM_COMPRESSION + , + filter_comp_info, bias_comp_info +#endif // USE_TFLM_COMPRESSION + ); } #endif // !defined(XTENSA) @@ -253,6 +421,60 @@ TF_LITE_MICRO_TEST(SimpleTest) { bias_values, golden, output_shape, &conv_params, output_data); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestCompressed) { + int input_shape[] = {4, 1, 3, 2, 2}; + const float input_values[] = {1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12}; + int filter_shape[] = {4, 1, 2, 2, 4}; + // Filter values: + // {1, 2, 3, 4, -9, 10, -11, 12, 5, 6, 7, 8, 13, -14, 15, -16} + // Align the tensor data the same as a Buffer in the schema + alignas(16) const uint8_t kBinQuantFilterData[] = {0x01, 0x23, 0xF8, 0xE9, + 0x45, 0x67, 0xAD, 0xBC}; + const float kBinQuantFilterValueTable[] = {1, 2, 3, 4, 5, 6, 7, 8, + 10, 12, 13, 15, -16, -14, -11, -9}; + int bias_shape[] = {4, 1, 1, 1, 4}; + const float bias_values[] = {1, 2, 3, 4}; + // Align the tensor data the same as a Buffer in the schema + alignas(16) const uint8_t kBinQuantBiasData[] = {0x1B}; + const float golden[] = { + 71, -34, 99, -20, 91, -26, 127, -4, + }; + int output_shape[] = {4, 1, 2, 1, 4}; + const int output_dims_count = std::extent::value; + float output_data[output_dims_count]; + + tflite::testing::TestCompressionInfo filter_comp_info = {}; + tflite::testing::TestCompressionInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = kBinQuantFilterValueTable; + filter_comp_info.value_table_stride = + std::extent::value; + filter_comp_info.bit_width = 4; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_values; + bias_comp_info.value_table_stride = std::extent::value; + bias_comp_info.bit_width = 2; + + TfLiteDepthwiseConvParams conv_params; + conv_params.activation = kTfLiteActNone; + conv_params.dilation_width_factor = 1; + conv_params.dilation_height_factor = 1; + conv_params.stride_height = 1; + conv_params.stride_width = 1; + + tflite::testing::TestDepthwiseConvFloat( + input_shape, input_values, filter_shape, + reinterpret_cast(kBinQuantFilterData), bias_shape, + reinterpret_cast(kBinQuantBiasData), golden, output_shape, + &conv_params, output_data, &filter_comp_info, &bias_comp_info); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(SimpleTestRelu) { int input_shape[] = {4, 1, 3, 2, 2}; const float input_values[] = {1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12}; @@ -1068,4 +1290,144 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelInt16InputInt8Filter) { bias_quantized, output_shape, golden, golden_quantized, output_data, output_scale, output_zero_point, &conv_params); } + +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelInt8Compressed) { + // data from TfLite test: + // PerChannelQuantizedDepthwiseConvolutionOpTest SimpleTestMixedOutputShift + const float input_scale = 0.5f; + const float output_scale = 0.5f; + const int input_zero_point = -1; + const int output_zero_point = -1; + constexpr float filter_scales[] = { + tflite::testing::kNumChannelsQ1, 0.1f, 2.0f, 3.0f, 0.4f, + }; + constexpr int filter_zero_points[] = { + tflite::testing::kNumChannelsQ1, 0, 0, 0, 0, + }; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int8_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kFilterElementsQ1]; + int32_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int8_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int8_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kFilterElementsQ1 / tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kFilterDataQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TfLiteDepthwiseConvParams conv_params = {}; + conv_params.activation = kTfLiteActNone; + conv_params.dilation_width_factor = 1; + conv_params.dilation_height_factor = 1; + conv_params.stride_height = 1; + conv_params.stride_width = 1; + + // tolerance of 3 is approx. 2.0f + // TODO(ddavis-2015): why does the tolerance differ from TfLite test??? + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestDepthwiseConvQuantizedCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1, + golden_quantized, output_quantized, output_scale, output_zero_point, + &conv_params, 3, &filter_comp_info, &bias_comp_info)); +} + +TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelInt16Compressed) { + // data from TfLite test: + // PerChannelQuantizedDepthwiseConvolutionOpTest SimpleTestMixedOutputShift + const float input_scale = + tflite::testing::SymmetricScaleFromMinMax(-4.0f, 4.0f); + const float output_scale = + tflite::testing::SymmetricScaleFromMinMax(-63.5f, 64.0f); + const int input_zero_point = 0; + const int output_zero_point = 0; + constexpr float filter_scales[] = { + tflite::testing::kNumChannelsQ1, 0.1f, 2.0f, 3.0f, 0.4f, + }; + constexpr int filter_zero_points[] = { + tflite::testing::kNumChannelsQ1, 0, 0, 0, 0, + }; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int16_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kFilterElementsQ1]; + int64_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int16_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int16_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kFilterElementsQ1 / tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kFilterDataQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TfLiteDepthwiseConvParams conv_params = {}; + conv_params.activation = kTfLiteActNone; + conv_params.dilation_width_factor = 1; + conv_params.dilation_height_factor = 1; + conv_params.stride_height = 1; + conv_params.stride_width = 1; + + // tolerance of 512 is approx. 1.0f + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestDepthwiseConvQuantizedCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1, + golden_quantized, output_quantized, output_scale, output_zero_point, + &conv_params, 512, &filter_comp_info, &bias_comp_info)); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index a1d22cfbd8e..b044a4bbab2 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -60,7 +60,7 @@ TfLiteStatus FullyConnectedPrepare(TfLiteContext* context, TfLiteNode* node) { (input->type == kTfLiteInt8 && (filter->type != kTfLiteInt8 && filter->type != kTfLiteInt4)) || (input->type == kTfLiteInt16 && filter->type != kTfLiteInt8)) { - MicroPrintf("Input type: %s with filter type : %s not supported.", + MicroPrintf("Input type: %s with filter type: %s not supported.", TfLiteTypeGetName(input->type), TfLiteTypeGetName(filter->type)); return kTfLiteError; @@ -79,6 +79,23 @@ TfLiteStatus FullyConnectedPrepare(TfLiteContext* context, TfLiteNode* node) { context, params->activation, input->type, input, filter, bias, output, data)); +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + if (micro_context->IsTensorCompressed(node, kFullyConnectedWeightsTensor) && + filter->type == kTfLiteInt4) { + MicroPrintf("Compression not supported with INT4 tensors"); + return kTfLiteError; + } + data->weights_scratch_index = + micro_context->AllocateDecompressionScratchBuffer( + node, kFullyConnectedWeightsTensor); + data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer( + node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); if (bias != nullptr) { @@ -102,8 +119,19 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); - TFLITE_DCHECK(node->user_data != nullptr); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, + kFullyConnectedWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + + TFLITE_DCHECK(node->user_data != nullptr); const auto& data = *(static_cast(node->user_data)); @@ -115,9 +143,18 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -152,9 +189,19 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)) : tflite::reference_integer_ops::FullyConnected( @@ -162,9 +209,19 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -186,9 +243,18 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; diff --git a/tensorflow/lite/micro/kernels/fully_connected.h b/tensorflow/lite/micro/kernels/fully_connected.h index 670488ab618..64213f0fb63 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.h +++ b/tensorflow/lite/micro/kernels/fully_connected.h @@ -50,6 +50,14 @@ struct OpDataFullyConnected { int32_t* per_channel_output_shift; bool is_per_channel; #endif + +#ifdef USE_TFLM_COMPRESSION + + // scratch buffers for compressed tensors + int weights_scratch_index; + int bias_scratch_index; + +#endif // USE_TFLM_COMPRESSION }; extern const int kFullyConnectedInputTensor; diff --git a/tensorflow/lite/micro/kernels/fully_connected_test.cc b/tensorflow/lite/micro/kernels/fully_connected_test.cc index 2ad132055b8..2cf3427c874 100644 --- a/tensorflow/lite/micro/kernels/fully_connected_test.cc +++ b/tensorflow/lite/micro/kernels/fully_connected_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,6 +42,29 @@ const float simple_weights_data[] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 }; +int simple_bias_dims[] = {1, 3}; +const float simple_bias_data[] = {1, 2, 3}; + +#ifdef USE_TFLM_COMPRESSION + +// compressed filter data for kBinQuant scheme +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantWeightData[] = { + 0x01, 0x23, 0x45, 0x67, 0x89, 0x01, 0x23, 0x45, + 0x67, 0x89, 0x01, 0x23, 0x45, 0x67, 0x89}; +constexpr float kBinQuantWeightValueTable[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; +constexpr size_t kBinQuantWeightValueTableElements = + std::extent::value; +constexpr int kBinQuantWeightBitWidth = 4; +// compressed bias data for kBinQuant scheme +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasData[] = {0x18}; +constexpr int kBinQuantBiasBitWidth = 2; +constexpr size_t simple_bias_size = + std::extent::value; + +#endif // USE_TFLM_COMPRESSION + // TODO(b/258710417): INT4 isn't currently supported on Hexagon. #if !defined(HEXAGON) const float simple_int4_weights_data[] = { @@ -53,8 +76,6 @@ const float simple_golden_null_bias_int4_weights[] = { -28, -28, -28, 0, 0, 0, }; #endif -int simple_bias_dims[] = {1, 3}; -const float simple_bias_data[] = {1, 2, 3}; const float simple_golden[] = { 24, 25, 26, 58, 59, 60, }; @@ -241,11 +262,19 @@ const float representative_64x16_golden[] = { const int representative_64x16_output_size = 16; int representative_64x16_output_dims[] = {2, 1, 16}; -template +constexpr int kMaxTensors = 4; + +template TfLiteStatus ValidateFullyConnectedGoldens( TfLiteTensor* tensors, const int tensors_size, bool null_bias, const TfLiteFusedActivation activation, const float tolerance, - const int output_len, const T* golden, T* output_data) { + const int output_len, const T* golden, T* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* weight_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { TfLiteFullyConnectedParams builtin_data = { activation, kTfLiteFullyConnectedWeightsFormatDefault, false, false, kTfLiteNoType}; @@ -272,10 +301,37 @@ TfLiteStatus ValidateFullyConnectedGoldens( TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); +#ifdef USE_TFLM_COMPRESSION + + TestCompressedList tcl; + + if (weight_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*weight_comp_info, tensors[kFullyConnectedWeightsTensor], + kFullyConnectedWeightsTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + if (bias_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*bias_comp_info, tensors[kFullyConnectedBiasTensor], + kFullyConnectedBiasTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + const CompressedTensorList* comp_list_p = tcl.GetCompressedTensorList(); + +#endif // USE_TFLM_COMPRESSION + const TFLMRegistration registration = Register_FULLY_CONNECTED(); micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, outputs_array, - reinterpret_cast(&builtin_data)); + reinterpret_cast(&builtin_data), nullptr +#ifdef USE_TFLM_COMPRESSION + , + comp_list_p +#endif // USE_TFLM_COMPRESSION + ); TfLiteStatus status = runner.InitAndPrepare(); if (status != kTfLiteOk) { @@ -293,11 +349,18 @@ TfLiteStatus ValidateFullyConnectedGoldens( return kTfLiteOk; } +template TfLiteStatus TestFullyConnectedFloat( int* input_dims_data, const float* input_data, int* weights_dims_data, const float* weights_data, int* bias_dims_data, const float* bias_data, const float* golden, int* output_dims_data, - TfLiteFusedActivation activation, float* output_data) { + TfLiteFusedActivation activation, float* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* weight_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* weights_dims = IntArrayFromInts(weights_dims_data); TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); @@ -305,16 +368,15 @@ TfLiteStatus TestFullyConnectedFloat( const int output_dims_count = ElementCount(*output_dims); bool null_bias = bias_data == nullptr ? true : false; - constexpr int array_size = 4; // Avoid variable length array warning. - const int inputs_size = bias_data == nullptr ? 2 : 3; + const int inputs_size = null_bias ? 2 : 3; constexpr int outputs_size = 1; const int tensors_size = inputs_size + outputs_size; - TfLiteTensor tensors[array_size]; + TfLiteTensor tensors[kMaxTensors]; tensors[0] = CreateTensor(input_data, input_dims); tensors[1] = CreateTensor(weights_data, weights_dims); - if (bias_data == nullptr) { + if (null_bias) { tensors[2] = CreateTensor(output_data, output_dims); } else { tensors[2] = CreateTensor(bias_data, bias_dims); @@ -323,7 +385,12 @@ TfLiteStatus TestFullyConnectedFloat( return ValidateFullyConnectedGoldens(tensors, tensors_size, null_bias, activation, 1e-4f, output_dims_count, - golden, output_data); + golden, output_data +#ifdef USE_TFLM_COMPRESSION + , + weight_comp_info, bias_comp_info +#endif // USE_TFLM_COMPRESSION + ); } template @@ -345,7 +412,7 @@ TfLiteStatus TestFullyConnectedQuantized( bool null_bias = bias_data == nullptr ? true : false; constexpr int array_size = 4; // Avoid variable length array warning. - const int inputs_size = bias_data == nullptr ? 2 : 3; + const int inputs_size = null_bias ? 2 : 3; constexpr int outputs_size = 1; const int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[array_size]; @@ -355,7 +422,7 @@ TfLiteStatus TestFullyConnectedQuantized( tensors[1] = CreateQuantizedTensor( weights_data, weights_quantized, weights_dims, weights_scale, weights_zero_point, false, weights_packed_type); - if (bias_data == nullptr) { + if (null_bias) { tensors[2] = CreateQuantizedTensor(output_data, output_dims, output_scale, output_zero_point); } else { @@ -373,6 +440,71 @@ TfLiteStatus TestFullyConnectedQuantized( golden_quantized, output_data); } +#ifdef USE_TFLM_COMPRESSION + +template +TfLiteStatus TestFullyConnectedQuantizedCompressed( + int* input_dims_data, const float* input_data, TIO* input_quantized, + float input_scale, int input_zero_point, int* output_dims_data, + const float* expected_output_data, TIO* expected_output_quantized, + TIO* output_quantized, float output_scale, int output_zero_point, + const TfLiteFusedActivation activation, + const TestCompressionQuantizedInfo* weight_comp_info, + const TestCompressionQuantizedInfo* bias_comp_info) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* weight_dims = IntArrayFromInts(weight_comp_info->dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_comp_info->dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + + TfLiteFloatArray* weight_scales = + FloatArrayFromFloats(weight_comp_info->scales); + TfLiteIntArray* weight_zero_points = + IntArrayFromInts(weight_comp_info->zero_points); + + TfLiteTensor weight_tensor = CreateQuantizedTensor( + weight_comp_info->compressed, weight_dims, weight_scales->data[0], + weight_zero_points->data[0], false, kTfLiteInt8); + SymmetricQuantize(weight_comp_info->data, weight_comp_info->value_table, + weight_comp_info->value_table_stride, + weight_scales->data[0]); + + TfLiteTensor bias_tensor = {}; + if (bias_comp_info != nullptr) { + bias_tensor = CreateQuantizedTensor(bias_comp_info->compressed, bias_dims, + input_scale * weight_scales->data[0], 0, + false, typeToTfLiteType()); + SymmetricQuantize(bias_comp_info->data, bias_comp_info->value_table, + bias_comp_info->value_table_stride, + bias_tensor.params.scale); + } + + TfLiteTensor output_tensor = CreateQuantizedTensor( + output_quantized, output_dims, output_scale, output_zero_point); + + const int tensors_size = + (bias_comp_info == nullptr) ? kMaxTensors - 1 : kMaxTensors; + TfLiteTensor tensors[kMaxTensors] = {}; + tensors[0] = CreateQuantizedTensor(input_data, input_quantized, input_dims, + input_scale, input_zero_point); + tensors[1] = weight_tensor; + if (bias_comp_info == nullptr) { + tensors[2] = output_tensor; + } else { + tensors[2] = bias_tensor; + tensors[3] = output_tensor; + } + + const int output_dims_count = ElementCount(*output_dims); + Quantize(expected_output_data, expected_output_quantized, output_dims_count, + output_scale, output_zero_point); + return ValidateFullyConnectedGoldens( + tensors, tensors_size, bias_comp_info == nullptr, activation, 0.0f, + output_dims_count, expected_output_quantized, output_quantized, + weight_comp_info, bias_comp_info); +} + +#endif // USE_TFLM_COMPRESSION + } // namespace } // namespace testing } // namespace tflite @@ -393,6 +525,40 @@ TF_LITE_MICRO_TEST(SimpleTest) { kTfLiteOk); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestCompressed) { + float output_data[tflite::testing::simple_output_size]; + + tflite::testing::TestCompressionInfo weight_comp_info = {}; + tflite::testing::TestCompressionInfo bias_comp_info = {}; + + weight_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + weight_comp_info.value_table = tflite::testing::kBinQuantWeightValueTable; + weight_comp_info.value_table_stride = + tflite::testing::kBinQuantWeightValueTableElements; + weight_comp_info.bit_width = tflite::testing::kBinQuantWeightBitWidth; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = tflite::testing::simple_bias_data; + bias_comp_info.value_table_stride = tflite::testing::simple_bias_size; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidth; + + TF_LITE_MICRO_EXPECT_EQ( + tflite::testing::TestFullyConnectedFloat( + tflite::testing::simple_input_dims, + tflite::testing::simple_input_data, + tflite::testing::simple_weights_dims, + reinterpret_cast(tflite::testing::kBinQuantWeightData), + tflite::testing::simple_bias_dims, + reinterpret_cast(tflite::testing::kBinQuantBiasData), + tflite::testing::simple_golden, tflite::testing::simple_output_dims, + kTfLiteActNone, output_data, &weight_comp_info, &bias_comp_info), + kTfLiteOk); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(SimpleTestNullBias) { float output_data[tflite::testing::simple_output_size]; TF_LITE_MICRO_EXPECT_EQ( @@ -434,6 +600,58 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8) { kTfLiteOk); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8Compressed) { + const float input_scale = 1.0f; + const int input_zero_point = -1; + constexpr float weights_scale[] = {1, 1.0f}; + constexpr int weights_zero_point[] = {1, 0}; + const float output_scale = 0.5f; + const int output_zero_point = -1; + + int8_t input_quantized[tflite::testing::simple_input_size]; + int8_t weights_quantized[tflite::testing::kBinQuantWeightValueTableElements]; + int32_t bias_quantized[tflite::testing::simple_output_size]; + int8_t golden_quantized[tflite::testing::simple_output_size]; + int8_t output_data[tflite::testing::simple_output_size]; + + tflite::testing::TestCompressionQuantizedInfo weight_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + weight_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + weight_comp_info.value_table = weights_quantized; + weight_comp_info.value_table_stride = + tflite::testing::kBinQuantWeightValueTableElements; + weight_comp_info.bit_width = tflite::testing::kBinQuantWeightBitWidth; + weight_comp_info.compressed = tflite::testing::kBinQuantWeightData; + weight_comp_info.data = tflite::testing::kBinQuantWeightValueTable; + weight_comp_info.dims_data = tflite::testing::simple_weights_dims; + weight_comp_info.scales = weights_scale; + weight_comp_info.zero_points = weights_zero_point; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = tflite::testing::simple_bias_size; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidth; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasData; + bias_comp_info.data = tflite::testing::simple_bias_data; + bias_comp_info.dims_data = tflite::testing::simple_bias_dims; + // bias scales and bias zero_points are not used + + TF_LITE_MICRO_EXPECT_EQ( + tflite::testing::TestFullyConnectedQuantizedCompressed( + tflite::testing::simple_input_dims, + tflite::testing::simple_input_data, input_quantized, input_scale, + input_zero_point, tflite::testing::simple_output_dims, + tflite::testing::simple_golden, golden_quantized, output_data, + output_scale, output_zero_point, kTfLiteActNone, &weight_comp_info, + &bias_comp_info), + kTfLiteOk); +} + +#endif // USE_TFLM_COMPRESSION + #if !defined(HEXAGON) TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16) { const float input_scale = 128.0 / 65536; @@ -443,7 +661,6 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16) { const float output_scale = 128.0 / 65536; const int output_zero_point = 0; - const float simple_golden[] = {24, 25, 26, 58, 59, 60}; int16_t input_quantized[tflite::testing::simple_input_size]; int8_t weights_quantized[tflite::testing::simple_weights_size]; int64_t bias_quantized[tflite::testing::simple_output_size]; @@ -457,12 +674,66 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16) { input_zero_point, tflite::testing::simple_weights_dims, tflite::testing::simple_weights_data, weights_quantized, weights_scale, weights_zero_point, tflite::testing::simple_bias_dims, - tflite::testing::simple_bias_data, bias_quantized, simple_golden, - golden_quantized, tflite::testing::simple_output_dims, output_scale, - output_zero_point, kTfLiteActNone, output_data), + tflite::testing::simple_bias_data, bias_quantized, + tflite::testing::simple_golden, golden_quantized, + tflite::testing::simple_output_dims, output_scale, output_zero_point, + kTfLiteActNone, output_data), kTfLiteOk); } -#endif + +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16Compressed) { + const float input_scale = 128.0 / 65536; + const int input_zero_point = 0; + constexpr float weights_scale[] = {1, 1.0f}; + constexpr int weights_zero_point[] = {1, 0}; + const float output_scale = 128.0 / 65536; + const int output_zero_point = 0; + + int16_t input_quantized[tflite::testing::simple_input_size]; + int8_t weights_quantized[tflite::testing::kBinQuantWeightValueTableElements]; + int64_t bias_quantized[tflite::testing::simple_output_size]; + int16_t golden_quantized[tflite::testing::simple_output_size]; + int16_t output_data[tflite::testing::simple_output_size]; + + tflite::testing::TestCompressionQuantizedInfo weight_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + weight_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + weight_comp_info.value_table = weights_quantized; + weight_comp_info.value_table_stride = + tflite::testing::kBinQuantWeightValueTableElements; + weight_comp_info.bit_width = tflite::testing::kBinQuantWeightBitWidth; + weight_comp_info.compressed = tflite::testing::kBinQuantWeightData; + weight_comp_info.data = tflite::testing::kBinQuantWeightValueTable; + weight_comp_info.dims_data = tflite::testing::simple_weights_dims; + weight_comp_info.scales = weights_scale; + weight_comp_info.zero_points = weights_zero_point; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = tflite::testing::simple_bias_size; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidth; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasData; + bias_comp_info.data = tflite::testing::simple_bias_data; + bias_comp_info.dims_data = tflite::testing::simple_bias_dims; + // bias scales and bias zero_points are not used + + TF_LITE_MICRO_EXPECT_EQ( + tflite::testing::TestFullyConnectedQuantizedCompressed( + tflite::testing::simple_input_dims, + tflite::testing::simple_input_data, input_quantized, input_scale, + input_zero_point, tflite::testing::simple_output_dims, + tflite::testing::simple_golden, golden_quantized, output_data, + output_scale, output_zero_point, kTfLiteActNone, &weight_comp_info, + &bias_comp_info), + kTfLiteOk); +} + +#endif // USE_TFLM_COMPRESSION + +#endif // !defined(HEXAGON) TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedInt8) { const float input_scale = 1.0f; diff --git a/tensorflow/lite/micro/kernels/kernel_runner.cc b/tensorflow/lite/micro/kernels/kernel_runner.cc index 602778d7c50..79824efe5de 100644 --- a/tensorflow/lite/micro/kernels/kernel_runner.cc +++ b/tensorflow/lite/micro/kernels/kernel_runner.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h" #include "tensorflow/lite/micro/micro_arena_constants.h" #include "tensorflow/lite/micro/micro_log.h" -#include "tensorflow/lite/micro/test_helpers.h" namespace tflite { namespace micro { @@ -38,12 +37,22 @@ KernelRunner::KernelRunner(const TFLMRegistration& registration, TfLiteTensor* tensors, int tensors_size, TfLiteIntArray* inputs, TfLiteIntArray* outputs, const void* builtin_data, - TfLiteIntArray* intermediates) + TfLiteIntArray* intermediates +#ifdef USE_TFLM_COMPRESSION + , + const CompressedTensorList* compressed_tensors +#endif // USE_TFLM_COMPRESSION + ) : registration_(registration), allocator_(SingleArenaBufferAllocator::Create(kKernelRunnerBuffer_, kKernelRunnerBufferSize_)), mock_micro_graph_(allocator_), - fake_micro_context_(tensors, allocator_, &mock_micro_graph_) { + fake_micro_context_(tensors, allocator_, &mock_micro_graph_ +#ifdef USE_TFLM_COMPRESSION + , + compressed_tensors +#endif // USE_TFLM_COMPRESSION + ) { // Prepare TfLiteContext: context_.impl_ = static_cast(&fake_micro_context_); context_.ReportError = MicroContextReportOpError; diff --git a/tensorflow/lite/micro/kernels/kernel_runner.h b/tensorflow/lite/micro/kernels/kernel_runner.h index 25b97c11302..8dbd7f8b015 100644 --- a/tensorflow/lite/micro/kernels/kernel_runner.h +++ b/tensorflow/lite/micro/kernels/kernel_runner.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,7 +36,12 @@ class KernelRunner { KernelRunner(const TFLMRegistration& registration, TfLiteTensor* tensors, int tensors_size, TfLiteIntArray* inputs, TfLiteIntArray* outputs, const void* builtin_data, - TfLiteIntArray* intermediates = nullptr); + TfLiteIntArray* intermediates = nullptr +#ifdef USE_TFLM_COMPRESSION + , + const CompressedTensorList* compressed_tensors = nullptr +#endif // USE_TFLM_COMPRESSION + ); // Calls init and prepare on the kernel (i.e. TFLMRegistration) struct. // Any exceptions will be DebugLog'd and returned as a status code. diff --git a/tensorflow/lite/micro/kernels/kernel_util.h b/tensorflow/lite/micro/kernels/kernel_util.h index f14c927133d..5f8eb2e4ec3 100644 --- a/tensorflow/lite/micro/kernels/kernel_util.h +++ b/tensorflow/lite/micro/kernels/kernel_util.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,6 +25,13 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/micro/micro_context.h" +#ifdef USE_TFLM_COMPRESSION + +#include "tensorflow/lite/micro/micro_arena_constants.h" +#include "tensorflow/lite/micro/micro_utils.h" + +#endif // USE_TFLM_COMPRESSION + namespace tflite { namespace micro { @@ -91,6 +98,38 @@ const T* GetOptionalTensorData(const TfLiteEvalTensor* tensor) { : reinterpret_cast(tensor->data.raw); } +#ifdef USE_TFLM_COMPRESSION + +// Overloads existing GetTensorData. If not compressed, this will return +// tensor->data. +template +const T* GetTensorData(MicroContext* micro_context, + const TfLiteEvalTensor* tensor, + const CompressionTensorData* compression_data, + int scratch_buffer_handle) { + if (tensor == nullptr) { + return nullptr; + } + if (compression_data == nullptr) { + return reinterpret_cast(tensor->data.data); + } + + void* scratch_buffer = nullptr; + if (scratch_buffer_handle != -1) { + scratch_buffer = micro_context->GetScratchBuffer(scratch_buffer_handle); + } else { + size_t bytes_to_allocate = EvalTensorBytes(tensor); + scratch_buffer = micro_context->AllocateDecompressionMemory( + bytes_to_allocate, MicroArenaBufferAlignment()); + } + TFLITE_DCHECK(scratch_buffer != nullptr); + void* uncompressed_data = micro_context->DecompressTensorToBuffer( + *tensor, *compression_data, scratch_buffer); + return reinterpret_cast(uncompressed_data); +} + +#endif // USE_TFLM_COMPRESSION + // Returns the shape of a TfLiteEvalTensor struct. const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor); diff --git a/tensorflow/lite/micro/kernels/transpose_conv.cc b/tensorflow/lite/micro/kernels/transpose_conv.cc index ea0efae0607..7d65dc3de7c 100644 --- a/tensorflow/lite/micro/kernels/transpose_conv.cc +++ b/tensorflow/lite/micro/kernels/transpose_conv.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,30 +27,26 @@ limitations under the License. #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/padding.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/transpose_conv.h" #include "tensorflow/lite/micro/micro_log.h" namespace tflite { namespace { -// For the TfLite transpose_conv implementation, input tensor 0 corresponds to -// the OutputShapeTensor. However, since TFLM does not support dynamic tensors, -// the TFLM implementation ignores input tensor 0 and the only inputs we care -// about are kFilterTensor, kInputTensor and kBiasTensor. -constexpr int kFilterTensor = 1; -constexpr int kInputTensor = 2; -constexpr int kBiasTensor = 3; -constexpr int kOutputTensor = 0; - -// Conv is quantized along dimension 0: -// https://www.tensorflow.org/lite/performance/quantization_spec -constexpr int kConvQuantizedDimension = 0; - struct OpData { ConvParams params; // A scratch buffer is required for quantized implementations. int scratch_buffer_index; +#ifdef USE_TFLM_COMPRESSION + + // scratch buffers for compressed tensors + int filter_scratch_index; + int bias_scratch_index; + +#endif // USE_TFLM_COMPRESSION + // Index to the converted 64-bit bias buffer from 16-bit bias. This is // required to handle 16x8 transpose convolutions where a 16-bit bias is // provided, whereas the kernel expects 64-bit biases. @@ -102,17 +98,17 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, MicroContext* micro_context = GetMicroContext(context); TfLiteTensor* input = - micro_context->AllocateTempInputTensor(node, kInputTensor); + micro_context->AllocateTempInputTensor(node, kTransposeConvInputTensor); TF_LITE_ENSURE(context, input != nullptr); - TfLiteTensor* filter = - micro_context->AllocateTempInputTensor(node, kFilterTensor); + TfLiteTensor* filter = micro_context->AllocateTempInputTensor( + node, kTransposeConvFilterTensor); TF_LITE_ENSURE(context, filter != nullptr); TfLiteTensor* bias = - micro_context->AllocateTempInputTensor(node, kBiasTensor); - TfLiteTensor* output = - micro_context->AllocateTempOutputTensor(node, kOutputTensor); + micro_context->AllocateTempInputTensor(node, kTransposeConvBiasTensor); + TfLiteTensor* output = micro_context->AllocateTempOutputTensor( + node, kTransposeConvOutputTensor); TF_LITE_ENSURE(context, output != nullptr); - int output_channels = filter->dims->data[kConvQuantizedDimension]; + int output_channels = filter->dims->data[kTransposeConvQuantizedDimension]; TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( context, input, filter, bias, output, kTfLiteActNone, @@ -164,13 +160,13 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) { MicroContext* micro_context = GetMicroContext(context); TfLiteTensor* output = - micro_context->AllocateTempOutputTensor(node, kOutputTensor); + micro_context->AllocateTempOutputTensor(node, kTransposeConvOutputTensor); TF_LITE_ENSURE(context, output != nullptr); TfLiteTensor* input = - micro_context->AllocateTempInputTensor(node, kInputTensor); + micro_context->AllocateTempInputTensor(node, kTransposeConvInputTensor); TF_LITE_ENSURE(context, input != nullptr); TfLiteTensor* filter = - micro_context->AllocateTempInputTensor(node, kFilterTensor); + micro_context->AllocateTempInputTensor(node, kTransposeConvFilterTensor); TF_LITE_ENSURE(context, filter != nullptr); TF_LITE_ENSURE_MSG( @@ -186,7 +182,7 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) { const int filter_height = SizeOfDimension(filter, 1); // Dynamically allocate per-channel quantization parameters. - const int num_channels = filter->dims->data[kConvQuantizedDimension]; + const int num_channels = filter->dims->data[kTransposeConvQuantizedDimension]; data->per_channel_output_multiplier = static_cast(context->AllocatePersistentBuffer( context, num_channels * sizeof(int32_t))); @@ -223,10 +219,10 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, affine_quantization->scale); TF_LITE_ENSURE(context, affine_quantization->zero_point); - TF_LITE_ENSURE(context, - affine_quantization->scale->size == 1 || - affine_quantization->scale->size == - filter->dims->data[kConvQuantizedDimension]); + TF_LITE_ENSURE( + context, affine_quantization->scale->size == 1 || + affine_quantization->scale->size == + filter->dims->data[kTransposeConvQuantizedDimension]); TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, affine_quantization->zero_point->size); } @@ -244,6 +240,18 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) { data->params.stride_width = params->stride_width; data->params.stride_height = params->stride_height; +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + data->filter_scratch_index = + micro_context->AllocateDecompressionScratchBuffer( + node, kTransposeConvFilterTensor); + data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer( + node, kTransposeConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); @@ -252,15 +260,26 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); + tflite::micro::GetEvalInput(context, node, kTransposeConvInputTensor); const TfLiteEvalTensor* filter = - tflite::micro::GetEvalInput(context, node, kFilterTensor); + tflite::micro::GetEvalInput(context, node, kTransposeConvFilterTensor); const TfLiteEvalTensor* bias = (NumInputs(node) == 4) - ? tflite::micro::GetEvalInput(context, node, kBiasTensor) + ? tflite::micro::GetEvalInput(context, node, kTransposeConvBiasTensor) : nullptr; TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); + tflite::micro::GetEvalOutput(context, node, kTransposeConvOutputTensor); + +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, kTransposeConvFilterTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kTransposeConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION TFLITE_DCHECK(node->user_data != nullptr); const OpData& data = *(static_cast(node->user_data)); @@ -280,9 +299,17 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) { op_params, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, data.filter_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr); @@ -296,9 +323,17 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) { data.per_channel_output_shift, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, data.filter_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer); @@ -311,16 +346,29 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) { auto* bias_converted_buffer = static_cast(context->GetScratchBuffer( context, data.bias_converted_buffer_index)); + const int16_t* const bias_int16_data = +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION + static_cast(bias->data.data); +#endif // USE_TFLM_COMPRESSION for (int i = 0; i < tflite::micro::GetTensorShape(bias).FlatSize(); i++) { - bias_converted_buffer[i] = bias->data.i16[i]; + bias_converted_buffer[i] = bias_int16_data[i]; } reference_integer_ops::TransposeConv( data.params, data.per_channel_output_multiplier, data.per_channel_output_shift, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.filter_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(bias), bias_converted_buffer, tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), @@ -331,9 +379,18 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) { data.per_channel_output_shift, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.filter_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetOptionalTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer); diff --git a/tensorflow/lite/micro/kernels/transpose_conv.h b/tensorflow/lite/micro/kernels/transpose_conv.h index 3a99ccbf847..ec0416e067f 100644 --- a/tensorflow/lite/micro/kernels/transpose_conv.h +++ b/tensorflow/lite/micro/kernels/transpose_conv.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,6 +23,19 @@ limitations under the License. namespace tflite { +// For the TfLite transpose_conv implementation, input tensor 0 corresponds to +// the OutputShapeTensor. However, since TFLM does not support dynamic tensors, +// the TFLM implementation ignores input tensor 0 and the only inputs we care +// about are kFilterTensor, kInputTensor and kBiasTensor. +constexpr int kTransposeConvFilterTensor = 1; +constexpr int kTransposeConvInputTensor = 2; +constexpr int kTransposeConvBiasTensor = 3; +constexpr int kTransposeConvOutputTensor = 0; + +// Conv is quantized along dimension 0: +// https://www.tensorflow.org/lite/performance/quantization_spec +constexpr int kTransposeConvQuantizedDimension = 0; + // This is the most generic TFLMRegistration. The actual supported types // may still be target dependent. The only requirement is that every // implementation (reference or optimized) must define this function. diff --git a/tensorflow/lite/micro/kernels/transpose_conv_test.cc b/tensorflow/lite/micro/kernels/transpose_conv_test.cc index 49d2c90f439..2d5f3a0ba4e 100644 --- a/tensorflow/lite/micro/kernels/transpose_conv_test.cc +++ b/tensorflow/lite/micro/kernels/transpose_conv_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,9 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/micro/kernels/transpose_conv.h" + +#include + #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/micro/kernels/conv_test.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" #include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/micro/test_helpers.h" @@ -47,20 +50,127 @@ static const float kGoldenData[kOutputElements] = { 184, 412, 568, 528, 678, 1347, 1689, 1434, 1494, 2715, 3057, 2442, 1968, 3352, 3652, 2760}; +#ifdef USE_TFLM_COMPRESSION + +constexpr size_t kTransposeConvMaxTensors = 5; +constexpr size_t kTransposeConvMaxInputTensors = 4; + +// compressed filter data for kBinQuant scheme, matches kFilterData +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantFilterData[] = { + 0x00, 0x44, 0x32, 0x14, 0xC7, 0x42, 0x54, 0xB6, 0x35, 0xCF, 0x84, 0x40}; +constexpr int kBinQuantFilterBitWidth = 5; +// compressed bias data for kBinQuant scheme, matches kBiasData +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasData[] = {0x00}; +constexpr int kBinQuantBiasBitWidth = 1; + +// Common inputs and outputs (quantized single channel). +// data from TfLite test: SimpleBiasTestQuantizedPerChannelSingleChannel +static int kInputShapeQ1[] = {4, 1, 4, 4, 1}; +static constexpr float kInputDataQ1[] = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; +constexpr size_t kInputElementsQ1 = std::extent::value; + +constexpr int kNumChannelsQ1 = 1; +static int kFilterShapeQ1[] = {4, 1, 3, 3, 1}; +static constexpr float kFilterDataQ1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; +constexpr size_t kFilterElementsQ1 = + std::extent::value; + +static int kBiasShapeQ1[] = {1, 1}; +static constexpr float kBiasDataQ1[] = {1}; +constexpr size_t kBiasElementsQ1 = std::extent::value; + +static int kOutputShapeQ1[] = {4, 1, 4, 4, 1}; +static constexpr float kGoldenDataQ1[] = { + 30, 62, 84, 76, 100, 194, 238, 200, 208, 372, 418, 330, 264, 446, 486, 366}; +constexpr int kOutputElementsQ1 = std::extent::value; + +// compressed filter data for kBinQuant scheme, matches kFilterDataQ1 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantFilterDataQ1[] = {0x01, 0x23, 0x45, 0x67, + 0x80}; +constexpr int kBinQuantFilterBitWidthQ1 = 4; +// compressed bias data for kBinQuant scheme, matches kBiasDataQ1 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasDataQ1[] = {0x00}; +constexpr int kBinQuantBiasBitWidthQ1 = 1; + +// Common inputs and outputs (quantized multi channel). +// data from TfLite test: SimpleBiasTestQuantizedPerChannel16x8Bias64 +static int kInputShapeQ2[] = {4, 1, 2, 3, 2}; +static constexpr float kInputDataQ2[] = { + // [1 * 2 * 3 * 2] as [batch, y, x, input_channel] + 3, 2, // batch = 0, y = 0, x = 0 + 1, -1, // batch = 0, y = 0, x = 1 + -2, -3, // batch = 0, y = 0, x = 2 + 4, 3, // batch = 0, y = 1, x = 0 + 2, -2, // batch = 0, y = 1, x = 1 + -3, -4, // batch = 0, y = 1, x = 2 +}; +constexpr size_t kInputElementsQ2 = std::extent::value; + +constexpr int kNumChannelsQ2 = 2; +static int kFilterShapeQ2[] = {4, 2, 2, 2, 2}; +// Original filter data: +// static constexpr float kFilterDataQ2[] = { +// // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel] +// 1, 2, // out channel = 0, y = 0, x = 0 +// 3, 4, // out channel = 0, y = 0, x = 1 +// 3, 4, // out channel = 0, y = 1, x = 0 +// 5, 6, // out channel = 0, y = 1, x = 1 +// 7, 8, // out channel = 1, y = 0, x = 0 +// 5, 6, // out channel = 1, y = 0, x = 1 +// 3, 4, // out channel = 1, y = 1, x = 0 +// 1, 2, // out channel = 1, y = 1, x = 1 +// }; + +static int kBiasShapeQ2[] = {1, 2}; +static constexpr float kBiasDataQ2[] = {3, -2}; +constexpr size_t kBiasElementsQ2 = std::extent::value; + +static int kOutputShapeQ2[] = {4, 1, 2, 3, 2}; +static constexpr float kGoldenDataQ2[] = {10, 35, 19, 24, -6, -41, + 30, 64, 51, 40, -29, -64}; +constexpr int kOutputElementsQ2 = std::extent::value; + +// compressed filter data for kBinQuant scheme, matches kFilterDataQ2 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantFilterDataQ2[] = {0x05, 0x34, 0xE5, + 0xDE, 0x54, 0xC1}; +constexpr float kBinQuantFilterValueTableQ2[] = {1, 2, 3, 4, 5, 6, 0, 0, + 1, 2, 3, 4, 5, 6, 7, 8}; +constexpr size_t kBinQuantFilterValueTableElementsQ2 = + std::extent::value; +constexpr int kBinQuantFilterBitWidthQ2 = 3; +// compressed bias data for kBinQuant scheme, matches kBiasDataQ2 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasDataQ2[] = {0x00}; +constexpr int kBinQuantBiasBitWidthQ2 = 1; + +#endif // USE_TFLM_COMPRESSION + // Transpose conv uses TfLiteConvParams. -static TfLiteConvParams common_conv_params = {kTfLitePaddingSame, // padding - 1, // stride_width - 1, // stride_height - kTfLiteActNone, - 1, - 1, - kTfLiteNoType}; +static const TfLiteConvParams common_conv_params = { + kTfLitePaddingSame, // padding + 1, // stride_width + 1, // stride_height + kTfLiteActNone, + 1, + 1, + kTfLiteNoType}; template -TfLiteStatus InvokeTransposeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, - TfLiteConvParams* conv_params, - T* output_data) { +TfLiteStatus InvokeTransposeConv( + TfLiteTensor* tensors, int tensors_size, int output_length, + const TfLiteConvParams* conv_params, T* output_data +#ifdef USE_TFLM_COMPRESSION + , + const CompressedTensorList* comp_list_p = nullptr +#endif // USE_TFLM_COMPRESSION +) { + // TODO(b/358151309): support optional bias tensor int inputs_array_data[] = {4, 0, 1, 2, 3}; TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); int outputs_array_data[] = {1, 4}; @@ -68,7 +178,12 @@ TfLiteStatus InvokeTransposeConv(TfLiteTensor* tensors, int tensors_size, const TFLMRegistration registration = tflite::Register_TRANSPOSE_CONV(); micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, - outputs_array, conv_params); + outputs_array, conv_params +#ifdef USE_TFLM_COMPRESSION + , + nullptr, comp_list_p +#endif // USE_TFLM_COMPRESSION + ); const char* init_data = reinterpret_cast(conv_params); TfLiteStatus status = runner.InitAndPrepare(init_data); @@ -78,15 +193,45 @@ TfLiteStatus InvokeTransposeConv(TfLiteTensor* tensors, int tensors_size, return runner.Invoke(); } -template -TfLiteStatus ValidateTransposeConvGoldens(TfLiteTensor* tensors, - int tensors_size, - const T* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - T* output_data, float tolerance) { +template +TfLiteStatus ValidateTransposeConvGoldens( + TfLiteTensor* tensors, int tensors_size, const T* expected_output_data, + int output_length, const TfLiteConvParams* conv_params, T* output_data, + float tolerance = 1e-5f +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { +#ifdef USE_TFLM_COMPRESSION + + TestCompressedList tcl; + if (filter_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*filter_comp_info, tensors[kTransposeConvFilterTensor], + kTransposeConvFilterTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + if (bias_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*bias_comp_info, tensors[kTransposeConvBiasTensor], + kTransposeConvBiasTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + const CompressedTensorList* comp_list_p = tcl.GetCompressedTensorList(); + +#endif // USE_TFLM_COMPRESSION + TfLiteStatus status = InvokeTransposeConv( - tensors, tensors_size, output_length, conv_params, output_data); + tensors, tensors_size, output_length, conv_params, output_data +#ifdef USE_TFLM_COMPRESSION + , + comp_list_p +#endif // USE_TFLM_COMPRESSION + ); if (status != kTfLiteOk) { return status; } @@ -97,11 +242,18 @@ TfLiteStatus ValidateTransposeConvGoldens(TfLiteTensor* tensors, return kTfLiteOk; } +template TfLiteStatus TestTransposeConvFloat( int* input_dims_data, const float* input_data, int* filter_dims_data, const float* filter_data, int* bias_dims_data, const float* bias_data, int* output_dims_data, const float* expected_output_data, - TfLiteConvParams* conv_params, float* output_data) { + const TfLiteConvParams* conv_params, float* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); @@ -125,7 +277,12 @@ TfLiteStatus TestTransposeConvFloat( return ValidateTransposeConvGoldens(tensors, tensors_size, expected_output_data, output_dims_count, - conv_params, output_data, 0.001f); + conv_params, output_data +#ifdef USE_TFLM_COMPRESSION + , + 1e-5, filter_comp_info, bias_comp_info +#endif // USE_TFLM_COMPRESSION + ); } TfLiteStatus TestTransposeConvQuantized( @@ -135,8 +292,8 @@ TfLiteStatus TestTransposeConvQuantized( int* bias_dims_data, const float* bias_data, int32_t* bias_quantized, float* bias_scales, int* bias_zero_points, int* output_dims_data, const float* expected_output_data, int8_t* expected_output_quantized, - float output_scale, int output_zero_point, TfLiteConvParams* conv_params, - int8_t* output_data) { + float output_scale, int output_zero_point, + const TfLiteConvParams* conv_params, int8_t* output_data) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); @@ -181,8 +338,8 @@ TfLiteStatus TestTransposeConvQuantized( int* bias_dims_data, const float* bias_data, T* bias_quantized, float* bias_scales, int* bias_zero_points, int* output_dims_data, const float* expected_output_data, int16_t* expected_output_quantized, - float output_scale, int output_zero_point, TfLiteConvParams* conv_params, - int16_t* output_data) { + float output_scale, int output_zero_point, + const TfLiteConvParams* conv_params, int16_t* output_data) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); @@ -221,6 +378,80 @@ TfLiteStatus TestTransposeConvQuantized( conv_params, output_data, 4.0f); } +#ifdef USE_TFLM_COMPRESSION + +template +TfLiteStatus TestTransposeConvQuantizedCompressed( + int* input_dims_data, const float* input_data, TIO* input_quantized, + float input_scale, int input_zero_point, int* output_dims_data, + const float* expected_output_data, TIO* expected_output_quantized, + TIO* output_quantized, float output_scale, int output_zero_point, + const TfLiteConvParams* conv_params, const unsigned int tolerance, + const TestCompressionQuantizedInfo* filter_comp_info, + const TestCompressionQuantizedInfo* bias_comp_info) { + // TODO(b/358151309): account for optional bias tensor + // bool null_bias = comp_info->bias_data == nullptr ? true : false; + + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInts(filter_comp_info->dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_comp_info->dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + + TfLiteFloatArray* filter_scales = + FloatArrayFromFloats(filter_comp_info->scales); + TfLiteIntArray* filter_zero_points = + IntArrayFromInts(filter_comp_info->zero_points); + TfLiteFloatArray* bias_scales = FloatArrayFromFloats(bias_comp_info->scales); + TfLiteIntArray* bias_zero_points = + IntArrayFromInts(bias_comp_info->zero_points); + + TfLiteAffineQuantization filter_quant = {}; + TfLiteTensor filter_tensor = CreatePerChannelQuantizedTensor( + filter_comp_info->compressed, filter_dims, filter_scales, + filter_zero_points, &filter_quant, kTransposeConvQuantizedDimension, + false /* is_variable */, kTfLiteInt8); + SymmetricPerChannelQuantize( + filter_comp_info->data, filter_comp_info->value_table, + filter_scales->size * filter_comp_info->value_table_stride, + filter_scales->size, filter_scales->data); + + TfLiteAffineQuantization bias_quant = {}; + TfLiteTensor bias_tensor = CreatePerChannelQuantizedBiasTensor( + bias_comp_info->compressed, bias_dims, input_scale, filter_scales, + bias_scales, bias_zero_points, &bias_quant, + kTransposeConvQuantizedDimension, false /* is_variable */, + typeToTfLiteType()); + SymmetricPerChannelQuantize( + bias_comp_info->data, bias_comp_info->value_table, + bias_scales->size * bias_comp_info->value_table_stride, bias_scales->size, + bias_scales->data); + + int output_shape_dims_data[] = {1, 0}; + int32_t* output_shape = nullptr; + TfLiteIntArray* output_shape_dims = IntArrayFromInts(output_shape_dims_data); + + constexpr int tensors_size = kTransposeConvMaxTensors; + TfLiteTensor tensors[tensors_size] = { + CreateTensor(output_shape, output_shape_dims), + filter_tensor, + CreateQuantizedTensor(input_data, input_quantized, input_dims, + input_scale, input_zero_point), + bias_tensor, + CreateQuantizedTensor(output_quantized, output_dims, output_scale, + output_zero_point), + }; + + const int output_dims_count = ElementCount(*output_dims); + Quantize(expected_output_data, expected_output_quantized, output_dims_count, + output_scale, output_zero_point); + return ValidateTransposeConvGoldens( + tensors, tensors_size, expected_output_quantized, output_dims_count, + conv_params, output_quantized, tolerance, filter_comp_info, + bias_comp_info); +} + +#endif // USE_TFLM_COMPRESSION + } // namespace } // namespace testing } // namespace tflite @@ -240,6 +471,41 @@ TF_LITE_MICRO_TEST(SimpleTestFloat) { &tflite::testing::common_conv_params, output_data)); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestFloatCompressed) { + tflite::testing::TestCompressionInfo filter_comp_info = {}; + tflite::testing::TestCompressionInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = tflite::testing::kFilterData; + filter_comp_info.value_table_stride = + std::extent::value; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidth; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = tflite::testing::kBiasData; + bias_comp_info.value_table_stride = + std::extent::value; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidth; + + float output_data[tflite::testing::kOutputElements]; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestTransposeConvFloat( + tflite::testing::kInputShape, tflite::testing::kInputData, + tflite::testing::kFilterShape, + reinterpret_cast(tflite::testing::kBinQuantFilterData), + tflite::testing::kBiasShape, + reinterpret_cast(tflite::testing::kBinQuantBiasData), + tflite::testing::kOutputShape, tflite::testing::kGoldenData, + &tflite::testing::common_conv_params, output_data, &filter_comp_info, + &bias_comp_info)); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(fusedRELUTest) { float output_data[tflite::testing::kOutputElements]; float golden_data[] = {29, 24, 0, 0, 99, 72, 0, 0, @@ -476,4 +742,202 @@ TF_LITE_MICRO_TEST(HybridModeIsError) { &tflite::testing::common_conv_params, output_data)); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelSingleChannelCompressed) { + // data from TfLite test: SimpleBiasTestQuantizedPerChannelSingleChannel + const float input_scale = 16.0f / 255.0f; + const float output_scale = 2.0f; + const int input_zero_point = -128; + const int output_zero_point = -128; + constexpr float filter_scales[] = { + tflite::testing::kNumChannelsQ1, + 9.0f / 127.0f, + }; + constexpr int filter_zero_points[] = { + tflite::testing::kNumChannelsQ1, + 0, + }; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int8_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kFilterElementsQ1]; + int32_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int8_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int8_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kFilterElementsQ1 / tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kFilterDataQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestTransposeConvQuantizedCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1, + golden_quantized, output_quantized, output_scale, output_zero_point, + &tflite::testing::common_conv_params, 0, &filter_comp_info, + &bias_comp_info)); +} + +TF_LITE_MICRO_TEST( + SimpleBiasTestQuantizedPerChannelBias16MultiChannelCompressed) { + // data from TfLite test: SimpleBiasTestQuantizedPerChannel16x8Bias64 + const float input_scale = 4.0f / 127.0f; + const float output_scale = 128.0f / 65536.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + constexpr float filter_scales[] = { + tflite::testing::kNumChannelsQ2, + 7.0f / 127.0f, + 8.0f / 127.0f, + }; + constexpr int filter_zero_points[] = { + tflite::testing::kNumChannelsQ2, + 0, + 0, + }; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int16_t input_quantized[tflite::testing::kInputElementsQ2]; + int8_t filter_quantized[tflite::testing::kBinQuantFilterValueTableElementsQ2]; + int16_t bias_quantized[tflite::testing::kBiasElementsQ2]; + int16_t golden_quantized[tflite::testing::kOutputElementsQ2]; + int16_t output_quantized[tflite::testing::kOutputElementsQ2]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kBinQuantFilterValueTableElementsQ2 / + tflite::testing::kNumChannelsQ2; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ2; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ2; + filter_comp_info.data = tflite::testing::kBinQuantFilterValueTableQ2; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ2; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ2 / tflite::testing::kNumChannelsQ2; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ2; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ2; + bias_comp_info.data = tflite::testing::kBiasDataQ2; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ2; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + // The quantized output is compared to the expected output (quantized). + // A tolerance of 81 is approx. 0.1582f which is less than the TfLite + // tolerance of 0.19f. + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestTransposeConvQuantizedCompressed( + tflite::testing::kInputShapeQ2, tflite::testing::kInputDataQ2, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ2, tflite::testing::kGoldenDataQ2, + golden_quantized, output_quantized, output_scale, output_zero_point, + &tflite::testing::common_conv_params, 81, &filter_comp_info, + &bias_comp_info)); +} + +TF_LITE_MICRO_TEST( + SimpleBiasTestQuantizedPerChannelBias64MultiChannelCompressed) { + // data from TfLite test: SimpleBiasTestQuantizedPerChannel16x8Bias64 + const float input_scale = 4.0f / 127.0f; + const float output_scale = 128.0f / 65536.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + constexpr float filter_scales[] = { + tflite::testing::kNumChannelsQ2, + 7.0f / 127.0f, + 8.0f / 127.0f, + }; + constexpr int filter_zero_points[] = { + tflite::testing::kNumChannelsQ2, + 0, + 0, + }; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int16_t input_quantized[tflite::testing::kInputElementsQ2]; + int8_t filter_quantized[tflite::testing::kBinQuantFilterValueTableElementsQ2]; + int64_t bias_quantized[tflite::testing::kBiasElementsQ2]; + int16_t golden_quantized[tflite::testing::kOutputElementsQ2]; + int16_t output_quantized[tflite::testing::kOutputElementsQ2]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kBinQuantFilterValueTableElementsQ2 / + tflite::testing::kNumChannelsQ2; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ2; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ2; + filter_comp_info.data = tflite::testing::kBinQuantFilterValueTableQ2; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ2; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ2 / tflite::testing::kNumChannelsQ2; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ2; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ2; + bias_comp_info.data = tflite::testing::kBiasDataQ2; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ2; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + // The quantized output is compared to the expected output (quantized). + // A tolerance of 81 is approx. 0.1582f which is less than the TfLite + // tolerance of 0.19f. + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestTransposeConvQuantizedCompressed( + tflite::testing::kInputShapeQ2, tflite::testing::kInputDataQ2, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ2, tflite::testing::kGoldenDataQ2, + golden_quantized, output_quantized, output_scale, output_zero_point, + &tflite::testing::common_conv_params, 81, &filter_comp_info, + &bias_comp_info)); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/xtensa/conv.cc b/tensorflow/lite/micro/kernels/xtensa/conv.cc index 384dba9f7ac..5eb7a1bb7d4 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -52,14 +52,34 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { switch (input->type) { case kTfLiteFloat32: { +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION tflite::reference_ops::Conv( ConvParamsFloat(params, op_data.reference_op_data), tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + op_data.reference_op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + op_data.reference_op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr); diff --git a/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc b/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc index 1d2d7ec253e..965fce23167 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc @@ -145,9 +145,29 @@ TfLiteStatus ConvEvalHifiInt16(TfLiteContext* context, TfLiteNode* node, const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + const int16_t* input_data = tflite::micro::GetTensorData(input); +#ifdef USE_TFLM_COMPRESSION + const int8_t* filter_data = tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.reference_op_data.weights_scratch_index); + const int64_t* bias_data = tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + data.reference_op_data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION const int8_t* filter_data = tflite::micro::GetTensorData(filter); const int64_t* bias_data = tflite::micro::GetTensorData(bias); +#endif // USE_TFLM_COMPRESSION int16_t* output_data = tflite::micro::GetTensorData(output); int output_data_format = 0; @@ -179,7 +199,6 @@ TfLiteStatus ConvEvalHifiInt16(TfLiteContext* context, TfLiteNode* node, } else { void* p_scratch = static_cast( context->GetScratchBuffer(context, data.scratch_tensor_index)); - for (int batch = 0; batch < batches; ++batch) { int16_t* p_out_temp; p_out_temp = &output_data[batch * out_length]; @@ -243,8 +262,25 @@ TfLiteStatus ConvEvalHifiInt8(TfLiteContext* context, TfLiteNode* node, const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + const int8_t* input_data = tflite::micro::GetTensorData(input); +#ifdef USE_TFLM_COMPRESSION + const int32_t* bias_data = tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + data.reference_op_data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION const int32_t* bias_data = tflite::micro::GetTensorData(bias); +#endif // USE_TFLM_COMPRESSION int8_t* output_data = tflite::micro::GetTensorData(output); const int8_t* filter_data; @@ -257,7 +293,13 @@ TfLiteStatus ConvEvalHifiInt8(TfLiteContext* context, TfLiteNode* node, tflite::micro::GetTensorShape(filter).FlatSize(), unpacked_filter_data); filter_data = unpacked_filter_data; } else { +#ifdef USE_TFLM_COMPRESSION + filter_data = tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.reference_op_data.weights_scratch_index); +#else // USE_TFLM_COMPRESSION filter_data = tflite::micro::GetTensorData(filter); +#endif // USE_TFLM_COMPRESSION } int output_data_format = 0; diff --git a/tensorflow/lite/micro/kernels/xtensa/conv_int16_reference.cc b/tensorflow/lite/micro/kernels/xtensa/conv_int16_reference.cc index 2492d4b348b..0f583cdaceb 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv_int16_reference.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv_int16_reference.cc @@ -45,6 +45,17 @@ TfLiteStatus ConvReferenceEvalInt16(TfLiteContext* context, TfLiteNode* node) { ? tflite::micro::GetEvalInput(context, node, kConvBiasTensor) : nullptr; +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + if (bias == nullptr || bias->type == kTfLiteInt32) { reference_integer_ops::ConvPerChannel( ConvParamsQuantized(params, op_data), @@ -52,9 +63,18 @@ TfLiteStatus ConvReferenceEvalInt16(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); } else if (bias->type == kTfLiteInt64) { @@ -64,9 +84,18 @@ TfLiteStatus ConvReferenceEvalInt16(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); } else { diff --git a/tensorflow/lite/micro/kernels/xtensa/conv_int8_reference.cc b/tensorflow/lite/micro/kernels/xtensa/conv_int8_reference.cc index 6ac07bab403..ba746f0ff8f 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv_int8_reference.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv_int8_reference.cc @@ -45,6 +45,17 @@ TfLiteStatus ConvReferenceEvalInt8(TfLiteContext* context, TfLiteNode* node) { ? tflite::micro::GetEvalInput(context, node, kConvBiasTensor) : nullptr; +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + const int8_t* filter_data; if (filter->type == kTfLiteInt4) { int8_t* unpacked_filter_data = static_cast( @@ -54,7 +65,12 @@ TfLiteStatus ConvReferenceEvalInt8(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(filter).FlatSize(), unpacked_filter_data); filter_data = unpacked_filter_data; } else { +#ifdef USE_TFLM_COMPRESSION + filter_data = tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, op_data.weights_scratch_index); +#else // USE_TFLM_COMPRESSION filter_data = tflite::micro::GetTensorData(filter); +#endif // USE_TFLM_COMPRESSION } reference_integer_ops::ConvPerChannel( @@ -64,7 +80,12 @@ TfLiteStatus ConvReferenceEvalInt8(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), filter_data, tflite::micro::GetTensorShape(bias), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); diff --git a/tensorflow/lite/micro/kernels/xtensa/conv_vision.cc b/tensorflow/lite/micro/kernels/xtensa/conv_vision.cc index 812ab60ebf2..8a0330907c3 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv_vision.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv_vision.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -104,6 +104,58 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) { filter_int8 = *filter; } +#ifdef USE_TFLM_COMPRESSION + + uint8_t* filter_data = nullptr; + int32_t* bias_data = nullptr; + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + if (filter_comp_td != nullptr) { + const size_t filter_data_size = + NumElements(&filter_int8) * TfLiteTypeGetSize(kTfLiteInt8); + filter_data = + micro_context->AllocateTempBuffer(filter_data_size, sizeof(int8_t)); + if (filter_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* filter_eval = + tflite::micro::GetEvalInput(context, node, kConvWeightsTensor); + filter_data = static_cast(micro_context->DecompressTensorToBuffer( + *filter_eval, *filter_comp_td, filter_data)); + } else { + filter_data = GetTensorData(&filter_int8); + } + + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + if (bias_comp_td != nullptr) { + const size_t bias_data_size = + NumElements(bias) * TfLiteTypeGetSize(kTfLiteInt32); + bias_data = reinterpret_cast( + micro_context->AllocateTempBuffer(bias_data_size, sizeof(int32_t))); + if (bias_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* bias_eval = + tflite::micro::GetEvalInput(context, node, kConvBiasTensor); + bias_data = static_cast(micro_context->DecompressTensorToBuffer( + *bias_eval, *bias_comp_td, bias_data)); + } else { + bias_data = GetTensorData(bias); + } + + if (filter_data == nullptr || bias_data == nullptr) { + return kTfLiteError; + } + +#else // USE_TFLM_COMPRESSION + + uint8_t* filter_data = GetTensorData(&filter_int8); + int32_t* bias_data = GetTensorData(bias); + +#endif // USE_TFLM_COMPRESSION + status = xiConvSetContext( data->p_context, data->context_size, input_depth, input_width, input_height, output_depth, output_width, output_height, filter_width, @@ -112,8 +164,7 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) { data->reference_op_data.output_multiplier, data->reference_op_data.output_shift, data->reference_op_data.output_activation_min, - data->reference_op_data.output_activation_max, - (uint8_t*)GetTensorData(&filter_int8), + data->reference_op_data.output_activation_max, filter_data, data->reference_op_data.padding.width, data->reference_op_data.padding.height); if (status) { @@ -138,9 +189,7 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) { status = xiConvDoCoeffReorder( data->p_context, data->context_size, reinterpret_cast(data->reorder_coefficient_bias), - data->reorder_coefficient_bias_size, - const_cast(GetTensorData(&filter_int8)), - const_cast(GetTensorData(bias))); + data->reorder_coefficient_bias_size, filter_data, bias_data); if (status) { return kTfLiteError; } @@ -149,6 +198,17 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) { micro_context->DeallocateTempBuffer(GetTensorData(&filter_int8)); } +#ifdef USE_TFLM_COMPRESSION + + if (filter_comp_td) { + micro_context->DeallocateTempBuffer(filter_data); + } + if (bias_comp_td) { + micro_context->DeallocateTempBuffer(reinterpret_cast(bias_data)); + } + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); diff --git a/tensorflow/lite/micro/kernels/xtensa/decompress.cc b/tensorflow/lite/micro/kernels/xtensa/decompress.cc new file mode 100644 index 00000000000..13d2ce2dec7 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/decompress.cc @@ -0,0 +1,711 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef USE_TFLM_COMPRESSION + +#include "tensorflow/lite/micro/kernels/decompress.h" + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/micro/micro_common.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" +#include "tensorflow/lite/micro/micro_utils.h" + +#ifdef HIFI5 +#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" +#endif // HIFI5 + +namespace tflite { +namespace { + +#ifdef HIFI5 + +struct DecompressionStateXtensa : DecompressionState { + DecompressionStateXtensa() = delete; + + DecompressionStateXtensa(const DecompressionState& other) + : DecompressionState(other) {} + + void DecompressToBufferWidth4_Xtensa(int8_t* buffer); + void DecompressToBufferWidth3_Xtensa(int8_t* buffer); + void DecompressToBufferWidth2_Xtensa(int8_t* buffer); + + void DecompressToBufferWidthAnyInt8_Xtensa(int8_t* buffer); + void DecompressToBufferWidthAnyInt16_Xtensa(int16_t* buffer); + void DecompressToBufferWidthAnyInt32_Xtensa(int32_t* buffer); + void DecompressToBufferWidthAnyInt64_Xtensa(int64_t* buffer); +}; + +void DecompressionStateXtensa::DecompressToBufferWidth4_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int8x8 d_shuffle_t = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + int elements_per_channel_t_by_4 = elements_per_channel_ >> 4; + int elements_per_channel_t_rem = elements_per_channel_ & 0xF; + int j; + + ae_int8x8 d_out1, d_out2; + ae_int8x8 d_value_0_t, d_value_1_t; + ae_int8x8 d_value_0, d_value_1; + ae_int8x8 d_index, d_dummy; + + ae_int8x8* __restrict pIn_tmp = (ae_int8x8*)compressed_indices_; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint8_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + const uint8_t* __restrict value_table_t = value_table; + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (size_t i = 0; i < num_channels_; i++) { + value_table_t = value_table; + ae_valignx2 align_vtab = AE_LA128_PP(value_table_t); + AE_LA8X8X2_IP(d_value_0_t, d_value_1_t, align_vtab, + (ae_int8x16*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0_t, d_value_1_t, + d_shuffle_value_t); + + ae_valign align_load = AE_LA64_PP(pIn_tmp); + + for (j = 0; j < elements_per_channel_t_by_4; j++) { + AE_LA8X8_IP(d_index, align_load, pIn_tmp); + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + + value_table += stride; + if (elements_per_channel_t_rem) { + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + (elements_per_channel_t_rem >> + 1)); /* Loading 48 bits for decoding 16 weight values */ + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + } + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void DecompressionStateXtensa::DecompressToBufferWidth3_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + int i, j; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + ae_int8x8* pIn_tmp = (ae_int8x8*)compressed_indices_; + const uint8_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + const uint8_t* __restrict value_table_t = value_table; + + int num_channels_t = num_channels_; + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + + int elements_per_channel_t_by_4 = elements_per_channel_ >> 4; + int elements_per_channel_t_rem = elements_per_channel_ & 0xF; + + ae_int8x8 d_index, d_dummy; + ae_int8x8 d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11; + ae_int8x8 d_out1, d_out2; + + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + ae_int8x8 d_shuffle_t1 = AE_MOVINT8X8_FROMINT64(0x0F00050C00020000LL); + ae_int8x8 d_shuffle_t2 = AE_MOVINT8X8_FROMINT64(0x000E00040B000100LL); + ae_int8x8 d_shuffle_t3 = AE_MOVINT8X8_FROMINT64(0x0F060D040C030A01LL); + ae_int8x8 d_shuffle_t = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (i = 0; i < num_channels_t; i++) { + ae_int8x8 d_value_0 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + ae_int8x8 d_value_1 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + + value_table_t = value_table; + + ae_valign align_vtab = AE_LA64_PP(value_table_t); + AE_LA8X8_IP(d_value_0, align_vtab, (ae_int8x8*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0, d_value_1, d_shuffle_value_t); + + for (j = 0; j < elements_per_channel_t_by_4; j++) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 6); /* Loading 48 bits for decoding 16 weight values */ + + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 1)); + d2 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d3 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 3)); + d4 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 4)); + + d1 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), 0x7007007007000000LL)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d2), 0x0700700700700000LL)); + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d3), 0x0070070070070000LL)); + d4 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d4), 0x0007007007007000LL)); + + d5 = d1 | d2; + d6 = d3 | d4; + + d7 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d5), 4)); + d8 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d6), 4)); + + d9 = AE_SEL8X8(d5, d7, d_shuffle_t1); + d10 = AE_SEL8X8(d6, d8, d_shuffle_t2); + d11 = AE_SEL8X8(d9, d10, d_shuffle_t3); + + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d11); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + if (elements_per_channel_t_rem) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 3); /* Loading 48 bits for decoding 16 weight values */ + + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 1)); + d2 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d3 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 3)); + d4 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 4)); + + d1 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), 0x7007007007000000LL)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d2), 0x0700700700700000LL)); + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d3), 0x0070070070070000LL)); + d4 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d4), 0x0007007007007000LL)); + + d5 = d1 | d2; + d6 = d3 | d4; + + d7 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d5), 4)); + d8 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d6), 4)); + + d9 = AE_SEL8X8(d5, d7, d_shuffle_t1); + d10 = AE_SEL8X8(d6, d8, d_shuffle_t2); + d11 = AE_SEL8X8(d9, d10, d_shuffle_t3); + + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d11); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + } + + value_table = value_table + stride; + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void DecompressionStateXtensa::DecompressToBufferWidth2_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + int i, j; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + ae_int8x8* pIn_tmp = (ae_int8x8*)compressed_indices_; + const uint8_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + const uint8_t* __restrict value_table_t = value_table; + + int num_channels_t = num_channels_; + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + + int elements_per_channel_t_by_5 = elements_per_channel_ >> 5; + int elements_per_channel_t_rem = elements_per_channel_ & 0x1F; + int elements_per_channel_t_rem_minus_16 = 0; + if (elements_per_channel_t_rem > 16) { + elements_per_channel_t_rem_minus_16 = elements_per_channel_t_rem - 16; + } + + ae_int8x8 d_index, d_dummy; + ae_int8x8 d0, d1, d2, d3, d4, d5; + ae_int8x8 q0, q1, q2, q3; + ae_int8x8 d_out1, d_out2; + + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + ae_int8x8 d_shuffle_t1 = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + ae_int8x8 d_shuffle_t2 = AE_MOVINT8X8_FROMINT64(0xFBEA7362D9C85140LL); + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (i = 0; i < num_channels_t; i++) { + ae_int8x8 d_value_0 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + ae_int8x8 d_value_1 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + + value_table_t = value_table; + + ae_valign align_vtab = AE_LA64_PP(value_table_t); + AE_LA8X8_IP(d_value_0, align_vtab, (ae_int8x8*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0, d_value_1, d_shuffle_value_t); + + for (j = 0; j < elements_per_channel_t_by_5; j++) { + // AE_LA8X8_IP( d_index, align_index, pIn_tmp ); /* Loading 64 bits + // for decoding 32 weight values */ + + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 8); /* Loading 64 bits for decoding 32 weight values */ + d0 = d_index; + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d0), + 0x3333333333333333LL)); // i1,i3,i5, .... + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), + 0x3333333333333333LL)); // i0,i2,i4, .... + + AE_DSEL8X8(d4, d5, d3, d2, + d_shuffle_t1); // d4 = i0,i2,i1,i3,i4,i6,... d5 = + // i16,i18, i17,i19, .... + + AE_DSEL8X8(q0, q1, d_value_0, d_value_1, + d4); // q0 = 0,1,4,5,8,9,12,13 q1 = 2,3,6,7,10,11,14,15 + AE_DSEL8X8( + q2, q3, d_value_0, d_value_1, + d5); // q2 = 16,17,20,21,24,25,28,29 q3 = 18,19,22,23,26,27,30,31 + + AE_DSEL8X8(d_out1, d_out2, q0, q1, d_shuffle_t2); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + + AE_DSEL8X8(d_out1, d_out2, q2, q3, d_shuffle_t2); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + if (elements_per_channel_t_rem) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + (elements_per_channel_t_rem >> + 2)); /* Loading 48 bits for decoding 16 weight values */ + d0 = d_index; + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d0), + 0x3333333333333333LL)); // i1,i3,i5, .... + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), + 0x3333333333333333LL)); // i0,i2,i4, .... + + AE_DSEL8X8(d4, d5, d3, d2, + d_shuffle_t1); // d4 = i0,i2,i1,i3,i4,i6,... d5 = + // i16,i18, i17,i19, .... + + AE_DSEL8X8(q0, q1, d_value_0, d_value_1, + d4); // q0 = 0,1,4,5,8,9,12,13 q1 = 2,3,6,7,10,11,14,15 + AE_DSEL8X8( + q2, q3, d_value_0, d_value_1, + d5); // q2 = 16,17,20,21,24,25,28,29 q3 = 18,19,22,23,26,27,30,31 + + AE_DSEL8X8(d_out1, d_out2, q0, q1, d_shuffle_t2); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + + AE_DSEL8X8(d_out1, d_out2, q2, q3, d_shuffle_t2); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem_minus_16); + } + + value_table = value_table + stride; + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void DecompressionStateXtensa::DecompressToBufferWidthAnyInt8_Xtensa( + int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint8_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (comp_data_.data.lut_data->use_alternate_axis) { + int count = count_indices_; + const uint8_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + uint32_t index_1, index_2; + uint32_t mask_bits = (1 << compressed_bit_width_) - 1; + + for (int i = 0; i < num_channels_t; i++) { + elements_per_channel_t = elements_per_channel_; + /* if output pointer is not 2 byte aligned */ + if ((unsigned int)p_out_tmp & 0x1) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + elements_per_channel_t = elements_per_channel_t - 1; + } + for (int j = 0; j < (elements_per_channel_t >> 1); j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, 2 * bw); + index_1 = (index >> compressed_bit_width_) & mask_bits; + index_2 = (index)&mask_bits; + ae_int8x8 d_tmp1 = AE_L8_X((const ae_int8*)value_table, index_1); + ae_int8x8 d_tmp2 = AE_L8_X((const ae_int8*)value_table, index_2); + ae_int16x4 d_tmp = + AE_MOVINT16X4_FROMINT8X8(AE_SEL8X8I(d_tmp2, d_tmp1, 21)); + AE_S16_0_IP(d_tmp, (ae_int16*)p_out_tmp, 2); + } + if (elements_per_channel_t & 0x1) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + } + value_table += stride; + } + } +} + +void DecompressionStateXtensa::DecompressToBufferWidthAnyInt16_Xtensa( + int16_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint16_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int16* __restrict p_out_tmp = (ae_int16*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (comp_data_.data.lut_data->use_alternate_axis) { + int count = count_indices_; + const uint16_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int16x4 d_tmp = AE_L16_X((const ae_int16*)value_table, index << 1); + AE_S16_0_IP(d_tmp, p_out_tmp, 2); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int16x4 d_tmp = AE_L16_X((const ae_int16*)value_table, index << 1); + AE_S16_0_IP(d_tmp, p_out_tmp, 2); + } + + value_table += stride; + } + } +} + +void DecompressionStateXtensa::DecompressToBufferWidthAnyInt32_Xtensa( + int32_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint32_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int32* __restrict p_out_tmp = (ae_int32*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (comp_data_.data.lut_data->use_alternate_axis) { + int count = count_indices_; + const uint32_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int32x2 d_tmp = AE_L32_X((const ae_int32*)value_table, index << 2); + AE_S32_L_IP(d_tmp, p_out_tmp, 4); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int32x2 d_tmp = AE_L32_X((const ae_int32*)value_table, index << 2); + AE_S32_L_IP(d_tmp, p_out_tmp, 4); + } + + value_table += stride; + } + } +} + +void DecompressionStateXtensa::DecompressToBufferWidthAnyInt64_Xtensa( + int64_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint64_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int64* __restrict p_out_tmp = (ae_int64*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (comp_data_.data.lut_data->use_alternate_axis) { + int count = count_indices_; + const uint64_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int64 d_tmp = AE_L64_X((const ae_int64*)value_table, index << 3); + AE_S64_IP(d_tmp, p_out_tmp, 8); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int64 d_tmp = AE_L64_X((const ae_int64*)value_table, index << 3); + AE_S64_IP(d_tmp, p_out_tmp, 8); + } + + value_table += stride; + } + } +} + +#endif // HIFI5 + +} // namespace + +#ifdef HIFI5 + +template <> +bool* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); + + return static_cast(buffer); +} + +template <> +int8_t* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + if (comp_data_.data.lut_data->compressed_bit_width == 4 && + !comp_data_.data.lut_data->use_alternate_axis) { + if (!(elements_per_channel_ & 0x01)) { + dsx.DecompressToBufferWidth4_Xtensa(static_cast(buffer)); + } else { + dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); + } + } else if (comp_data_.data.lut_data->compressed_bit_width == 3 && + !comp_data_.data.lut_data->use_alternate_axis) { + if (!(elements_per_channel_ & 0x07)) { + dsx.DecompressToBufferWidth3_Xtensa(static_cast(buffer)); + } else { + dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); + } + } else if (comp_data_.data.lut_data->compressed_bit_width == 2 && + !comp_data_.data.lut_data->use_alternate_axis) { + if (!(elements_per_channel_ & 0x03)) { + dsx.DecompressToBufferWidth2_Xtensa(static_cast(buffer)); + } else { + dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); + } + } else { + dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); + } + + return static_cast(buffer); +} + +template <> +int16_t* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + dsx.DecompressToBufferWidthAnyInt16_Xtensa(static_cast(buffer)); + + return static_cast(buffer); +} + +template <> +int32_t* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + dsx.DecompressToBufferWidthAnyInt32_Xtensa(static_cast(buffer)); + + return static_cast(buffer); +} + +template <> +float* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + dsx.DecompressToBufferWidthAnyInt32_Xtensa(static_cast(buffer)); + + return static_cast(buffer); +} + +template <> +int64_t* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + dsx.DecompressToBufferWidthAnyInt64_Xtensa(static_cast(buffer)); + + return static_cast(buffer); +} + +#else // HIFI5 + +template +T* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 4 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth4_16(static_cast(buffer)); + } else if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 3 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth3_32(static_cast(buffer)); + } else if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 2 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth2_16(static_cast(buffer)); + } else { + DecompressToBufferWidthAny(static_cast(buffer)); + } + + return static_cast(buffer); +} + +template bool* DecompressionState::DecompressToBuffer(void*); +template float* DecompressionState::DecompressToBuffer(void*); +template int8_t* DecompressionState::DecompressToBuffer(void*); +template int16_t* DecompressionState::DecompressToBuffer(void*); +template int32_t* DecompressionState::DecompressToBuffer(void*); +template int64_t* DecompressionState::DecompressToBuffer(void*); + +#endif // HIFI5 + +} // namespace tflite + +#endif // USE_TFLM_COMPRESSION diff --git a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc index 8536ff79507..088420aee17 100644 --- a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc @@ -1,5 +1,5 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -93,6 +93,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteEvalTensor filter_int8 = tflite::micro::MakeUnpackedInt4Tensor( context, op_data.reference_op_data.filter_buffer_index, filter); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kDepthwiseConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + switch (input->type) { // Already know in/out types are same. case kTfLiteInt8: { switch (filter_int8.type) { @@ -111,9 +123,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, &filter_int8, filter_comp_td, + op_data.reference_op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + op_data.reference_op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(&filter_int8), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); #endif // defined(HIFI3) || defined(HIFI4) || defined(HIFI5) @@ -136,9 +158,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, &filter_int8, filter_comp_td, + op_data.reference_op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + op_data.reference_op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(&filter_int8), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; diff --git a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc index 8c2052b23e7..9bff6d997a3 100644 --- a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc +++ b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -97,10 +97,22 @@ TfLiteStatus DepthwiseConvEvalHifi(TfLiteContext* context, TfLiteNode* node, const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) { +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kDepthwiseConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + // If dilation is not required use the optimized NN Library kernel. // Otherwise call the reference implementation. if ((params.dilation_width_factor == 1) && - (params.dilation_height_factor == 1)) { + (params.dilation_height_factor == 1) && bias != nullptr) { const int stride_width = params.stride_width; const int stride_height = params.stride_height; const int pad_width = data.reference_op_data.padding.width; @@ -133,8 +145,17 @@ TfLiteStatus DepthwiseConvEvalHifi(TfLiteContext* context, TfLiteNode* node, TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); const int8_t* input_data = tflite::micro::GetTensorData(input); +#ifdef USE_TFLM_COMPRESSION + const int8_t* filter_data = tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, + data.reference_op_data.weights_scratch_index); + const int32_t* bias_data = tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + data.reference_op_data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION const int8_t* filter_data = tflite::micro::GetTensorData(filter); const int32_t* bias_data = tflite::micro::GetTensorData(bias); +#endif // USE_TFLM_COMPRESSION int8_t* output_data = tflite::micro::GetTensorData(output); int32_t input_data_format = 0; @@ -178,9 +199,19 @@ TfLiteStatus DepthwiseConvEvalHifi(TfLiteContext* context, TfLiteNode* node, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, + data.reference_op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + data.reference_op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); diff --git a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc index 35fa8cf1c1a..23e18dc8342 100644 --- a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc +++ b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -53,7 +53,7 @@ TfLiteStatus DepthwiseConvPrepareVision(TfLiteContext* context, TF_LITE_ENSURE(context, filter != nullptr); TfLiteTensor* bias = micro_context->AllocateTempInputTensor(node, kDepthwiseConvBiasTensor); - TF_LITE_ENSURE(context, filter != nullptr); + TF_LITE_ENSURE(context, bias != nullptr); // Dynamically allocate per-channel quantization parameters. const int num_channels = @@ -135,18 +135,81 @@ TfLiteStatus DepthwiseConvPrepareVision(TfLiteContext* context, filter_int8 = *filter; } +#ifdef USE_TFLM_COMPRESSION + + uint8_t* filter_data = nullptr; + int32_t* bias_data = nullptr; + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kDepthwiseConvWeightsTensor); + if (filter_comp_td != nullptr) { + const size_t filter_data_size = + NumElements(&filter_int8) * TfLiteTypeGetSize(kTfLiteInt8); + filter_data = + micro_context->AllocateTempBuffer(filter_data_size, sizeof(int8_t)); + if (filter_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* filter_eval = + tflite::micro::GetEvalInput(context, node, kDepthwiseConvWeightsTensor); + filter_data = static_cast(micro_context->DecompressTensorToBuffer( + *filter_eval, *filter_comp_td, filter_data)); + } else { + filter_data = GetTensorData(&filter_int8); + } + + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor); + if (bias_comp_td != nullptr) { + const size_t bias_data_size = + NumElements(bias) * TfLiteTypeGetSize(kTfLiteInt32); + bias_data = reinterpret_cast( + micro_context->AllocateTempBuffer(bias_data_size, sizeof(int32_t))); + if (bias_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* bias_eval = + tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor); + bias_data = static_cast(micro_context->DecompressTensorToBuffer( + *bias_eval, *bias_comp_td, bias_data)); + } else { + bias_data = GetTensorData(bias); + } + + if (filter_data == nullptr || bias_data == nullptr) { + return kTfLiteError; + } + +#else // USE_TFLM_COMPRESSION + + uint8_t* filter_data = GetTensorData(&filter_int8); + int32_t* bias_data = GetTensorData(bias); + +#endif // USE_TFLM_COMPRESSION + status = xiDepthwiseConvDoCoeffReorder( data->p_context, data->context_size, reinterpret_cast(data->reorder_coefficient_bias), - data->reorder_coefficient_bias_size, - const_cast(GetTensorData(&filter_int8)), - const_cast(GetTensorData(bias))); + data->reorder_coefficient_bias_size, filter_data, bias_data); if (status) { return kTfLiteError; } if (filter->type == kTfLiteInt4) { micro_context->DeallocateTempBuffer(GetTensorData(&filter_int8)); } + +#ifdef USE_TFLM_COMPRESSION + + if (filter_comp_td) { + micro_context->DeallocateTempBuffer(filter_data); + } + if (bias_comp_td) { + micro_context->DeallocateTempBuffer(reinterpret_cast(bias_data)); + } + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc index df5458001b7..4a141784d8f 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,6 +45,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, + kFullyConnectedWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + TFLITE_DCHECK(node->user_data != nullptr); const auto& data = @@ -58,9 +70,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -93,9 +114,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc index cf87c5ff1ed..91b9f40c907 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -120,6 +120,23 @@ TfLiteStatus XtensaPrepareFullyConnected(TfLiteContext* context, context, params->activation, input->type, input, filter, bias, output, data)); +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + if (micro_context->IsTensorCompressed(node, kFullyConnectedWeightsTensor) && + filter->type == kTfLiteInt4) { + MicroPrintf("Compression not supported with INT4 tensors"); + return kTfLiteError; + } + data->weights_scratch_index = + micro_context->AllocateDecompressionScratchBuffer( + node, kFullyConnectedWeightsTensor); + data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer( + node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); if (bias != nullptr) { diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc index f850c0c0fca..32dfba2e5a8 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,12 +32,37 @@ TfLiteStatus XtensaEvalFullyConnectedQuantizedInt8( const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) { #if !defined(VISION_P6) + +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, + kFullyConnectedWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + const int32_t* bias_data = +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION tflite::micro::GetOptionalTensorData(bias); +#endif // USE_TFLM_COMPRESSION + + const int8_t* filter_data = +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, data.weights_scratch_index); +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter); +#endif // USE_TFLM_COMPRESSION // P6 Vision will handle INT4 filters as a reference operation. // For all other architectures, unpack INT4 here. - const int8_t* filter_data = tflite::micro::GetTensorData(filter); if (filter->type == kTfLiteInt4) { int8_t* unpacked_filter_data = static_cast( context->GetScratchBuffer(context, data.filter_buffer_index)); @@ -47,6 +72,7 @@ TfLiteStatus XtensaEvalFullyConnectedQuantizedInt8( tflite::micro::GetTensorShape(filter).FlatSize(), unpacked_filter_data); filter_data = unpacked_filter_data; } + #endif // !defined(VISION_P6) #if defined(HIFIMINI) diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc index 14bb9a12b15..24fd1258277 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -108,14 +108,66 @@ TfLiteStatus FullyConnectedPrepareVision(TfLiteContext* context, filter_int8 = *filter; } +#ifdef USE_TFLM_COMPRESSION + + uint8_t* filter_data = nullptr; + int32_t* bias_data = nullptr; + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kFullyConnectedWeightsTensor); + if (filter_comp_td != nullptr) { + const size_t filter_data_size = + NumElements(&filter_int8) * TfLiteTypeGetSize(kTfLiteInt8); + filter_data = + micro_context->AllocateTempBuffer(filter_data_size, sizeof(int8_t)); + if (filter_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* filter_eval = tflite::micro::GetEvalInput( + context, node, kFullyConnectedWeightsTensor); + filter_data = static_cast(micro_context->DecompressTensorToBuffer( + *filter_eval, *filter_comp_td, filter_data)); + } else { + filter_data = GetTensorData(&filter_int8); + } + + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); + if (bias_comp_td != nullptr) { + const size_t bias_data_size = + NumElements(bias) * TfLiteTypeGetSize(kTfLiteInt32); + bias_data = reinterpret_cast( + micro_context->AllocateTempBuffer(bias_data_size, sizeof(int32_t))); + if (bias_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* bias_eval = + tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor); + bias_data = static_cast(micro_context->DecompressTensorToBuffer( + *bias_eval, *bias_comp_td, bias_data)); + } else { + bias_data = GetTensorData(bias); + } + + if (filter_data == nullptr || (bias != nullptr && bias_data == nullptr)) { + return kTfLiteError; + } + +#else // USE_TFLM_COMPRESSION + + uint8_t* filter_data = GetTensorData(&filter_int8); + int32_t* bias_data = GetTensorData(bias); + +#endif // USE_TFLM_COMPRESSION + status = xiFullyConnectedSetContext( data->p_context, data->context_size, inputDims, outputDims, filterDims, 1, input->params.zero_point, filter->params.zero_point, output->params.zero_point, data->reference_op_data.output_multiplier, data->reference_op_data.output_shift, data->reference_op_data.output_activation_min, - data->reference_op_data.output_activation_max, - (uint8_t*)GetTensorData(&filter_int8)); + data->reference_op_data.output_activation_max, filter_data); if (status) { return kTfLiteError; @@ -139,9 +191,7 @@ TfLiteStatus FullyConnectedPrepareVision(TfLiteContext* context, status = xiFullyConnectedDoCoeffReorder( data->p_context, data->context_size, reinterpret_cast(data->reorder_coefficient_bias), - data->reorder_coefficient_bias_size, - const_cast(GetTensorData(&filter_int8)), - const_cast(GetTensorData(bias))); + data->reorder_coefficient_bias_size, filter_data, bias_data); if (status) { return kTfLiteError; } @@ -149,6 +199,18 @@ TfLiteStatus FullyConnectedPrepareVision(TfLiteContext* context, if (filter->type == kTfLiteInt4) { micro_context->DeallocateTempBuffer(GetTensorData(&filter_int8)); } + +#ifdef USE_TFLM_COMPRESSION + + if (filter_comp_td) { + micro_context->DeallocateTempBuffer(filter_data); + } + if (bias_comp_td) { + micro_context->DeallocateTempBuffer(reinterpret_cast(bias_data)); + } + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); diff --git a/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc b/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc index 44a9f86049c..d46ef0ba88a 100644 --- a/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -49,6 +49,14 @@ struct OpData { // A scratch buffer is required for quantized implementations. int scratch_buffer_index; +#ifdef USE_TFLM_COMPRESSION + + // scratch buffers for compressed tensors + int filter_scratch_index; + int bias_scratch_index; + +#endif // USE_TFLM_COMPRESSION + // TODO(b/192090531): Remove this once all 8x16 transpose conv models use // 64-bit biases. int bias_converted_buffer_index; @@ -268,6 +276,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { data->params.stride_width = params->stride_width; data->params.stride_height = params->stride_height; +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + data->filter_scratch_index = + micro_context->AllocateDecompressionScratchBuffer(node, kFilterTensor); + data->bias_scratch_index = + micro_context->AllocateDecompressionScratchBuffer(node, kBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); @@ -286,6 +305,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, kOutputTensor); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, kFilterTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kBiasTensor); + +#endif // USE_TFLM_COMPRESSION + TFLITE_DCHECK(node->user_data != nullptr); const OpData& data = *(static_cast(node->user_data)); @@ -309,9 +339,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { op_params, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, data.filter_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr); @@ -321,7 +359,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { int32_t* scratch_buffer = static_cast( context->GetScratchBuffer(context, data.scratch_buffer_index)); #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5) - if (bias->type == kTfLiteInt32) { + if (bias != nullptr && bias->type == kTfLiteInt32) { const RuntimeShape& input_shape = tflite::micro::GetTensorShape(input); const RuntimeShape& filter_shape = tflite::micro::GetTensorShape(filter); @@ -343,9 +381,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); const int8_t* input_data = tflite::micro::GetTensorData(input); +#ifdef USE_TFLM_COMPRESSION + const int8_t* filter_data = tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, data.filter_scratch_index); + const int32_t* bias_data = tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION const int8_t* filter_data = tflite::micro::GetTensorData(filter); const int32_t* bias_data = tflite::micro::GetTensorData(bias); +#endif // USE_TFLM_COMPRESSION int8_t* output_data = tflite::micro::GetTensorData(output); const int num_elements = output_shape.FlatSize(); @@ -369,9 +414,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { data.per_channel_output_shift, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.filter_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer); @@ -382,9 +436,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { data.per_channel_output_shift, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, data.filter_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer); @@ -396,20 +458,36 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { context->GetScratchBuffer(context, data.scratch_buffer_index)); // TODO(b/192090531): Remove this once all 8x16 transpose conv models use // 64-bit biases. - if (bias->type == kTfLiteInt16) { - std::int64_t* bias_converted_buffer = - static_cast(context->GetScratchBuffer( - context, data.bias_converted_buffer_index)); - for (int i = 0; i < tflite::micro::GetTensorShape(bias).FlatSize(); - i++) { - bias_converted_buffer[i] = bias->data.i16[i]; + if (bias == nullptr || bias->type == kTfLiteInt16) { + std::int64_t* bias_converted_buffer = nullptr; + if (bias != nullptr) { + bias_converted_buffer = + static_cast(context->GetScratchBuffer( + context, data.bias_converted_buffer_index)); + const int16_t* const bias_int16_data = +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION + static_cast(bias->data.data); +#endif // USE_TFLM_COMPRESSION + for (int i = 0; i < tflite::micro::GetTensorShape(bias).FlatSize(); + i++) { + bias_converted_buffer[i] = bias_int16_data[i]; + } } reference_integer_ops::TransposeConv( data.params, data.per_channel_output_multiplier, data.per_channel_output_shift, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.filter_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(bias), bias_converted_buffer, tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), @@ -438,9 +516,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const int output_width = output_shape.Dims(2); const int16_t* input_data = tflite::micro::GetTensorData(input); +#ifdef USE_TFLM_COMPRESSION + const int8_t* filter_data = tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, data.filter_scratch_index); + const int64_t* bias_data = tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION const int8_t* filter_data = tflite::micro::GetTensorData(filter); const int64_t* bias_data = tflite::micro::GetTensorData(bias); +#endif // USE_TFLM_COMPRESSION int16_t* output_data = tflite::micro::GetTensorData(output); const int num_elements = output_shape.FlatSize(); @@ -457,15 +542,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { data.per_channel_output_shift, data.per_channel_output_multiplier, scratch_buffer); } -#else // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5) +#else // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5) reference_integer_ops::TransposeConv( data.params, data.per_channel_output_multiplier, data.per_channel_output_shift, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.filter_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer); diff --git a/tensorflow/lite/micro/memory_arena_threshold_test.cc b/tensorflow/lite/micro/memory_arena_threshold_test.cc index 2f9242781f8..34c62cda412 100644 --- a/tensorflow/lite/micro/memory_arena_threshold_test.cc +++ b/tensorflow/lite/micro/memory_arena_threshold_test.cc @@ -63,7 +63,6 @@ constexpr int kKeywordModelOnlyTotalSize = 14472; // TODO(b/207157610): replace magic number that depends on OPs constexpr int kKeywordModelOnlyTailSize = 13800; constexpr int kKeywordModelPersistentTfLiteTensorDataSize = 128; -constexpr int kKeywordModelPersistentBufferDataSize = 832; #else // Total size contributed by the keyword model excluding the // RecordingMicroAllocator's overhead. @@ -74,7 +73,6 @@ constexpr int kKeywordModelOnlyTotalSize = 14936; // TODO(b/207157610): replace magic number that depends on OPs constexpr int kKeywordModelOnlyTailSize = 14264; constexpr int kKeywordModelPersistentTfLiteTensorDataSize = 224; -constexpr int kKeywordModelPersistentBufferDataSize = 840; #endif constexpr int kKeywordModelHeadSize = 672; constexpr int kKeywordModelTfLiteTensorVariableBufferDataSize = 10240; @@ -87,6 +85,12 @@ uint8_t test_conv_tensor_arena[kTestConvModelArenaSize]; constexpr int kTestConvModelTensorCount = 15; constexpr int kTestConvModelNodeAndRegistrationCount = 7; +#if defined(USE_TFLM_COMPRESSION) +constexpr int kKeywordModelPersistentBufferDataSize = 920; +#else +constexpr int kKeywordModelPersistentBufferDataSize = 840; +#endif + // NOTE: These values are measured on x86-64: // TODO(b/158651472): Consider auditing these values on non-64 bit systems. #ifdef TF_LITE_STATIC_MEMORY @@ -136,10 +140,6 @@ void EnsureAllocatedSizeThreshold(const char* allocation_type, size_t actual, // 64-bit systems should check floor and ceiling to catch memory savings: TF_LITE_MICRO_EXPECT_NEAR(actual, expected, expected * kAllocationThreshold); - if (actual != expected) { - MicroPrintf("%s threshold failed: %d != %d", allocation_type, actual, - expected); - } } else { // Non-64 bit systems should just expect allocation does not exceed the // ceiling: diff --git a/tensorflow/lite/micro/memory_planner/greedy_memory_planner.cc b/tensorflow/lite/micro/memory_planner/greedy_memory_planner.cc index a087b236cc9..62099308c7e 100644 --- a/tensorflow/lite/micro/memory_planner/greedy_memory_planner.cc +++ b/tensorflow/lite/micro/memory_planner/greedy_memory_planner.cc @@ -31,7 +31,7 @@ char GetOrdinalCharacter(int i) { } else if (i < 62) { return 'A' + (i - 36); } - return '*'; + return GetOrdinalCharacter(i % 62); } } // namespace @@ -335,9 +335,14 @@ void GreedyMemoryPlanner::PrintMemoryPlan() { CalculateOffsetsIfNeeded(); for (int i = 0; i < buffer_count_; ++i) { - MicroPrintf("%c (id=%d): size=%d, offset=%d, first_used=%d last_used=%d", - GetOrdinalCharacter(i), i, requirements_[i].size, - buffer_offsets_[i], requirements_[i].first_time_used, + char c = '*'; + if (requirements_[i].first_time_used != requirements_[i].last_time_used) { + // not a scratch buffer nor subgraph output tensor + c = GetOrdinalCharacter(i); + } + MicroPrintf("%c (id=%d): size=%d, offset=%d, first_used=%d last_used=%d", c, + i, requirements_[i].size, buffer_offsets_[i], + requirements_[i].first_time_used, requirements_[i].last_time_used); } @@ -379,7 +384,12 @@ void GreedyMemoryPlanner::PrintMemoryPlan() { const int line_end = ((offset + size) * kLineWidth) / max_size; for (int n = line_start; n < line_end; ++n) { if (line[n] == '.') { - line[n] = GetOrdinalCharacter(i); + if (requirements->first_time_used == requirements->last_time_used) { + // scratch buffer or subgraph output tensor + line[n] = '*'; + } else { + line[n] = GetOrdinalCharacter(i); + } } else { line[n] = '!'; } @@ -387,7 +397,7 @@ void GreedyMemoryPlanner::PrintMemoryPlan() { } line[kLineWidth] = 0; - MicroPrintf("%s%d: %s (%dk)", t < 10 ? " " : "", t, (const char*)line, + MicroPrintf("%4d: %s (%dk)", t, (const char*)line, (memory_use + 1023) / 1024); } } diff --git a/tensorflow/lite/micro/micro_allocator.cc b/tensorflow/lite/micro/micro_allocator.cc index 930da754bb5..08203285f4c 100644 --- a/tensorflow/lite/micro/micro_allocator.cc +++ b/tensorflow/lite/micro/micro_allocator.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,6 +36,15 @@ limitations under the License. #include "tensorflow/lite/micro/tflite_bridge/flatbuffer_conversions_bridge.h" #include "tensorflow/lite/schema/schema_generated.h" +#ifdef USE_TFLM_COMPRESSION + +#include +#include + +#include "tensorflow/lite/micro/compression/metadata_saved.h" + +#endif // USE_TFLM_COMPRESSION + namespace tflite { namespace { @@ -65,10 +74,10 @@ class MicroBuiltinDataAllocator : public TfLiteBridgeBuiltinDataAllocator { // of the model. } - TF_LITE_REMOVE_VIRTUAL_DELETE - private: IPersistentBufferAllocator* persistent_allocator_; + + TF_LITE_REMOVE_VIRTUAL_DELETE }; MicroMemoryPlanner* CreateMemoryPlanner( @@ -355,6 +364,142 @@ TfLiteStatus InitializeTfLiteEvalTensorFromFlatbuffer( return kTfLiteOk; } +#ifdef USE_TFLM_COMPRESSION + +const tflite::micro::compression::Metadata* GetCompressionMetadata( + const Model& model) { + const auto metadata_vector = model.metadata(); + if (metadata_vector == nullptr) { + return nullptr; + } + auto buffers = model.buffers(); + if (buffers == nullptr) { + return nullptr; + } + const size_t metadata_string_length = std::strlen(kCompressionMetadataString); + for (size_t metadata_index = 0; metadata_index < metadata_vector->size(); + metadata_index++) { + auto metadata = metadata_vector->Get(metadata_index); + if (metadata->name() == nullptr || metadata->name()->size() == 0) { + continue; + } + const char* s = metadata->name()->c_str(); + if ((metadata->name()->size() == metadata_string_length) && + (std::strncmp(s, kCompressionMetadataString, metadata_string_length) == + 0)) { + auto buffer_index = metadata->buffer(); + if (buffer_index == 0 || buffer_index >= buffers->size()) { + MicroPrintf("Compression: Invalid buffer index %u", buffer_index); + continue; + } + auto vp = buffers->Get(buffer_index)->data(); + if (vp == nullptr || vp->data() == nullptr) { + MicroPrintf("Compression: Invalid data for buffer index %u", + buffer_index); + continue; + } + // TODO(ddavis-2015): support multiple compression methods, possibly + // through multiple verification checks. + // Then return a pair. + auto compression_metadata = + tflite::micro::compression::GetSizePrefixedMetadata(vp); + flatbuffers::Verifier verifier(vp->data(), vp->size(), + flatbuffers::Verifier::Options()); + if (!tflite::micro::compression::VerifyMetadataBuffer(verifier)) { + MicroPrintf("Compression: verification failure"); + return nullptr; + } else { + return compression_metadata; + } + } + } + + return nullptr; +} + +TfLiteStatus InitializeCompressionTensorDataFromFlatbuffer( + const Model& model, const size_t subgraph_index, + const tflite::micro::compression::LutTensor& lut_tensor, + CompressionTensorData* ctd) { + // TODO(ddavis-2015): support multiple compression schemes + ctd->scheme = CompressionScheme::kBinQuant; + + const size_t tensor_index = lut_tensor.tensor(); + auto tensors = model.subgraphs()->Get(subgraph_index)->tensors(); + if (tensor_index >= tensors->size()) { + MicroPrintf("Compression: invalid tensor index %u in LutTensor", + tensor_index); + return kTfLiteError; + } + const size_t index_bit_width = lut_tensor.index_bitwidth(); + if (index_bit_width > LookupTableData::kMaxBitWidth) { + MicroPrintf("Compression: invalid bit width %u in LutTensor", + index_bit_width); + return kTfLiteError; + } + ctd->data.lut_data->compressed_bit_width = index_bit_width; + const size_t value_buffer_index = lut_tensor.value_buffer(); + if (value_buffer_index >= model.buffers()->size()) { + MicroPrintf("Compression: invalid value_buffer %u in LutTensor", + value_buffer_index); + return kTfLiteError; + } + auto value_buffer = model.buffers()->Get(value_buffer_index)->data(); + if (value_buffer == nullptr || value_buffer->data() == nullptr) { + MicroPrintf("Compression: invalid value table for value_buffer %u", + value_buffer_index); + return kTfLiteError; + } + ctd->data.lut_data->value_table = value_buffer->data(); + auto tensor = + model.subgraphs()->Get(subgraph_index)->tensors()->Get(tensor_index); + if (tensor->shape() == nullptr) { + MicroPrintf("Compression: scalar tensors not supported"); + return kTfLiteError; + } + TfLiteType tensor_type = kTfLiteNoType; + TfLiteStatus status = ConvertTensorType(tensor->type(), &tensor_type); + if (status != kTfLiteOk) { + MicroPrintf("Compression: failed to convert tensor type"); + return kTfLiteError; + } + size_t tensor_type_size = 0; + status = TfLiteTypeSizeOf(tensor_type, &tensor_type_size); + if (status != kTfLiteOk) { + MicroPrintf("Compression: failed to get tensor type size"); + return kTfLiteError; + } + if (tensor->quantization() != nullptr && + tensor->quantization()->scale() != nullptr && + tensor->quantization()->scale()->size() > 1) { + const size_t num_channels = tensor->quantization()->scale()->size(); + ctd->data.lut_data->is_per_channel_quantized = true; + const TfLiteIntArray* dims = + FlatBufferVectorToTfLiteTypeArray(tensor->shape()); + int32_t quantized_axis = tensor->quantization()->quantized_dimension(); + if (quantized_axis == 0) { + ctd->data.lut_data->use_alternate_axis = false; + } else if (quantized_axis == (dims->size - 1)) { + ctd->data.lut_data->use_alternate_axis = true; + } else { + MicroPrintf("Compression: unsupported quantization axis %u", + quantized_axis); + return kTfLiteError; + } + ctd->data.lut_data->value_table_channel_stride = + (value_buffer->size() / tensor_type_size) / num_channels; + } else { + ctd->data.lut_data->is_per_channel_quantized = false; + ctd->data.lut_data->use_alternate_axis = false; + ctd->data.lut_data->value_table_channel_stride = + value_buffer->size() / tensor_type_size; + } + + return kTfLiteOk; +} + +#endif // USE_TFLM_COMPRESSION + } // namespace internal size_t MicroAllocator::GetDefaultTailUsage(bool is_memory_planner_given) { @@ -502,7 +647,11 @@ SubgraphAllocations* MicroAllocator::StartModelAllocation(const Model* model) { return nullptr; } - if (AllocateTfLiteEvalTensors(model, output) != kTfLiteOk || + if ( +#ifdef USE_TFLM_COMPRESSION + AllocateCompressedTensorsList(model, output) != kTfLiteOk || +#endif // USE_TFLM_COMPRESSION + AllocateTfLiteEvalTensors(model, output) != kTfLiteOk || AllocateNodeAndRegistrations(model, output) != kTfLiteOk) { return nullptr; } @@ -757,6 +906,121 @@ bool MicroAllocator::IsAllTempDeallocated() { return non_persistent_buffer_allocator_->IsAllTempDeallocated(); } +#ifdef USE_TFLM_COMPRESSION + +TfLiteStatus MicroAllocator::AllocateCompressedTensorsList( + const Model* model, SubgraphAllocations* subgraph_allocations) { + TFLITE_DCHECK(subgraph_allocations != nullptr); + + for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs()->size(); + subgraph_idx++) { + subgraph_allocations[subgraph_idx].compressed.tensors = nullptr; + } + + const tflite::micro::compression::Metadata* compression_metadata = + internal::GetCompressionMetadata(*model); + if (compression_metadata == nullptr) { + // no compression metadata is available + return kTfLiteOk; + } + if (compression_metadata->subgraphs() == nullptr) { + MicroPrintf("Compression: invalid Subgraph vector"); + return kTfLiteError; + } + if (compression_metadata->subgraphs()->size() == 0) { + MicroPrintf("Compression: zero length Subgraph vector"); + return kTfLiteError; + } + + for (size_t subgraph_index = 0; + subgraph_index < compression_metadata->subgraphs()->size(); + subgraph_index++) { + auto subgraph = compression_metadata->subgraphs()->Get(subgraph_index); + + if (subgraph->lut_tensors() == nullptr) { + MicroPrintf("Compression: invalid LutTensor vector"); + return kTfLiteError; + } + if (subgraph->lut_tensors()->size() == 0) { + MicroPrintf("Compression: zero length LutTensor vector"); + return kTfLiteError; + } + + for (size_t lut_tensors_index = 0; + lut_tensors_index < subgraph->lut_tensors()->size(); + lut_tensors_index++) { + auto lut_tensor = subgraph->lut_tensors()->Get(lut_tensors_index); + + CompressionTensorData* ctd = reinterpret_cast( + persistent_buffer_allocator_->AllocatePersistentBuffer( + sizeof(CompressionTensorData), alignof(CompressionTensorData))); + if (ctd == nullptr) { + MicroPrintf( + "Compressions: failed to allocate memory for " + "CompressionTensorData, %d bytes required", + sizeof(CompressionTensorData)); + return kTfLiteError; + } + + LookupTableData* lut_table = reinterpret_cast( + persistent_buffer_allocator_->AllocatePersistentBuffer( + sizeof(LookupTableData), alignof(LookupTableData))); + if (lut_table == nullptr) { + MicroPrintf( + "Compressions: failed to allocate memory for LookupTableData, " + "%d bytes required", + sizeof(LookupTableData)); + return kTfLiteError; + } + ctd->data.lut_data = lut_table; + + TfLiteStatus status = + internal::InitializeCompressionTensorDataFromFlatbuffer( + *model, subgraph_index, *lut_tensor, ctd); + if (status != kTfLiteOk) { + MicroPrintf("Compression: failed to initialize data for LutTensor %u", + lut_tensors_index); + return kTfLiteError; + } + + if (subgraph_allocations[subgraph_index].compressed.tensors == nullptr) { + size_t alloc_count = + model->subgraphs()->Get(subgraph_index)->tensors()->size(); + const CompressionTensorData** tensors = + reinterpret_cast( + persistent_buffer_allocator_->AllocatePersistentBuffer( + sizeof(CompressionTensorData*) * alloc_count, + alignof(CompressionTensorData*))); + if (tensors == nullptr) { + MicroPrintf( + "Compression: failed to allocate memory for compression tensor " + "list, %d bytes required", + sizeof(CompressionTensorData*) * alloc_count); + return kTfLiteError; + } + + subgraph_allocations[subgraph_index].compressed.tensors = tensors; + std::fill(tensors, tensors + alloc_count, nullptr); + } + + const size_t tensor_index = lut_tensor->tensor(); + if (subgraph_allocations[subgraph_index] + .compressed.tensors[tensor_index] != nullptr) { + MicroPrintf("Compression: duplicate LutTensor subgraph %u tensor %u", + subgraph_index, tensor_index); + return kTfLiteError; + } else { + subgraph_allocations[subgraph_index].compressed.tensors[tensor_index] = + ctd; + } + } + } + + return kTfLiteOk; +} + +#endif // USE_TFLM_COMPRESSION + TfLiteStatus MicroAllocator::AllocateTfLiteEvalTensors( const Model* model, SubgraphAllocations* subgraph_allocations) { TFLITE_DCHECK(subgraph_allocations != nullptr); diff --git a/tensorflow/lite/micro/micro_allocator.h b/tensorflow/lite/micro/micro_allocator.h index 02317220e12..215bffc6a8c 100644 --- a/tensorflow/lite/micro/micro_allocator.h +++ b/tensorflow/lite/micro/micro_allocator.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,6 +26,12 @@ limitations under the License. #include "tensorflow/lite/micro/tflite_bridge/flatbuffer_conversions_bridge.h" #include "tensorflow/lite/schema/schema_generated.h" +#ifdef USE_TFLM_COMPRESSION + +#include "tensorflow/lite/micro/compression.h" + +#endif // USE_TFLM_COMPRESSION + namespace tflite { // TODO(b/199402574): rename to tflite_internal or just remove internal @@ -91,6 +97,9 @@ struct ScratchBufferHandle { struct SubgraphAllocations { NodeAndRegistration* node_and_registrations; TfLiteEvalTensor* tensors; +#ifdef USE_TFLM_COMPRESSION + CompressedTensorList compressed; +#endif // USE_TFLM_COMPRESSION }; // Allocator responsible for allocating memory for all intermediate tensors @@ -258,6 +267,15 @@ class MicroAllocator { MicroMemoryPlanner* memory_planner); virtual ~MicroAllocator(); +#ifdef USE_TFLM_COMPRESSION + + // Allocates an array in the arena of pointers to the compressions data + // required to decompress tensors for each subgraph within the model. + virtual TfLiteStatus AllocateCompressedTensorsList( + const Model* model, SubgraphAllocations* subgraph_allocations); + +#endif // USE_TFLM_COMPRESSION + // Allocates an array in the arena to hold pointers to the node and // registration pointers required to represent the inference graph of the // model. diff --git a/tensorflow/lite/micro/micro_context.cc b/tensorflow/lite/micro/micro_context.cc index 295b3c34463..ea4fd8e8dc7 100644 --- a/tensorflow/lite/micro/micro_context.cc +++ b/tensorflow/lite/micro/micro_context.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,8 +18,11 @@ limitations under the License. #include #include +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/micro/kernels/decompress.h" #include "tensorflow/lite/micro/micro_common.h" #include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_utils.h" namespace tflite { namespace { @@ -74,4 +77,66 @@ void MicroContextReportOpError(struct TfLiteContext* context, va_end(args); } +#ifdef USE_TFLM_COMPRESSION + +void* MicroContext::DecompressTensorToBuffer( + const TfLiteEvalTensor& tensor, + const CompressionTensorData& compression_data, void* buffer) { + TFLITE_DCHECK(compression_data.scheme == CompressionScheme::kBinQuant); + TFLITE_DCHECK(buffer != nullptr); + size_t count = ElementCount(*tensor.dims); + size_t num_channels = 1; + + if (compression_data.data.lut_data->is_per_channel_quantized) { + const size_t channel_axis = + compression_data.data.lut_data->use_alternate_axis + ? tensor.dims->size - 1 + : 0; + num_channels = tensor.dims->data[channel_axis]; + } + + DecompressionState ds(static_cast(tensor.data.data), count, + compression_data, num_channels, GetAlternateProfiler()); + + switch (tensor.type) { + case kTfLiteBool: { + return ds.DecompressToBuffer(buffer); + } break; + case kTfLiteInt8: { + return ds.DecompressToBuffer(buffer); + } break; + case kTfLiteInt16: { + return ds.DecompressToBuffer(buffer); + } break; + case kTfLiteInt32: { + return ds.DecompressToBuffer(buffer); + } break; + case kTfLiteInt64: { + return ds.DecompressToBuffer(buffer); + } break; + case kTfLiteFloat32: { + return ds.DecompressToBuffer(buffer); + } break; + default: { + MicroPrintf("Unsupported decompression tensor type %d", tensor.type); + } break; + } + + return nullptr; +} + +TfLiteStatus MicroContext::SetDecompressionMemory( + const std::initializer_list& regions) { + return kTfLiteError; +} + +void* MicroContext::AllocateDecompressionMemory(size_t bytes, + size_t alignment) { + return nullptr; +} + +void MicroContext::ResetDecompressionMemoryAllocations() {} + +#endif // USE_TFLM_COMPRESSION + } // namespace tflite diff --git a/tensorflow/lite/micro/micro_context.h b/tensorflow/lite/micro/micro_context.h index 2dd3233a159..5b1ea9ca798 100644 --- a/tensorflow/lite/micro/micro_context.h +++ b/tensorflow/lite/micro/micro_context.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,15 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/micro_graph.h" +#include "tensorflow/lite/micro/micro_profiler_interface.h" + +#ifdef USE_TFLM_COMPRESSION + +#include + +#include "tensorflow/lite/micro/compression.h" + +#endif // USE_TFLM_COMPRESSION namespace tflite { // TODO(b/149795762): kTfLiteAbort cannot be part of the tflite TfLiteStatus. @@ -95,6 +104,69 @@ class MicroContext { virtual MicroGraph& graph() = 0; +#ifdef USE_TFLM_COMPRESSION + + // Available during Prepare & Eval. Returns false if tensor is not + // compressed. + virtual bool IsTensorCompressed(const TfLiteNode* node, int tensor_idx) = 0; + + // Only available during Prepare. The kernel is responsible for storing the + // scratch buffer handle. + virtual int AllocateDecompressionScratchBuffer(const TfLiteNode* node, + int tensor_idx) = 0; + + // Available during Prepare & Eval. Returns nullptr if tensor is not + // compressed. + virtual const CompressionTensorData* GetTensorCompressionData( + const TfLiteNode* node, int tensor_idx) = 0; + + // Only available during Prepare & Eval. Returns nullptr on failure, otherwise + // returns a pointer to the buffer. + virtual void* DecompressTensorToBuffer( + const TfLiteEvalTensor& tensor, + const CompressionTensorData& compression_data, void* buffer); + + // Used for configuring alternate decompression memory + struct AlternateMemoryRegion { + void* address; + size_t bytes; + }; + + // Set the alternate decompression memory regions. + // Can only be called during the MicroInterpreter kInit state. + virtual TfLiteStatus SetDecompressionMemory( + const std::initializer_list& regions); + + // Return a pointer to memory that can be used for decompression. + // The pointer will be aligned to the value. + // Return nullptr if the requested size is not available. + // Can be called during kPrepare and kInvoke states. + virtual void* AllocateDecompressionMemory(size_t bytes, size_t alignment); + + // reset all allocation tracking + virtual void ResetDecompressionMemoryAllocations(); + +#endif // USE_TFLM_COMPRESSION + + // Set the alternate MicroProfilerInterface. + // This can be used to profile subsystems simultaneously with the profiling + // of kernels during the Eval phase. See (b/379584353). + // The alternate MicroProfilerInterface is currently used by the tensor + // decompression subsystem. + virtual TfLiteStatus SetAlternateProfiler( + MicroProfilerInterface* alt_profiler) { + return kTfLiteError; + } + + // Get the alternate MicroProfilerInterface. + // This can be used to profile subsystems simultaneously with the profiling + // of kernels during the Eval phase. See (b/379584353). + // The alternate MicroProfilerInterface is currently used by the tensor + // decompression subsystem. + virtual MicroProfilerInterface* GetAlternateProfiler() const { + return nullptr; + } + private: TF_LITE_REMOVE_VIRTUAL_DELETE }; diff --git a/tensorflow/lite/micro/micro_interpreter.cc b/tensorflow/lite/micro/micro_interpreter.cc index 7f4565e638a..666516c6a3a 100644 --- a/tensorflow/lite/micro/micro_interpreter.cc +++ b/tensorflow/lite/micro/micro_interpreter.cc @@ -334,4 +334,19 @@ TfLiteStatus MicroInterpreter::SetMicroExternalContext( return micro_context_.set_external_context(external_context_payload); } +TfLiteStatus MicroInterpreter::SetAlternateProfiler( + MicroProfilerInterface* alt_profiler) { + return micro_context_.SetAlternateProfiler(alt_profiler); +} + +#ifdef USE_TFLM_COMPRESSION + +TfLiteStatus MicroInterpreter::SetDecompressionMemory( + const std::initializer_list& + regions) { + return micro_context_.SetDecompressionMemory(regions); +} + +#endif // USE_TFLM_COMPRESSION + } // namespace tflite diff --git a/tensorflow/lite/micro/micro_interpreter.h b/tensorflow/lite/micro/micro_interpreter.h index 1c419962239..4a03c3fe825 100644 --- a/tensorflow/lite/micro/micro_interpreter.h +++ b/tensorflow/lite/micro/micro_interpreter.h @@ -18,6 +18,12 @@ limitations under the License. #include #include +#ifdef USE_TFLM_COMPRESSION + +#include + +#endif // USE_TFLM_COMPRESSION + #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" @@ -146,6 +152,25 @@ class MicroInterpreter { return allocator_.preserves_all_tensor(); } + // Set the alternate MicroProfilerInterface. + // This value is passed through to the MicroContext. + // This can be used to profile subsystems simultaneously with the profiling + // of kernels during the Eval phase. See (b/379584353). + // The alternate MicroProfilerInterface is currently used by the tensor + // decompression subsystem. + TfLiteStatus SetAlternateProfiler(MicroProfilerInterface* alt_profiler); + +#ifdef USE_TFLM_COMPRESSION + + // Set the alternate decompression memory regions. + // Can only be called during the MicroInterpreter kInit state (i.e. must + // be called before MicroInterpreter::AllocateTensors). + TfLiteStatus SetDecompressionMemory( + const std::initializer_list& + regions); + +#endif // USE_TFLM_COMPRESSION + protected: const MicroAllocator& allocator() const { return allocator_; } const TfLiteContext& context() const { return context_; } diff --git a/tensorflow/lite/micro/micro_interpreter_context.cc b/tensorflow/lite/micro/micro_interpreter_context.cc index 098df15d522..38c6225be0a 100644 --- a/tensorflow/lite/micro/micro_interpreter_context.cc +++ b/tensorflow/lite/micro/micro_interpreter_context.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,9 +17,38 @@ limitations under the License. #include +#ifdef USE_TFLM_COMPRESSION + +#include + +#include "tensorflow/lite/micro/memory_helpers.h" +#include "tensorflow/lite/micro/micro_arena_constants.h" + +#endif // USE_TFLM_COMPRESSION + #include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/micro/micro_utils.h" namespace tflite { + +namespace { + +#ifdef USE_TFLM_COMPRESSION + +int GetInputTensorIndex(const TfLiteNode* node, const int index) { + if (index >= 0 && index < node->inputs->size) { + const int tensor_index = node->inputs->data[index]; + if (tensor_index != kTfLiteOptionalTensor) { + return tensor_index; + } + } + return -1; +} + +#endif // USE_TFLM_COMPRESSION + +} // namespace + MicroInterpreterContext::MicroInterpreterContext(MicroAllocator* allocator, const Model* model, MicroInterpreterGraph* graph) @@ -106,4 +135,146 @@ MicroInterpreterContext::GetInterpreterState() const { return state_; } +#ifdef USE_TFLM_COMPRESSION + +// Available during Prepare & Eval. Returns false if tensor is not +// compressed. +bool MicroInterpreterContext::IsTensorCompressed(const TfLiteNode* node, + int tensor_idx) { + TFLITE_DCHECK(state_ == InterpreterState::kPrepare || + state_ == InterpreterState::kInvoke); + + const SubgraphAllocations* allocations = + &graph_.GetAllocations()[graph_.GetCurrentSubgraphIndex()]; + if (allocations->compressed.tensors == nullptr) { + return false; + } + int index = GetInputTensorIndex(node, tensor_idx); + if (index == -1) { + return false; + } + return allocations->compressed.tensors[index] != nullptr; +} + +// Only available during Prepare. The kernel is responsible for storing the +// scratch buffer handle. +int MicroInterpreterContext::AllocateDecompressionScratchBuffer( + const TfLiteNode* node, int tensor_idx) { + TFLITE_DCHECK(state_ == InterpreterState::kPrepare); + + const SubgraphAllocations* allocations = + &graph_.GetAllocations()[graph_.GetCurrentSubgraphIndex()]; + if (allocations->compressed.tensors == nullptr) { + return -1; + } + int index = GetInputTensorIndex(node, tensor_idx); + if (index == -1 || allocations->compressed.tensors[index] == nullptr) { + return -1; + } + const TfLiteEvalTensor* tensor = &allocations->tensors[index]; + const size_t byte_count = EvalTensorBytes(tensor); + + if (AllocateDecompressionMemory(byte_count, MicroArenaBufferAlignment()) != + nullptr) { + // Tensor fits in alternate decompression memory, no need to allocate + // scratch buffer. + return -1; + } + + int scratch_index = -1; + TfLiteStatus result = RequestScratchBufferInArena(byte_count, &scratch_index); + TFLITE_DCHECK(scratch_index != -1 && result == kTfLiteOk); + + return scratch_index; +} + +// Available during Prepare & Eval. Returns nullptr if tensor is not +// compressed. +const CompressionTensorData* MicroInterpreterContext::GetTensorCompressionData( + const TfLiteNode* node, int tensor_idx) { + TFLITE_DCHECK(state_ == InterpreterState::kPrepare || + state_ == InterpreterState::kInvoke); + + const SubgraphAllocations* allocations = + &graph_.GetAllocations()[graph_.GetCurrentSubgraphIndex()]; + if (allocations->compressed.tensors == nullptr) { + return nullptr; + } + int index = GetInputTensorIndex(node, tensor_idx); + if (index == -1) { + return nullptr; + } + return allocations->compressed.tensors[index]; +} + +// Only available during Prepare & Eval. Returns nullptr on failure, otherwise +// returns a pointer to the buffer. +void* MicroInterpreterContext::DecompressTensorToBuffer( + const TfLiteEvalTensor& tensor, + const CompressionTensorData& compression_data, void* buffer) { + TFLITE_DCHECK(state_ == InterpreterState::kPrepare || + state_ == InterpreterState::kInvoke); + + return MicroContext::DecompressTensorToBuffer(tensor, compression_data, + buffer); +} + +TfLiteStatus MicroInterpreterContext::SetDecompressionMemory( + const std::initializer_list& regions) { + if (state_ != InterpreterState::kInit) { + return kTfLiteError; + } + + decompress_regions_ = ®ions; + decompress_regions_allocations_ = static_cast( + AllocatePersistentBuffer(sizeof(size_t) * regions.size())); + if (decompress_regions_allocations_ == nullptr) { + return kTfLiteError; + } + ResetDecompressionMemoryAllocations(); + + return kTfLiteOk; +} + +void* MicroInterpreterContext::AllocateDecompressionMemory(size_t bytes, + size_t alignment) { + TFLITE_DCHECK(state_ == InterpreterState::kPrepare || + state_ == InterpreterState::kInvoke); + if (decompress_regions_ != nullptr) { + for (size_t i = 0; i < decompress_regions_->size(); i++) { + const AlternateMemoryRegion* region = &decompress_regions_->begin()[i]; + uint8_t* start = static_cast(region->address) + + decompress_regions_allocations_[i]; + uint8_t* aligned_start = AlignPointerUp(start, alignment); + size_t total = bytes + (aligned_start - start); + if (total + decompress_regions_allocations_[i] <= region->bytes) { + decompress_regions_allocations_[i] += total; + return aligned_start; + } + } + } + + return nullptr; +} + +void MicroInterpreterContext::ResetDecompressionMemoryAllocations() { + if (decompress_regions_ == nullptr) { + return; + } + TFLITE_DCHECK(decompress_regions_allocations_ != nullptr); + std::fill_n(decompress_regions_allocations_, decompress_regions_->size(), 0); +} + +#endif // USE_TFLM_COMPRESSION + +TfLiteStatus MicroInterpreterContext::SetAlternateProfiler( + tflite::MicroProfilerInterface* alt_profiler) { + alt_profiler_ = alt_profiler; + return kTfLiteOk; +} + +MicroProfilerInterface* MicroInterpreterContext::GetAlternateProfiler() const { + return alt_profiler_; +} + } // namespace tflite diff --git a/tensorflow/lite/micro/micro_interpreter_context.h b/tensorflow/lite/micro/micro_interpreter_context.h index 5986dc37fd2..a3927580d51 100644 --- a/tensorflow/lite/micro/micro_interpreter_context.h +++ b/tensorflow/lite/micro/micro_interpreter_context.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -106,6 +106,59 @@ class MicroInterpreterContext : public MicroContext { // housekeeping in MicroInterpreterContext. void SetScratchBufferHandles(ScratchBufferHandle* scratch_buffer_handles); +#ifdef USE_TFLM_COMPRESSION + + // Available during Prepare & Eval. Returns false if tensor is not + // compressed. + bool IsTensorCompressed(const TfLiteNode* node, int tensor_idx) override; + + // Only available during Prepare. The kernel is responsible for storing the + // scratch buffer handle. + int AllocateDecompressionScratchBuffer(const TfLiteNode* node, + int tensor_idx) override; + + // Available during Prepare & Eval. Returns nullptr if tensor is not + // compressed. + const CompressionTensorData* GetTensorCompressionData( + const TfLiteNode* node, int tensor_idx) override; + + // Only available during Prepare & Eval. Returns nullptr on failure, otherwise + // returns a pointer to the buffer. + void* DecompressTensorToBuffer(const TfLiteEvalTensor& tensor, + const CompressionTensorData& compression_data, + void* buffer) override; + + // Set the alternate decompression memory regions. + // Can only be called during the MicroInterpreter kInit state. + TfLiteStatus SetDecompressionMemory( + const std::initializer_list& regions) override; + + // Return a pointer to memory that can be used for decompression. + // The pointer will be aligned to the value. + // Return nullptr if the requested size is not available. + // Can be called during kPrepare and kInvoke states. + void* AllocateDecompressionMemory(size_t bytes, size_t alignment) override; + + // reset all allocation tracking + void ResetDecompressionMemoryAllocations() override; + +#endif // USE_TFLM_COMPRESSION + + // Set the alternate MicroProfilerInterface. + // This can be used to profile subsystems simultaneously with the profiling + // of kernels during the Eval phase. See (b/379584353). + // The alternate MicroProfilerInterface is currently used by the tensor + // decompression subsystem. + TfLiteStatus SetAlternateProfiler( + MicroProfilerInterface* alt_profiler) override; + + // Get the alternate MicroProfilerInterface. + // This can be used to profile subsystems simultaneously with the profiling + // of kernels during the Eval phase. See (b/379584353). + // The alternate MicroProfilerInterface is currently used by the tensor + // decompression subsystem. + MicroProfilerInterface* GetAlternateProfiler() const override; + private: MicroAllocator& allocator_; MicroInterpreterGraph& graph_; @@ -114,6 +167,16 @@ class MicroInterpreterContext : public MicroContext { ScratchBufferHandle* scratch_buffer_handles_ = nullptr; void* external_context_payload_ = nullptr; + MicroProfilerInterface* alt_profiler_ = nullptr; + +#ifdef USE_TFLM_COMPRESSION + + const std::initializer_list* decompress_regions_ = + nullptr; + // array of size_t elements with length equal to decompress_regions_.size() + size_t* decompress_regions_allocations_; + +#endif // USE_TFLM_COMPRESSION TF_LITE_REMOVE_VIRTUAL_DELETE }; diff --git a/tensorflow/lite/micro/micro_interpreter_graph.cc b/tensorflow/lite/micro/micro_interpreter_graph.cc index 7f096ae71c9..61dd06e7ce1 100644 --- a/tensorflow/lite/micro/micro_interpreter_graph.cc +++ b/tensorflow/lite/micro/micro_interpreter_graph.cc @@ -24,6 +24,12 @@ limitations under the License. #include "tensorflow/lite/micro/micro_profiler.h" #include "tensorflow/lite/schema/schema_generated.h" +#ifdef USE_TFLM_COMPRESSION + +#include "tensorflow/lite/micro/micro_context.h" + +#endif // USE_TFLM_COMPRESSION + namespace tflite { namespace { @@ -115,6 +121,9 @@ TfLiteStatus MicroInterpreterGraph::PrepareSubgraphs() { current_operator_index_, prepare_status); return kTfLiteError; } +#ifdef USE_TFLM_COMPRESSION + GetMicroContext(context_)->ResetDecompressionMemoryAllocations(); +#endif // USE_TFLM_COMPRESSION } allocator_->FinishPrepareNodeAllocations( /*node_id=*/current_operator_index_); @@ -217,6 +226,9 @@ TfLiteStatus MicroInterpreterGraph::InvokeSubgraph(int subgraph_idx) { TFLITE_DCHECK(registration->invoke); TfLiteStatus invoke_status = registration->invoke(context_, node); +#ifdef USE_TFLM_COMPRESSION + GetMicroContext(context_)->ResetDecompressionMemoryAllocations(); +#endif // USE_TFLM_COMPRESSION // All TfLiteTensor structs used in the kernel are allocated from temp // memory in the allocator. This creates a chain of allocations in the diff --git a/tensorflow/lite/micro/micro_interpreter_test.cc b/tensorflow/lite/micro/micro_interpreter_test.cc index e44de6b09aa..eebdd39f12a 100644 --- a/tensorflow/lite/micro/micro_interpreter_test.cc +++ b/tensorflow/lite/micro/micro_interpreter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/lite/micro/micro_interpreter.h" #include +#include +#include #include "tensorflow/lite/micro/arena_allocator/recording_single_arena_buffer_allocator.h" #include "tensorflow/lite/micro/compatibility.h" @@ -108,6 +110,215 @@ TF_LITE_MICRO_TEST(TestInterpreter) { TF_LITE_MICRO_EXPECT_EQ(tflite::testing::MockCustom::freed_, true); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(TestInterpreterCompression) { + const tflite::Model* model = tflite::testing::GetSimpleMockModelCompressed(); + TF_LITE_MICRO_EXPECT(nullptr != model); + tflite::testing::TestingOpResolver op_resolver; + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, + tflite::testing::GetTestingOpResolver(op_resolver)); + + constexpr size_t kAllocatorBufferSize = 2000; + uint8_t allocator_buffer[kAllocatorBufferSize]; + + // Create a new scope so that we can test the destructor. + { + tflite::MicroInterpreter interpreter(model, op_resolver, allocator_buffer, + kAllocatorBufferSize); + TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + TF_LITE_MICRO_EXPECT_EQ(static_cast(1), interpreter.inputs_size()); + TF_LITE_MICRO_EXPECT_EQ(static_cast(1), interpreter.outputs_size()); + + TfLiteTensor* input = interpreter.input(0); + TF_LITE_MICRO_EXPECT(nullptr != input); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt16, input->type); + TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size); + TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ(static_cast(2), input->bytes); + TF_LITE_MICRO_EXPECT(nullptr != input->data.data); + static_cast(input->data.data)[0] = 42; + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke()); + + const std::initializer_list kGolden = { + 43, 44, 45, 46, 47, 41, 40, 39, 38, 37, 43, 44, 45, 46, 47}; + const int kGoldenCount = kGolden.size(); + TfLiteTensor* output = interpreter.output(0); + TF_LITE_MICRO_EXPECT(nullptr != output); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt16, output->type); + TF_LITE_MICRO_EXPECT_EQ(1, output->dims->size); + TF_LITE_MICRO_EXPECT_EQ(kGoldenCount, output->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ( + static_cast(kGoldenCount * sizeof(*kGolden.begin())), + output->bytes); + TF_LITE_MICRO_EXPECT(nullptr != output->data.data); + for (int i = 0; i < kGoldenCount; i++) { + TF_LITE_MICRO_EXPECT_EQ(static_cast(output->data.data)[i], + kGolden.begin()[i]); + } + } +} + +TF_LITE_MICRO_TEST(TestInterpreterCompressionAltMemoryAfterInit) { + const tflite::Model* model = tflite::testing::GetSimpleMockModelCompressed(); + TF_LITE_MICRO_EXPECT(nullptr != model); + tflite::testing::TestingOpResolver op_resolver; + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, + tflite::testing::GetTestingOpResolver(op_resolver)); + + constexpr size_t kAllocatorBufferSize = 2000; + uint8_t allocator_buffer[kAllocatorBufferSize]; + constexpr size_t kAltMemSize = 10; + int16_t alt_mem_1[kAltMemSize]; + int16_t alt_mem_2[kAltMemSize]; + std::initializer_list alt_mem = { + {alt_mem_1, sizeof(alt_mem_1)}, + {alt_mem_2, sizeof(alt_mem_2)}, + }; + + // Create a new scope so that we can test the destructor. + { + tflite::MicroInterpreter interpreter(model, op_resolver, allocator_buffer, + kAllocatorBufferSize); + TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + TF_LITE_MICRO_EXPECT_EQ(interpreter.SetDecompressionMemory(alt_mem), + kTfLiteError); + } +} + +TF_LITE_MICRO_TEST(TestInterpreterCompressionAltMemoryTooSmall) { + const tflite::Model* model = tflite::testing::GetSimpleMockModelCompressed(); + TF_LITE_MICRO_EXPECT(nullptr != model); + tflite::testing::TestingOpResolver op_resolver; + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, + tflite::testing::GetTestingOpResolver(op_resolver)); + + constexpr size_t kAllocatorBufferSize = 2000; + uint8_t allocator_buffer[kAllocatorBufferSize]; + constexpr size_t kAltMemSize = 10; + int16_t alt_mem_1[kAltMemSize] = {}; + int16_t alt_mem_2[kAltMemSize] = {}; + std::initializer_list alt_mem = { + {alt_mem_1, sizeof(alt_mem_1)}, + {alt_mem_2, sizeof(alt_mem_2)}, + }; + + // Create a new scope so that we can test the destructor. + { + tflite::MicroInterpreter interpreter(model, op_resolver, allocator_buffer, + kAllocatorBufferSize); + TF_LITE_MICRO_EXPECT_EQ(interpreter.SetDecompressionMemory(alt_mem), + kTfLiteOk); + TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + TF_LITE_MICRO_EXPECT_EQ(static_cast(1), interpreter.inputs_size()); + TF_LITE_MICRO_EXPECT_EQ(static_cast(1), interpreter.outputs_size()); + + TfLiteTensor* input = interpreter.input(0); + TF_LITE_MICRO_EXPECT(nullptr != input); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt16, input->type); + TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size); + TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ(static_cast(2), input->bytes); + TF_LITE_MICRO_EXPECT(nullptr != input->data.data); + static_cast(input->data.data)[0] = 42; + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke()); + + const std::initializer_list kGolden = { + 43, 44, 45, 46, 47, 41, 40, 39, 38, 37, 43, 44, 45, 46, 47}; + const int kGoldenCount = kGolden.size(); + TfLiteTensor* output = interpreter.output(0); + TF_LITE_MICRO_EXPECT(nullptr != output); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt16, output->type); + TF_LITE_MICRO_EXPECT_EQ(1, output->dims->size); + TF_LITE_MICRO_EXPECT_EQ(kGoldenCount, output->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ( + static_cast(kGoldenCount * sizeof(*kGolden.begin())), + output->bytes); + TF_LITE_MICRO_EXPECT(nullptr != output->data.data); + for (int i = 0; i < kGoldenCount; i++) { + TF_LITE_MICRO_EXPECT_EQ(static_cast(output->data.data)[i], + kGolden.begin()[i]); + } + for (size_t i = 0; i < kAltMemSize; i++) { + TF_LITE_MICRO_EXPECT_EQ(alt_mem_1[i], 0); + TF_LITE_MICRO_EXPECT_EQ(alt_mem_2[i], 0); + } + } +} + +TF_LITE_MICRO_TEST(TestInterpreterCompressionAltMemory) { + const tflite::Model* model = tflite::testing::GetSimpleMockModelCompressed(); + TF_LITE_MICRO_EXPECT(nullptr != model); + tflite::testing::TestingOpResolver op_resolver; + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, + tflite::testing::GetTestingOpResolver(op_resolver)); + + constexpr size_t kAllocatorBufferSize = 2000; + uint8_t allocator_buffer[kAllocatorBufferSize]; + constexpr size_t kAltMemSize = 10; + int16_t alt_mem_1[kAltMemSize] = {}; + int16_t alt_mem_2[kAltMemSize * 2] = {}; + std::initializer_list alt_mem = { + {alt_mem_1, sizeof(alt_mem_1)}, + {alt_mem_2, sizeof(alt_mem_2)}, + }; + + // Create a new scope so that we can test the destructor. + { + tflite::MicroInterpreter interpreter(model, op_resolver, allocator_buffer, + kAllocatorBufferSize); + TF_LITE_MICRO_EXPECT_EQ(interpreter.SetDecompressionMemory(alt_mem), + kTfLiteOk); + TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + TF_LITE_MICRO_EXPECT_EQ(static_cast(1), interpreter.inputs_size()); + TF_LITE_MICRO_EXPECT_EQ(static_cast(1), interpreter.outputs_size()); + + TfLiteTensor* input = interpreter.input(0); + TF_LITE_MICRO_EXPECT(nullptr != input); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt16, input->type); + TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size); + TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ(static_cast(2), input->bytes); + TF_LITE_MICRO_EXPECT(nullptr != input->data.data); + static_cast(input->data.data)[0] = 42; + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke()); + + const std::initializer_list kGolden = { + 43, 44, 45, 46, 47, 41, 40, 39, 38, 37, 43, 44, 45, 46, 47}; + const int kGoldenCount = kGolden.size(); + TfLiteTensor* output = interpreter.output(0); + TF_LITE_MICRO_EXPECT(nullptr != output); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt16, output->type); + TF_LITE_MICRO_EXPECT_EQ(1, output->dims->size); + TF_LITE_MICRO_EXPECT_EQ(kGoldenCount, output->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ( + static_cast(kGoldenCount * sizeof(*kGolden.begin())), + output->bytes); + TF_LITE_MICRO_EXPECT(nullptr != output->data.data); + for (int i = 0; i < kGoldenCount; i++) { + TF_LITE_MICRO_EXPECT_EQ(static_cast(output->data.data)[i], + kGolden.begin()[i]); + } + std::initializer_list uncompressed = {1, 2, 3, 4, 5, -1, -2, -3, + -4, -5, 1, 2, 3, 4, 5}; + for (size_t i = 0; i < kAltMemSize; i++) { + TF_LITE_MICRO_EXPECT_EQ(alt_mem_1[i], 0); + } + for (size_t i = 0; i < uncompressed.size(); i++) { + TF_LITE_MICRO_EXPECT_EQ(alt_mem_2[i], uncompressed.begin()[i]); + } + for (size_t i = uncompressed.size(); + i < std::extent::value; i++) { + TF_LITE_MICRO_EXPECT_EQ(alt_mem_2[i], 0); + } + } +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(TestMultiTenantInterpreter) { tflite::testing::TestingOpResolver op_resolver; TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index f5f6e38e003..3ec00a6b614 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -44,8 +44,6 @@ TFLMRegistration* Register_DETECTION_POSTPROCESS(); template class MicroMutableOpResolver : public MicroOpResolver { public: - TF_LITE_REMOVE_VIRTUAL_DELETE - explicit MicroMutableOpResolver() {} const TFLMRegistration* FindOp(tflite::BuiltinOperator op) const override { @@ -704,6 +702,8 @@ class MicroMutableOpResolver : public MicroOpResolver { BuiltinOperator builtin_codes_[tOpCount]; TfLiteBridgeBuiltinParseFunction builtin_parsers_[tOpCount]; unsigned int num_buitin_ops_ = 0; + + TF_LITE_REMOVE_VIRTUAL_DELETE }; }; // namespace tflite diff --git a/tensorflow/lite/micro/micro_profiler.cc b/tensorflow/lite/micro/micro_profiler.cc index ebead51a90d..e349bf73668 100644 --- a/tensorflow/lite/micro/micro_profiler.cc +++ b/tensorflow/lite/micro/micro_profiler.cc @@ -86,14 +86,14 @@ void MicroProfiler::LogTicksPerTagCsv() { TFLITE_DCHECK(tags_[i] != nullptr); int position = FindExistingOrNextPosition(tags_[i]); TFLITE_DCHECK(position >= 0); - total_ticks_per_tag[position].tag = tags_[i]; - total_ticks_per_tag[position].ticks = - total_ticks_per_tag[position].ticks + ticks; + total_ticks_per_tag_[position].tag = tags_[i]; + total_ticks_per_tag_[position].ticks = + total_ticks_per_tag_[position].ticks + ticks; total_ticks += ticks; } for (int i = 0; i < num_events_; ++i) { - TicksPerTag each_tag_entry = total_ticks_per_tag[i]; + TicksPerTag each_tag_entry = total_ticks_per_tag_[i]; if (each_tag_entry.tag == nullptr) { break; } @@ -112,7 +112,7 @@ void MicroProfiler::LogTicksPerTagCsv() { int MicroProfiler::FindExistingOrNextPosition(const char* tag_name) { int pos = 0; for (; pos < num_events_; pos++) { - TicksPerTag each_tag_entry = total_ticks_per_tag[pos]; + TicksPerTag each_tag_entry = total_ticks_per_tag_[pos]; if (each_tag_entry.tag == nullptr || strcmp(each_tag_entry.tag, tag_name) == 0) { return pos; @@ -120,4 +120,13 @@ int MicroProfiler::FindExistingOrNextPosition(const char* tag_name) { } return pos < num_events_ ? pos : -1; } + +void MicroProfiler::ClearEvents() { + for (int i = 0; i < num_events_; i++) { + total_ticks_per_tag_[i].tag = nullptr; + } + + num_events_ = 0; +} + } // namespace tflite diff --git a/tensorflow/lite/micro/micro_profiler.h b/tensorflow/lite/micro/micro_profiler.h index b52ebcb4ea9..fd8bc42ffd4 100644 --- a/tensorflow/lite/micro/micro_profiler.h +++ b/tensorflow/lite/micro/micro_profiler.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,7 +45,7 @@ class MicroProfiler : public MicroProfilerInterface { virtual void EndEvent(uint32_t event_handle) override; // Clears all the events that have been currently profiled. - void ClearEvents() { num_events_ = 0; } + void ClearEvents(); // Returns the sum of the ticks taken across all the events. This number // is only meaningful if all of the events are disjoint (the end time of @@ -83,7 +83,7 @@ class MicroProfiler : public MicroProfilerInterface { // In practice, the number of tags will be much lower than the number of // events. But it is theoretically possible that each event to be unique and // hence we allow total_ticks_per_tag to have kMaxEvents entries. - TicksPerTag total_ticks_per_tag[kMaxEvents] = {}; + TicksPerTag total_ticks_per_tag_[kMaxEvents] = {}; int FindExistingOrNextPosition(const char* tag_name); diff --git a/tensorflow/lite/micro/micro_resource_variable.cc b/tensorflow/lite/micro/micro_resource_variable.cc index 767e7d17d6f..843aac664bc 100644 --- a/tensorflow/lite/micro/micro_resource_variable.cc +++ b/tensorflow/lite/micro/micro_resource_variable.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -113,8 +113,8 @@ TfLiteStatus MicroResourceVariables::Allocate(int id, TfLiteContext* context, return kTfLiteOk; } -TfLiteStatus MicroResourceVariables::Assign(int id, - const TfLiteEvalTensor* tensor) { +TfLiteStatus MicroResourceVariables::Assign(int id, size_t count_bytes, + const void* input_buffer) { if (id < 0 || id >= num_resource_variables_) { MicroPrintf("Attempting to read non-existent resource variable %d", id); return kTfLiteError; @@ -128,8 +128,9 @@ TfLiteStatus MicroResourceVariables::Assign(int id, "with a TfLiteTensor first."); return kTfLiteError; } - TFLITE_DCHECK(EvalTensorBytes(tensor) == variable.bytes); - memcpy(variable.resource_buffer, tensor->data.raw, variable.bytes); + TFLITE_DCHECK(count_bytes == variable.bytes); + TFLITE_DCHECK(input_buffer != nullptr); + memcpy(variable.resource_buffer, input_buffer, variable.bytes); return kTfLiteOk; } diff --git a/tensorflow/lite/micro/micro_resource_variable.h b/tensorflow/lite/micro/micro_resource_variable.h index fb9917d4784..57da6497b3a 100644 --- a/tensorflow/lite/micro/micro_resource_variable.h +++ b/tensorflow/lite/micro/micro_resource_variable.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,10 +46,10 @@ class MicroResourceVariables { TfLiteStatus Allocate(int id, TfLiteContext* context, const TfLiteTensor* tensor); - // Copies input tensor contents to the resource buffer. + // Copies input_buffer contents to the resource buffer. // AllocateResourceVariable with a TFLite tensor must have been called first // in order to allocate the resource buffer. - TfLiteStatus Assign(int id, const TfLiteEvalTensor* tensor); + TfLiteStatus Assign(int id, size_t count_bytes, const void* input_buffer); // Zeros out all resource buffers. TfLiteStatus ResetAll(); diff --git a/tensorflow/lite/micro/micro_resource_variable_test.cc b/tensorflow/lite/micro/micro_resource_variable_test.cc index 13868bb440d..a30718cb994 100644 --- a/tensorflow/lite/micro/micro_resource_variable_test.cc +++ b/tensorflow/lite/micro/micro_resource_variable_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/lite/micro/micro_resource_variable.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" @@ -120,7 +121,9 @@ TF_LITE_MICRO_TEST(VerifyAssignAndReadResourceBuffer) { .type = kTfLiteFloat32, }; - resource_variables->Assign(id, &assign_tensor); + resource_variables->Assign( + id, tflite::EvalTensorBytes(&assign_tensor), + tflite::micro::GetTensorData(&assign_tensor)); int32_t buffer[32]; TfLiteEvalTensor read_tensor = { diff --git a/tensorflow/lite/micro/micro_utils.h b/tensorflow/lite/micro/micro_utils.h index 98ef81dc8ed..b362d3402bb 100644 --- a/tensorflow/lite/micro/micro_utils.h +++ b/tensorflow/lite/micro/micro_utils.h @@ -90,12 +90,19 @@ void SymmetricQuantize(const float* input, T* output, int num_elements, template void SymmetricPerChannelQuantize(const float* input, T* output, int num_elements, int num_channels, - float* scales) { + float* scales, + size_t quantized_dimension = 0) { int elements_per_channel = num_elements / num_channels; for (int i = 0; i < num_channels; i++) { for (int j = 0; j < elements_per_channel; j++) { - output[i * elements_per_channel + j] = FloatToSymmetricQuantizedType( - input[i * elements_per_channel + j], scales[i]); + size_t offset; + if (quantized_dimension == 0) { + offset = i * elements_per_channel + j; + } else { + offset = i + elements_per_channel * j; + } + output[offset] = + FloatToSymmetricQuantizedType(input[offset], scales[i]); } } } diff --git a/tensorflow/lite/micro/recording_micro_allocator.cc b/tensorflow/lite/micro/recording_micro_allocator.cc index ee76196d255..18addaee5f7 100644 --- a/tensorflow/lite/micro/recording_micro_allocator.cc +++ b/tensorflow/lite/micro/recording_micro_allocator.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -78,14 +78,15 @@ RecordedAllocation RecordingMicroAllocator::GetRecordedAllocation( return recorded_node_and_registration_array_data_; case RecordedAllocationType::kOpData: return recorded_op_data_; - // the function MicroPrintf was never reached outside the switch, because - // each case has a return. As the intention of the MicroPrintf is to be - // called when no matching case is found, a default case was added to - // contemplate an invalid allocation type +#ifdef USE_TFLM_COMPRESSION + case RecordedAllocationType::kCompressionData: + return recorded_compression_data_; +#endif // USE_TFLM_COMPRESSION default: - MicroPrintf("Invalid allocation type supplied: %d", allocation_type); - return RecordedAllocation(); + break; } + MicroPrintf("Invalid allocation type supplied: %d", allocation_type); + return RecordedAllocation(); } const RecordingSingleArenaBufferAllocator* @@ -117,6 +118,13 @@ void RecordingMicroAllocator::PrintAllocations() const { "NodeAndRegistration structs"); PrintRecordedAllocation(RecordedAllocationType::kOpData, "Operator runtime data", "OpData structs"); + +#ifdef USE_TFLM_COMPRESSION + + PrintRecordedAllocation(RecordedAllocationType::kCompressionData, + "Persistent compression data", "allocations"); + +#endif // USE_TFLM_COMPRESSION } void* RecordingMicroAllocator::AllocatePersistentBuffer(size_t bytes) { @@ -233,6 +241,21 @@ TfLiteStatus RecordingMicroAllocator::PopulateTfLiteTensorFromFlatbuffer( return status; } +#ifdef USE_TFLM_COMPRESSION + +TfLiteStatus RecordingMicroAllocator::AllocateCompressedTensorsList( + const Model* model, SubgraphAllocations* subgraph_allocations) { + RecordedAllocation allocations = SnapshotAllocationUsage(); + + TfLiteStatus status = MicroAllocator::AllocateCompressedTensorsList( + model, subgraph_allocations); + + RecordAllocationUsage(allocations, recorded_compression_data_); + return status; +} + +#endif // USE_TFLM_COMPRESSION + RecordedAllocation RecordingMicroAllocator::SnapshotAllocationUsage() const { return {/*requested_bytes=*/recording_memory_allocator_->GetRequestedBytes(), /*used_bytes=*/recording_memory_allocator_->GetUsedBytes(), diff --git a/tensorflow/lite/micro/recording_micro_allocator.h b/tensorflow/lite/micro/recording_micro_allocator.h index b6f69264dc0..80f163240d3 100644 --- a/tensorflow/lite/micro/recording_micro_allocator.h +++ b/tensorflow/lite/micro/recording_micro_allocator.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,6 +33,11 @@ enum class RecordedAllocationType { kTfLiteTensorVariableBufferData, kNodeAndRegistrationArray, kOpData, +#ifdef USE_TFLM_COMPRESSION + kCompressionData, +#endif // USE_TFLM_COMPRESSION + + kNumAllocationTypes, // must be last }; // Container for holding information about allocation recordings by a given @@ -93,6 +98,13 @@ class RecordingMicroAllocator : public MicroAllocator { int subgraph_index, bool allocate_temp) override; +#ifdef USE_TFLM_COMPRESSION + + TfLiteStatus AllocateCompressedTensorsList( + const Model* model, SubgraphAllocations* subgraph_allocations) override; + +#endif // USE_TFLM_COMPRESSION + private: RecordingMicroAllocator(RecordingSingleArenaBufferAllocator* memory_allocator, MicroMemoryPlanner* memory_planner); @@ -113,6 +125,9 @@ class RecordingMicroAllocator : public MicroAllocator { RecordedAllocation recorded_persistent_buffer_data_ = {}; RecordedAllocation recorded_tflite_tensor_variable_buffer_data_ = {}; RecordedAllocation recorded_node_and_registration_array_data_ = {}; +#ifdef USE_TFLM_COMPRESSION + RecordedAllocation recorded_compression_data_ = {}; +#endif // USE_TFLM_COMPRESSION // TODO(b/187993291): Re-enable OpData allocating tracking. RecordedAllocation recorded_op_data_ = {}; diff --git a/tensorflow/lite/micro/recording_micro_allocator_test.cc b/tensorflow/lite/micro/recording_micro_allocator_test.cc index 9d3a5965de4..121a74c3324 100644 --- a/tensorflow/lite/micro/recording_micro_allocator_test.cc +++ b/tensorflow/lite/micro/recording_micro_allocator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -317,6 +317,72 @@ TF_LITE_MICRO_TEST(TestMultiSubgraphModel) { num_tensors * TF_LITE_EVAL_TENSOR_STRUCT_SIZE); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(TestCompressedModel) { + tflite::ScratchBufferHandle* scratch_buffer_handles = nullptr; + tflite::testing::TestingOpResolver ops_resolver; + const tflite::Model* model = tflite::testing::GetSimpleMockModelCompressed(); + const int arena_size = 2048; + + uint8_t arena[arena_size]; + + tflite::RecordingMicroAllocator* micro_allocator = + tflite::RecordingMicroAllocator::Create(arena, arena_size); + TF_LITE_MICRO_EXPECT(micro_allocator != nullptr); + TF_LITE_MICRO_CHECK_FAIL(); + + tflite::SubgraphAllocations* subgraph_allocations = + micro_allocator->StartModelAllocation(model); + TF_LITE_MICRO_EXPECT(nullptr != subgraph_allocations); + TF_LITE_MICRO_CHECK_FAIL(); + + TfLiteStatus status = micro_allocator->FinishModelAllocation( + model, subgraph_allocations, &scratch_buffer_handles); + TF_LITE_MICRO_EXPECT_EQ(status, kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + + micro_allocator->PrintAllocations(); + + size_t count_compression_allocations = 0; + size_t size_compression_allocations = 0; + for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs()->size(); + subgraph_idx++) { + const tflite::CompressionTensorData** ctl = + subgraph_allocations[subgraph_idx].compressed.tensors; + if (ctl == nullptr) { + continue; + } + const tflite::SubGraph* subgraph = model->subgraphs()->Get(subgraph_idx); + const size_t num_tensors = subgraph->tensors()->size(); + for (size_t i = 0; i < num_tensors; i++) { + if (ctl[i] != nullptr) { + count_compression_allocations++; + size_compression_allocations += sizeof(tflite::CompressionTensorData); + count_compression_allocations++; + size_compression_allocations += sizeof(tflite::LookupTableData); + } + } + // Add the CompressionTensorData array + count_compression_allocations++; + size_compression_allocations += + num_tensors * sizeof(tflite::CompressionTensorData*); + } + + tflite::RecordedAllocation recorded_allocation = + micro_allocator->GetRecordedAllocation( + tflite::RecordedAllocationType::kCompressionData); + + TF_LITE_MICRO_EXPECT_EQ(recorded_allocation.count, + count_compression_allocations); + TF_LITE_MICRO_EXPECT_EQ(recorded_allocation.requested_bytes, + size_compression_allocations); + TF_LITE_MICRO_EXPECT_GE(recorded_allocation.used_bytes, + size_compression_allocations); +} + +#endif // USE_TFLM_COMPRESSION + // TODO(b/158124094): Find a way to audit OpData allocations on // cross-architectures. diff --git a/tensorflow/lite/micro/test_helper_custom_ops.cc b/tensorflow/lite/micro/test_helper_custom_ops.cc index 374aabcc9df..4d71f7e3f8c 100644 --- a/tensorflow/lite/micro/test_helper_custom_ops.cc +++ b/tensorflow/lite/micro/test_helper_custom_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,6 +35,18 @@ limitations under the License. namespace tflite { namespace testing { +namespace { + +template +void BroadcastAdd(const T input_scalar, const T* weights, T* output, + const size_t count) { + for (size_t i = 0; i < count; i++) { + output[i] = input_scalar + weights[i]; + } +} + +} // namespace + const TFLMRegistration* PackerOp::getRegistration() { return GetMutableRegistration(); } @@ -107,5 +119,178 @@ TfLiteStatus PackerOp::Invoke(TfLiteContext* context, TfLiteNode* node) { bool PackerOp::freed_ = false; +const TFLMRegistration* BroadcastAddOp::getRegistration() { + return GetMutableRegistration(); +} + +TFLMRegistration* BroadcastAddOp::GetMutableRegistration() { + static TFLMRegistration r; + r.init = Init; + r.prepare = Prepare; + r.invoke = Invoke; + return &r; +} + +void* BroadcastAddOp::Init(TfLiteContext* context, const char* buffer, + size_t length) { +#ifdef USE_TFLM_COMPRESSION + + weight_scratch_index_ = -1; + +#endif // USE_TFLM_COMPRESSION + + // Do nothing. + return nullptr; +} + +TfLiteStatus BroadcastAddOp::Prepare(TfLiteContext* context, TfLiteNode* node) { + MicroContext* micro_context = GetMicroContext(context); + + TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0); + TF_LITE_ENSURE(context, input != nullptr); + TfLiteTensor* weights = micro_context->AllocateTempInputTensor(node, 1); + TF_LITE_ENSURE(context, weights != nullptr); + TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0); + TF_LITE_ENSURE(context, output != nullptr); + + TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); + TF_LITE_ENSURE_TYPES_EQ(context, input->type, weights->type); + TF_LITE_ENSURE( + context, input->type == kTfLiteFloat32 || input->type == kTfLiteInt8 || + input->type == kTfLiteInt16 || input->type == kTfLiteInt32 || + input->type == kTfLiteInt64); + TF_LITE_ENSURE(context, input->quantization.type == kTfLiteNoQuantization); + TF_LITE_ENSURE(context, weights->quantization.type == kTfLiteNoQuantization); + TF_LITE_ENSURE(context, output->quantization.type == kTfLiteNoQuantization); + TF_LITE_ENSURE(context, + ElementCount(*weights->dims) == ElementCount(*output->dims)); + TF_LITE_ENSURE(context, ElementCount(*input->dims) == 1); + TF_LITE_ENSURE(context, input->dims->size == 1); + TF_LITE_ENSURE(context, weights->dims->size == 1); + +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + weight_scratch_index_ = + micro_context->AllocateDecompressionScratchBuffer(node, 1); + if (!micro_context->IsTensorCompressed(node, 1)) { + TF_LITE_ENSURE(context, weight_scratch_index_ == -1); + } + +#endif // USE_TFLM_COMPRESSION + + micro_context->DeallocateTempTfLiteTensor(input); + micro_context->DeallocateTempTfLiteTensor(weights); + micro_context->DeallocateTempTfLiteTensor(output); + + return kTfLiteOk; +} + +TfLiteStatus BroadcastAddOp::Invoke(TfLiteContext* context, TfLiteNode* node) { + const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); + TF_LITE_ENSURE(context, input != nullptr); + const TfLiteEvalTensor* weights = + tflite::micro::GetEvalInput(context, node, 1); + TF_LITE_ENSURE(context, weights != nullptr); + TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); + TF_LITE_ENSURE(context, output != nullptr); + +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, 1); + if (micro_context->IsTensorCompressed(node, 1)) { + TF_LITE_ENSURE(context, weights_comp_td != nullptr); + } else { + TF_LITE_ENSURE(context, weights_comp_td == nullptr); + } + +#endif // USE_TFLM_COMPRESSION + + switch (input->type) { + case kTfLiteFloat32: { + BroadcastAdd( + tflite::micro::GetTensorData(input)[0], +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, weights, weights_comp_td, weight_scratch_index_), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(weights), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(output), + ElementCount(*output->dims)); + } break; + + case kTfLiteInt8: { + BroadcastAdd( + tflite::micro::GetTensorData(input)[0], +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, weights, weights_comp_td, weight_scratch_index_), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(weights), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(output), + ElementCount(*output->dims)); + } break; + + case kTfLiteInt16: { + BroadcastAdd( + tflite::micro::GetTensorData(input)[0], +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, weights, weights_comp_td, weight_scratch_index_), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(weights), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(output), + ElementCount(*output->dims)); + } break; + + case kTfLiteInt32: { + BroadcastAdd( + tflite::micro::GetTensorData(input)[0], +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, weights, weights_comp_td, weight_scratch_index_), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(weights), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(output), + ElementCount(*output->dims)); + } break; + + case kTfLiteInt64: { + BroadcastAdd( + tflite::micro::GetTensorData(input)[0], +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, weights, weights_comp_td, weight_scratch_index_), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(weights), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(output), + ElementCount(*output->dims)); + } break; + + default: { + MicroPrintf("Input type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + } + + return kTfLiteOk; +} + +#ifdef USE_TFLM_COMPRESSION + +int BroadcastAddOp::weight_scratch_index_ = -1; + +#endif // USE_TFLM_COMPRESSION + } // namespace testing } // namespace tflite diff --git a/tensorflow/lite/micro/test_helper_custom_ops.h b/tensorflow/lite/micro/test_helper_custom_ops.h index d28bb4038f1..53a8cc3bdd4 100644 --- a/tensorflow/lite/micro/test_helper_custom_ops.h +++ b/tensorflow/lite/micro/test_helper_custom_ops.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -43,6 +43,23 @@ class PackerOp { static bool freed_; }; +// This op optionally supports compressed weights +class BroadcastAddOp { + public: + static const TFLMRegistration* getRegistration(); + static TFLMRegistration* GetMutableRegistration(); + static void* Init(TfLiteContext* context, const char* buffer, size_t length); + static TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node); + static TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node); + + private: +#ifdef USE_TFLM_COMPRESSION + + static int weight_scratch_index_; // decompression scratch buffer index + +#endif // USE_TFLM_COMPRESSION +}; + } // namespace testing } // namespace tflite diff --git a/tensorflow/lite/micro/test_helpers.cc b/tensorflow/lite/micro/test_helpers.cc index 3f0f5ec0826..9faa991dc9e 100644 --- a/tensorflow/lite/micro/test_helpers.cc +++ b/tensorflow/lite/micro/test_helpers.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/micro/test_helpers.h" +#include #include #include #include @@ -33,6 +34,12 @@ limitations under the License. #include "tensorflow/lite/micro/test_helper_custom_ops.h" #include "tensorflow/lite/schema/schema_generated.h" +#ifdef USE_TFLM_COMPRESSION + +#include "tensorflow/lite/micro/compression/metadata_saved.h" + +#endif // USE_TFLM_COMPRESSION + // TODO(b/170464050): Use TFLM test only version of schema_utils. namespace tflite { @@ -236,7 +243,7 @@ const Model* ModelBuilder::BuildModel( *builder_, 0, builder_->CreateVector(operator_codes_, next_operator_code_id_), builder_->CreateVector(subgraphs, subgraphs_size), - builder_->CreateString("teset_model"), + builder_->CreateString("test_model"), builder_->CreateVector(buffers, buffer_size), 0, builder_->CreateVector(metadata_, ModelBuilder::nbr_of_metadata_buffers_)); @@ -245,7 +252,7 @@ const Model* ModelBuilder::BuildModel( *builder_, 0, builder_->CreateVector(operator_codes_, next_operator_code_id_), builder_->CreateVector(subgraphs, subgraphs_size), - builder_->CreateString("teset_model"), + builder_->CreateString("test_model"), builder_->CreateVector(buffers, buffer_size)); } @@ -578,6 +585,117 @@ const Model* BuildSimpleMockModel() { return model; } +#ifdef USE_TFLM_COMPRESSION + +const flatbuffers::span BuildLutMetadata( + uint32_t tensor_index, uint32_t value_table_buffer_index, + uint32_t bit_width) { + using flatbuffers::Offset; + namespace compression = tflite::micro::compression; + + flatbuffers::FlatBufferBuilder* builder = BuilderInstance(); + + auto lut_tensor = compression::CreateLutTensor( + *builder, tensor_index, value_table_buffer_index, bit_width); + auto subgraph = compression::CreateSubgraph( + *builder, builder->CreateVector(&lut_tensor, 1)); + constexpr uint32_t schema_version = 1; + auto metadata = compression::CreateMetadata( + *builder, schema_version, builder->CreateVector(&subgraph, 1)); + compression::FinishMetadataBuffer(*builder, metadata); + return builder->GetBufferSpan(); +} + +const Model* BuildSimpleMockModelCompressed() { + using flatbuffers::Offset; + using flatbuffers::Vector; + using tflite::micro::compression::LutTensor; + constexpr uint32_t kEmptyBuffer = 0; + constexpr uint32_t kMetadataBuffer = 1; + constexpr uint32_t kWeightsBuffer = 2; + constexpr uint32_t kValueTableBuffer = 3; + // constexpr uint32_t kInputTensor = 0; + constexpr uint32_t kWeightsTensor = 1; + // constexpr uint32_t kOutputTensor = 2; + constexpr uint32_t kCompressedBitWidth = 4; + + auto lut_tensors_span = + BuildLutMetadata(kWeightsTensor, kValueTableBuffer, kCompressedBitWidth); + + flatbuffers::FlatBufferBuilder* builder = BuilderInstance(); + + // [1, 2, 3, 4, 5, -1, -2, -3, -4, -5, 1, 2, 3, 4, 5] + const std::initializer_list weights_data = {0x01, 0x23, 0x45, 0x98, + 0x76, 0x01, 0x23, 0x40}; + const std::initializer_list value_table_data = {1, 2, 3, 4, 5, + -1, -5, -4, -3, -2}; + auto value_table_offset = builder->CreateVector(value_table_data).o; + const std::initializer_list> buffers = { + CreateBuffer(*builder), + CreateBuffer(*builder, builder->CreateVector(lut_tensors_span)), + CreateBuffer(*builder, builder->CreateVector(weights_data)), + CreateBuffer(*builder, Offset>(value_table_offset)), + }; + + const std::initializer_list input_shape = {1}; + const std::initializer_list weights_shape = {15}; + const std::initializer_list output_shape = weights_shape; + const std::initializer_list> tensors = { + CreateTensor(*builder, builder->CreateVector(input_shape), + TensorType_INT16, kEmptyBuffer, + builder->CreateString("test_input_tensor"), 0, false), + CreateTensor(*builder, builder->CreateVector(weights_shape), + TensorType_INT16, kWeightsBuffer, + builder->CreateString("test_weight_tensor"), 0, false), + CreateTensor(*builder, builder->CreateVector(output_shape), + TensorType_INT16, kEmptyBuffer, + builder->CreateString("test_output_tensor"), 0, false), + }; + + const std::initializer_list subgraph_inputs = {0}; + const std::initializer_list subgraph_outputs = {2}; + const std::initializer_list operator_inputs = {0, 1}; + const std::initializer_list operator_outputs = {2}; + const std::initializer_list> operators = { + CreateOperator(*builder, 0, builder->CreateVector(operator_inputs), + builder->CreateVector(operator_outputs), + BuiltinOptions_NONE), + }; + + const std::initializer_list> subgraphs = { + CreateSubGraph(*builder, builder->CreateVector(tensors), + builder->CreateVector(subgraph_inputs), + builder->CreateVector(subgraph_outputs), + builder->CreateVector(operators), + builder->CreateString("test_subgraph")), + }; + + const std::initializer_list> operator_codes = { + CreateOperatorCodeDirect(*builder, /*deprecated_builtin_code=*/0, + "broadcast_add_op", + /*version=*/0, BuiltinOperator_CUSTOM), + }; + + const std::initializer_list> metadata = { + CreateMetadata(*builder, + builder->CreateString(kCompressionMetadataString), + kMetadataBuffer), + }; + + const Offset model_offset = CreateModel( + *builder, 0, builder->CreateVector(operator_codes), + builder->CreateVector(subgraphs), builder->CreateString("test_model"), + builder->CreateVector(buffers), 0, builder->CreateVector(metadata)); + + FinishModelBuffer(*builder, model_offset); + void* model_pointer = builder->GetBufferPointer(); + const Model* model = flatbuffers::GetRoot(model_pointer); + + return model; +} + +#endif // USE_TFLM_COMPRESSION + const Model* BuildComplexMockModel() { using flatbuffers::Offset; flatbuffers::FlatBufferBuilder* builder = BuilderInstance(); @@ -1665,6 +1783,8 @@ TfLiteStatus GetTestingOpResolver( op_resolver.AddCustom("no_op", NoOp::GetMutableRegistration())); TF_LITE_ENSURE_STATUS(op_resolver.AddCustom( "custom_packer_op", PackerOp::GetMutableRegistration())); + TF_LITE_ENSURE_STATUS(op_resolver.AddCustom( + "broadcast_add_op", BroadcastAddOp::GetMutableRegistration())); TF_LITE_ENSURE_STATUS(op_resolver.AddIf()); return kTfLiteOk; } @@ -1698,6 +1818,18 @@ const Model* GetSimpleMockModel() { return model; } +#ifdef USE_TFLM_COMPRESSION + +const Model* GetSimpleMockModelCompressed() { + static Model* model = nullptr; + if (!model) { + model = const_cast(BuildSimpleMockModelCompressed()); + } + return model; +} + +#endif // USE_TFLM_COMPRESSION + const Model* GetSimpleMultipleInputsModel() { static Model* model = nullptr; if (!model) { @@ -1890,100 +2022,6 @@ TfLiteFloatArray* FloatArrayFromFloats(const float* floats) { return reinterpret_cast(const_cast(floats)); } -TfLiteTensor CreateQuantizedBiasTensor(const float* data, int16_t* quantized, - TfLiteIntArray* dims, float input_scale, - float weights_scale, bool is_variable) { - float bias_scale = input_scale * weights_scale; - tflite::SymmetricQuantize(data, quantized, ElementCount(*dims), bias_scale); - - // Quantized int16_t tensors always have a zero point of 0, since the range of - // int16_t values is large, and because zero point costs extra cycles during - // processing. - TfLiteTensor result = - CreateQuantizedTensor(quantized, dims, bias_scale, 0, is_variable); - return result; -} - -TfLiteTensor CreateQuantizedBiasTensor(const float* data, int32_t* quantized, - TfLiteIntArray* dims, float input_scale, - float weights_scale, bool is_variable) { - float bias_scale = input_scale * weights_scale; - tflite::SymmetricQuantize(data, quantized, ElementCount(*dims), bias_scale); - - // Quantized int32_t tensors always have a zero point of 0, since the range of - // int32_t values is large, and because zero point costs extra cycles during - // processing. - TfLiteTensor result = - CreateQuantizedTensor(quantized, dims, bias_scale, 0, is_variable); - return result; -} - -TfLiteTensor CreateQuantizedBiasTensor(const float* data, - std::int64_t* quantized, - TfLiteIntArray* dims, float input_scale, - float weights_scale, bool is_variable) { - float bias_scale = input_scale * weights_scale; - tflite::SymmetricQuantize(data, quantized, ElementCount(*dims), bias_scale); - - // Quantized int32_t tensors always have a zero point of 0, since the range of - // int32_t values is large, and because zero point costs extra cycles during - // processing. - TfLiteTensor result = - CreateQuantizedTensor(quantized, dims, bias_scale, 0, is_variable); - return result; -} - -// Quantizes int32_t bias tensor with per-channel weights determined by input -// scale multiplied by weight scale for each channel. -template -TfLiteTensor CreatePerChannelQuantizedBiasTensor( - const float* input, T* quantized, TfLiteIntArray* dims, float input_scale, - float* weight_scales, float* scales, int* zero_points, - TfLiteAffineQuantization* affine_quant, int quantized_dimension, - bool is_variable) { - int input_size = ElementCount(*dims); - int num_channels = dims->data[quantized_dimension]; - // First element is reserved for array length - zero_points[0] = num_channels; - scales[0] = static_cast(num_channels); - float* scales_array = &scales[1]; - for (int i = 0; i < num_channels; i++) { - scales_array[i] = input_scale * weight_scales[i]; - zero_points[i + 1] = 0; - } - - SymmetricPerChannelQuantize(input, quantized, input_size, num_channels, - scales_array); - - affine_quant->scale = FloatArrayFromFloats(scales); - affine_quant->zero_point = IntArrayFromInts(zero_points); - affine_quant->quantized_dimension = quantized_dimension; - - TfLiteTensor result = CreateTensor(quantized, dims, is_variable); - result.quantization = {kTfLiteAffineQuantization, affine_quant}; - return result; -} - -TfLiteTensor CreatePerChannelQuantizedBiasTensor( - const float* input, int32_t* quantized, TfLiteIntArray* dims, - float input_scale, float* weight_scales, float* scales, int* zero_points, - TfLiteAffineQuantization* affine_quant, int quantized_dimension, - bool is_variable) { - return CreatePerChannelQuantizedBiasTensor( - input, quantized, dims, input_scale, weight_scales, scales, zero_points, - affine_quant, quantized_dimension, is_variable); -} - -TfLiteTensor CreatePerChannelQuantizedBiasTensor( - const float* input, std::int64_t* quantized, TfLiteIntArray* dims, - float input_scale, float* weight_scales, float* scales, int* zero_points, - TfLiteAffineQuantization* affine_quant, int quantized_dimension, - bool is_variable) { - return CreatePerChannelQuantizedBiasTensor( - input, quantized, dims, input_scale, weight_scales, scales, zero_points, - affine_quant, quantized_dimension, is_variable); -} - TfLiteTensor CreateSymmetricPerChannelQuantizedTensor( const float* input, int8_t* quantized, TfLiteIntArray* dims, float* scales, int* zero_points, TfLiteAffineQuantization* affine_quant, diff --git a/tensorflow/lite/micro/test_helpers.h b/tensorflow/lite/micro/test_helpers.h index 6315b9fecdc..86eaf778f7b 100644 --- a/tensorflow/lite/micro/test_helpers.h +++ b/tensorflow/lite/micro/test_helpers.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,6 +31,13 @@ limitations under the License. #include "tensorflow/lite/portable_type_to_tflitetype.h" #include "tensorflow/lite/schema/schema_generated.h" +#ifdef USE_TFLM_COMPRESSION + +#include "tensorflow/lite/micro/compression.h" +#include "tensorflow/lite/micro/micro_log.h" + +#endif // TENSORFLOW_LITE_MICRO_TEST_HELPERS_H_ + namespace tflite { namespace testing { @@ -112,6 +119,15 @@ TfLiteStatus GetTestingOpResolver(TestingOpResolver& op_resolver); // 1 layer of weights, 1 output Tensor, and 1 operator. const Model* GetSimpleMockModel(); +#ifdef USE_TFLM_COMPRESSION + +// Returns a simple example flatbuffer TensorFlow Lite model. Contains 1 input, +// 1 layer of weights, 1 output Tensor, and 1 operator (BroadcastAddOp). The +// weights tensor is compressed. +const Model* GetSimpleMockModelCompressed(); + +#endif // USE_TFLM_COMPRESSION + // Returns a flatbuffer TensorFlow Lite model with more inputs, variable // tensors, and operators. const Model* GetComplexMockModel(); @@ -220,8 +236,6 @@ TfLiteTensor CreateTensor(const T* data, TfLiteIntArray* dims, result.is_variable = is_variable; result.allocation_type = kTfLiteMemNone; result.data.data = const_cast(data); - result.bytes = ElementCount(*dims) * sizeof(T); - result.data.data = const_cast(data); if (type == kTfLiteInt4) { result.type = kTfLiteInt4; @@ -233,7 +247,13 @@ TfLiteTensor CreateTensor(const T* data, TfLiteIntArray* dims, // a single CreateTensor method. A Const array should be used for immutable // input tensors and non-const array should be used for mutable and output // tensors. - result.type = typeToTfLiteType(); + if (type == kTfLiteNoType) { + result.type = typeToTfLiteType(); + } else { + result.type = type; + } + + result.bytes = ElementCount(*dims) * TfLiteTypeGetSize(result.type); } return result; } @@ -260,37 +280,95 @@ TfLiteTensor CreateQuantizedTensor(const float* input, T* quantized, type); } -TfLiteTensor CreateQuantizedBiasTensor(const float* data, int16_t* quantized, +template +TfLiteTensor CreateQuantizedBiasTensor(const float* data, T* quantized, TfLiteIntArray* dims, float input_scale, float weights_scale, - bool is_variable = false); + bool is_variable = false) { + float bias_scale = input_scale * weights_scale; + tflite::SymmetricQuantize(data, quantized, ElementCount(*dims), bias_scale); + + // Quantized bias tensors always have a zero point of 0, since the range of + // values is large, and because zero point costs extra cycles during + // processing. + TfLiteTensor result = + CreateQuantizedTensor(quantized, dims, bias_scale, 0, is_variable); + return result; +} -TfLiteTensor CreateQuantizedBiasTensor(const float* data, int32_t* quantized, - TfLiteIntArray* dims, float input_scale, - float weights_scale, - bool is_variable = false); +// Creates bias tensor with input data, and per-channel weights determined by +// input scale multiplied by weight scale for each channel. Input data will not +// be quantized. +template +TfLiteTensor CreatePerChannelQuantizedBiasTensor( + const T* input_data, TfLiteIntArray* dims, float input_scale, + const TfLiteFloatArray* weight_scales, TfLiteFloatArray* scales, + TfLiteIntArray* zero_points, TfLiteAffineQuantization* affine_quant, + int quantized_dimension, bool is_variable = false, + TfLiteType type = kTfLiteNoType) { + int num_channels = dims->data[quantized_dimension]; + zero_points->size = num_channels; + scales->size = num_channels; + for (int i = 0; i < num_channels; i++) { + scales->data[i] = input_scale * weight_scales->data[i]; + zero_points->data[i] = 0; + } -TfLiteTensor CreateQuantizedBiasTensor(const float* data, - std::int64_t* quantized, - TfLiteIntArray* dims, float input_scale, - float weights_scale, - bool is_variable = false); + affine_quant->scale = scales; + affine_quant->zero_point = zero_points; + affine_quant->quantized_dimension = quantized_dimension; -// Quantizes int32_t bias tensor with per-channel weights determined by input -// scale multiplied by weight scale for each channel. -TfLiteTensor CreatePerChannelQuantizedBiasTensor( - const float* input, int32_t* quantized, TfLiteIntArray* dims, - float input_scale, float* weight_scales, float* scales, int* zero_points, - TfLiteAffineQuantization* affine_quant, int quantized_dimension, - bool is_variable = false); + TfLiteTensor result = CreateTensor(input_data, dims, is_variable, type); + result.quantization = {kTfLiteAffineQuantization, affine_quant}; + return result; +} -// Quantizes int64_t bias tensor with per-channel weights determined by input +// Quantizes bias tensor with per-channel weights determined by input // scale multiplied by weight scale for each channel. +template TfLiteTensor CreatePerChannelQuantizedBiasTensor( - const float* input, std::int64_t* quantized, TfLiteIntArray* dims, - float input_scale, float* weight_scales, float* scales, int* zero_points, + const float* input, T* quantized, TfLiteIntArray* dims, float input_scale, + const float* weight_scales, float* scales, int* zero_points, TfLiteAffineQuantization* affine_quant, int quantized_dimension, - bool is_variable = false); + bool is_variable = false) { + int input_size = ElementCount(*dims); + int num_channels = dims->data[quantized_dimension]; + // First element is reserved for array length + zero_points[0] = num_channels; + scales[0] = static_cast(num_channels); + float* scales_array = &scales[1]; + for (int i = 0; i < num_channels; i++) { + scales_array[i] = input_scale * weight_scales[i]; + zero_points[i + 1] = 0; + } + + SymmetricPerChannelQuantize(input, quantized, input_size, num_channels, + scales_array); + + affine_quant->scale = FloatArrayFromFloats(scales); + affine_quant->zero_point = IntArrayFromInts(zero_points); + affine_quant->quantized_dimension = quantized_dimension; + + TfLiteTensor result = CreateTensor(quantized, dims, is_variable); + result.quantization = {kTfLiteAffineQuantization, affine_quant}; + + return result; +} + +template +TfLiteTensor CreatePerChannelQuantizedTensor( + const T* quantized, TfLiteIntArray* dims, TfLiteFloatArray* scales, + TfLiteIntArray* zero_points, TfLiteAffineQuantization* affine_quant, + int quantized_dimension, bool is_variable = false, + TfLiteType type = kTfLiteNoType) { + affine_quant->scale = scales; + affine_quant->zero_point = zero_points; + affine_quant->quantized_dimension = quantized_dimension; + + TfLiteTensor result = CreateTensor(quantized, dims, is_variable, type); + result.quantization = {kTfLiteAffineQuantization, affine_quant}; + return result; +} TfLiteTensor CreateSymmetricPerChannelQuantizedTensor( const float* input, int8_t* quantized, TfLiteIntArray* dims, float* scales, @@ -329,6 +407,128 @@ inline int ZeroPointFromMinMax(const float min, const float max) { static_cast(roundf(-min / ScaleFromMinMax(min, max))); } +#ifdef USE_TFLM_COMPRESSION + +template +struct TestCompressionInfo { + T* value_table; + size_t value_table_stride; + int bit_width; + CompressionScheme scheme; +}; + +template +struct TestCompressionQuantizedInfo : TestCompressionInfo { + const uint8_t* compressed; + const float* data; + const int* dims_data; // TfLiteIntArray + const float* scales; // TfLiteFloatArray (may be computed) + const int* zero_points; // TfLiteIntArray (may be computed) +}; + +template +class TestCompressedList { + public: + template + TfLiteStatus AddInput(const TestCompressionInfo& tci, + const TfLiteTensor& tensor, const size_t tensor_index) { + if (next_input_index_ >= NINPUTS) { + MicroPrintf("TestCompressedList: too many inputs, max %u", NINPUTS); + return kTfLiteError; + } + inputs_comp_data_[next_input_index_].data.lut_data = + &inputs_ltd_[next_input_index_]; + inputs_comp_data_[next_input_index_].scheme = tci.scheme; + inputs_comp_data_[next_input_index_].data.lut_data->compressed_bit_width = + tci.bit_width; + inputs_comp_data_[next_input_index_].data.lut_data->value_table = + tci.value_table; + inputs_comp_data_[next_input_index_] + .data.lut_data->value_table_channel_stride = tci.value_table_stride; + inputs_comp_data_[next_input_index_] + .data.lut_data->is_per_channel_quantized = + IsPerChannelQuantized(tensor); + inputs_comp_data_[next_input_index_].data.lut_data->use_alternate_axis = + UsesAltAxis(tensor); + return SetCompressionData(tensor_index, + inputs_comp_data_[next_input_index_++]); + } + + const CompressedTensorList* GetCompressedTensorList() { + if (next_input_index_ == 0) { + return nullptr; + } + + return &ctl_; + } + + private: + size_t next_input_index_ = 0; + LookupTableData inputs_ltd_[NINPUTS] = {}; + CompressionTensorData inputs_comp_data_[NINPUTS] = {}; + const CompressionTensorData* ctdp_[NTENSORS] = {}; + const CompressedTensorList ctl_ = {ctdp_}; + + TfLiteStatus SetCompressionData(const size_t tensor_index, + const CompressionTensorData& cd) { + if (tensor_index >= NTENSORS) { + MicroPrintf("TestCompressedList: bad tensor index %u", tensor_index); + return kTfLiteError; + } + if (cd.data.lut_data->value_table == nullptr) { + MicroPrintf("TestCompressedList: null value_table pointer"); + return kTfLiteError; + } + if (cd.data.lut_data->value_table_channel_stride == 0) { + MicroPrintf("TestCompressedList: value_table_channel_stride not set"); + return kTfLiteError; + } + if (cd.scheme != CompressionScheme::kBinQuant) { + MicroPrintf("TestCompressedList: unsupported compression scheme"); + return kTfLiteError; + } + if (ctdp_[tensor_index] != nullptr) { + MicroPrintf("TestCompressedList: tensor index %d already in use", + tensor_index); + return kTfLiteError; + } + + ctdp_[tensor_index] = &cd; + return kTfLiteOk; + } + + bool IsPerChannelQuantized(const TfLiteTensor& tensor) { + if (tensor.quantization.type == kTfLiteAffineQuantization && + tensor.quantization.params != nullptr) { + const TfLiteAffineQuantization* qp = + static_cast( + tensor.quantization.params); + if (qp->scale->size > 1) { + return true; + } + } + + return false; + } + + bool UsesAltAxis(const TfLiteTensor& tensor) { + if (tensor.quantization.type == kTfLiteAffineQuantization && + tensor.quantization.params != nullptr) { + const TfLiteAffineQuantization* qp = + static_cast( + tensor.quantization.params); + if (qp->quantized_dimension != 0) { + TFLITE_DCHECK_EQ(qp->quantized_dimension, tensor.dims->size - 1); + return true; + } + } + + return false; + } +}; + +#endif // USE_TFLM_COMPRESSION + } // namespace testing } // namespace tflite diff --git a/tensorflow/lite/micro/tools/benchmarking/Makefile.inc b/tensorflow/lite/micro/tools/benchmarking/Makefile.inc index 396e7016384..a79420cb982 100644 --- a/tensorflow/lite/micro/tools/benchmarking/Makefile.inc +++ b/tensorflow/lite/micro/tools/benchmarking/Makefile.inc @@ -20,6 +20,15 @@ endif $(GENERATED_SRCS_DIR)$(GENERIC_BENCHMARK_MODEL_DIR)$(GENERIC_BENCHMARK_MODEL_NAME)_model_data.h endif +ifeq ($(ENABLE_COMPRESSION), yes) +ifneq ($(GENERIC_BENCHMARK_ALT_MEM_ATTR),) + CXXFLAGS += -DGENERIC_BENCHMARK_ALT_MEM_ATTR=$(GENERIC_BENCHMARK_ALT_MEM_ATTR) +endif +ifneq ($(GENERIC_BENCHMARK_ALT_MEM_SIZE),) + CXXFLAGS += -DGENERIC_BENCHMARK_ALT_MEM_SIZE=$(GENERIC_BENCHMARK_ALT_MEM_SIZE) +endif +endif + GENERIC_BENCHMARK_SRCS := \ $(MICROLITE_BENCHMARK_ROOT_DIR)/generic_model_benchmark.cc \ $(MICROLITE_BENCHMARK_ROOT_DIR)/metrics.cc \ diff --git a/tensorflow/lite/micro/tools/benchmarking/collect_meta_data.sh b/tensorflow/lite/micro/tools/benchmarking/collect_meta_data.sh index c60bdf3ed72..424a1b8da65 100755 --- a/tensorflow/lite/micro/tools/benchmarking/collect_meta_data.sh +++ b/tensorflow/lite/micro/tools/benchmarking/collect_meta_data.sh @@ -52,7 +52,7 @@ function substitute_strings() { IFS=${SAVED_IFS} replacement=() for line in "${lines_array[@]}"; do - line=$(sed -e 's/"/\\"/g' <<< "${line}") + line=$(sed -e 's/\\/\\\\/g' -e 's/"/\\"/g' <<< "${line}") line=$(printf '"%s",\n ' "${line}") replacement+=( "${line}" ) done diff --git a/tensorflow/lite/micro/tools/benchmarking/generic_model_benchmark.cc b/tensorflow/lite/micro/tools/benchmarking/generic_model_benchmark.cc index f398963a00d..9af661fb3b8 100644 --- a/tensorflow/lite/micro/tools/benchmarking/generic_model_benchmark.cc +++ b/tensorflow/lite/micro/tools/benchmarking/generic_model_benchmark.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -56,19 +57,37 @@ limitations under the License. #endif // defind(GENERIC_BENCHMARK_USING_BUILTIN_MODEL) +#if defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) && \ + !defined(GENERIC_BENCHMARK_ALT_MEM_SIZE) +#error "GENERIC_BENCHMARK_ALT_MEM_SIZE missing from CXXFLAGS" +#endif // defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) && + // !defined(GENERIC_BENCHMARK_ALT_MEM_SIZE) + +#if defined(GENERIC_BENCHMARK_ALT_MEM_SIZE) && \ + !defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) +#error "GENERIC_BENCHMARK_ALT_MEM_ATTR missing from CXXFLAGS" +#endif // defined(GENERIC_BENCHMARK_ALT_MEM_SIZE) && + // !defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) + +#if defined(GENERIC_BENCHMARK_ALT_MEM_SIZE) && \ + defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) && defined(USE_TFLM_COMPRESSION) +#define USE_ALT_DECOMPRESSION_MEM +#endif // defined(GENERIC_BENCHMARK_ALT_MEM_SIZE) && + // defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) && + // defined(USE_TFLM_COMPRESSION) + /* - * Generic model benchmark. Evaluates runtime performance of a provided model - * with random inputs. + * Generic model benchmark. Evaluates runtime performance of a provided + * model with random inputs. */ namespace tflite { - namespace { using Profiler = ::tflite::MicroProfiler; -// Seed used for the random input. Input data shouldn't affect invocation timing -// so randomness isn't really needed. +// Seed used for the random input. Input data shouldn't affect invocation +// timing so randomness isn't really needed. constexpr uint32_t kRandomSeed = 0xFB; #if !defined(GENERIC_BENCHMARK_USING_BUILTIN_MODEL) @@ -80,6 +99,11 @@ constexpr size_t kTensorArenaSize = GENERIC_BENCHMARK_TENSOR_ARENA_SIZE; constexpr size_t kTensorArenaSize = 5e6 - MODEL_SIZE; #endif // !defined(GENERIC_BENCHMARK_USING_BUILTIN_MODEL) +#if defined(USE_ALT_DECOMPRESSION_MEM) +constexpr size_t kAltMemorySize = GENERIC_BENCHMARK_ALT_MEM_SIZE; +alignas(16) GENERIC_BENCHMARK_ALT_MEM_ATTR uint8_t g_alt_memory[kAltMemorySize]; +#endif // defined(USE_ALT_DECOMPRESSION_MEM) + constexpr int kNumResourceVariable = 100; void SetRandomInput(const uint32_t random_seed, @@ -130,39 +154,145 @@ bool ReadFile(const char* file_name, void* buffer, size_t buffer_size) { } #endif // !defined(GENERIC_BENCHMARK_USING_BUILTIN_MODEL) +uint32_t crctab[256]; + +void GenCRC32Table() { + constexpr uint32_t kPolyN = 0xEDB88320; + for (size_t index = 0; index < 256; index++) { + crctab[index] = index; + for (int i = 0; i < 8; i++) { + if (crctab[index] & 1) { + crctab[index] = (crctab[index] >> 1) ^ kPolyN; + } else { + crctab[index] >>= 1; + } + } + } +} + +uint32_t ComputeCRC32(const uint8_t* data, const size_t data_length) { + uint32_t crc32 = ~0U; + + for (size_t i = 0; i < data_length; i++) { + // crctab is an array of 256 32-bit constants + const uint32_t index = (crc32 ^ data[i]) & 0xFF; + crc32 = (crc32 >> 8) ^ crctab[index]; + } + + // invert all bits of result + crc32 ^= ~0U; + return crc32; +} + +void ShowOutputCRC32(tflite::MicroInterpreter* interpreter) { + GenCRC32Table(); + for (size_t i = 0; i < interpreter->outputs_size(); ++i) { + TfLiteTensor* output = interpreter->output_tensor(i); + uint8_t* output_values = tflite::GetTensorData(output); + uint32_t crc32_value = ComputeCRC32(output_values, output->bytes); + MicroPrintf("Output CRC32: 0x%X", crc32_value); + } +} + +void ShowInputCRC32(tflite::MicroInterpreter* interpreter) { + GenCRC32Table(); + for (size_t i = 0; i < interpreter->inputs_size(); ++i) { + TfLiteTensor* input = interpreter->input_tensor(i); + uint8_t* input_values = tflite::GetTensorData(input); + uint32_t crc32_value = ComputeCRC32(input_values, input->bytes); + MicroPrintf("Input CRC32: 0x%X", crc32_value); + } +} + int Benchmark(const uint8_t* model_data, tflite::PrettyPrintType print_type) { - Profiler profiler; + static Profiler profiler; + static Profiler profiler2; + TfLiteStatus status; + +// use this to keep the application size stable regardless of whether +// compression is being used +#ifdef USE_TFLM_COMPRESSION + constexpr bool using_compression = true; +#else // USE_TFLM_COMPRESSION + constexpr bool using_compression = false; +#endif // USE_TFLM_COMPRESSION + alignas(16) static uint8_t tensor_arena[kTensorArenaSize]; - uint32_t event_handle = profiler.BeginEvent("TfliteGetModel"); +#ifdef USE_ALT_DECOMPRESSION_MEM + std::initializer_list + alt_memory_region = {{g_alt_memory, kAltMemorySize}}; +#endif // USE_ALT_DECOMPRESSION_MEM + + uint32_t event_handle = profiler.BeginEvent("tflite::GetModel"); const tflite::Model* model = tflite::GetModel(model_data); profiler.EndEvent(event_handle); + event_handle = profiler.BeginEvent("tflite::CreateOpResolver"); TflmOpResolver op_resolver; - TF_LITE_ENSURE_STATUS(CreateOpResolver(op_resolver)); + status = CreateOpResolver(op_resolver); + if (status != kTfLiteOk) { + MicroPrintf("tflite::CreateOpResolver failed"); + return -1; + } + profiler.EndEvent(event_handle); + event_handle = profiler.BeginEvent("tflite::RecordingMicroAllocator::Create"); tflite::RecordingMicroAllocator* allocator( tflite::RecordingMicroAllocator::Create(tensor_arena, kTensorArenaSize)); + profiler.EndEvent(event_handle); + event_handle = profiler.BeginEvent("tflite::MicroInterpreter instantiation"); tflite::RecordingMicroInterpreter interpreter( model, op_resolver, allocator, tflite::MicroResourceVariables::Create(allocator, kNumResourceVariable), &profiler); - TF_LITE_ENSURE_STATUS(interpreter.AllocateTensors()); + profiler.EndEvent(event_handle); + +#ifdef USE_ALT_DECOMPRESSION_MEM + event_handle = + profiler.BeginEvent("tflite::MicroInterpreter::SetDecompressionMemory"); + status = interpreter.SetDecompressionMemory(alt_memory_region); + if (status != kTfLiteOk) { + MicroPrintf("tflite::MicroInterpreter::SetDecompressionMemory failed"); + return -1; + } + profiler.EndEvent(event_handle); +#endif // USE_ALT_DECOMPRESSION_MEM + + event_handle = + profiler.BeginEvent("tflite::MicroInterpreter::AllocateTensors"); + status = interpreter.AllocateTensors(); + if (status != kTfLiteOk) { + MicroPrintf("tflite::MicroInterpreter::AllocateTensors failed"); + return -1; + } + profiler.EndEvent(event_handle); - profiler.Log(); + profiler.LogTicksPerTagCsv(); profiler.ClearEvents(); + if (using_compression) { + status = interpreter.SetAlternateProfiler(&profiler2); + if (status != kTfLiteOk) { + MicroPrintf("tflite::MicroInterpreter::SetAlternateProfiler failed"); + return -1; + } + } + MicroPrintf(""); // null MicroPrintf serves as a newline. - // For streaming models, the interpreter will return kTfLiteAbort if the model - // does not yet have enough data to make an inference. As such, we need to - // invoke the interpreter multiple times until we either receive an error or - // kTfLiteOk. This loop also works for non-streaming models, as they'll just - // return kTfLiteOk after the first invocation. + // For streaming models, the interpreter will return kTfLiteAbort if the + // model does not yet have enough data to make an inference. As such, we + // need to invoke the interpreter multiple times until we either receive an + // error or kTfLiteOk. This loop also works for non-streaming models, as + // they'll just return kTfLiteOk after the first invocation. uint32_t seed = kRandomSeed; while (true) { SetRandomInput(seed++, interpreter); - TfLiteStatus status = interpreter.Invoke(); + ShowInputCRC32(&interpreter); + MicroPrintf(""); // null MicroPrintf serves as a newline. + + status = interpreter.Invoke(); if ((status != kTfLiteOk) && (static_cast(status) != kTfLiteAbort)) { MicroPrintf("Model interpreter invocation failed: %d\n", status); return -1; @@ -174,6 +304,17 @@ int Benchmark(const uint8_t* model_data, tflite::PrettyPrintType print_type) { MicroPrintf(""); // null MicroPrintf serves as a newline. profiler.ClearEvents(); + if (using_compression) { + profiler2.Log(); + MicroPrintf(""); // null MicroPrintf serves as a newline. + profiler2.LogTicksPerTagCsv(); + MicroPrintf(""); // null MicroPrintf serves as a newline. + profiler2.ClearEvents(); + } + + ShowOutputCRC32(&interpreter); + MicroPrintf(""); // null MicroPrintf serves as a newline. + if (status == kTfLiteOk) { break; } diff --git a/tensorflow/lite/micro/tools/benchmarking/metrics.cc b/tensorflow/lite/micro/tools/benchmarking/metrics.cc index 3a4bf7e4917..f71a4cd139e 100644 --- a/tensorflow/lite/micro/tools/benchmarking/metrics.cc +++ b/tensorflow/lite/micro/tools/benchmarking/metrics.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,7 +46,8 @@ struct LogAllocationRecord { constexpr int kArenaRows = 3; constexpr int kArenaColumns = 3; -constexpr int kAllocationTypes = 7; +constexpr int kAllocationTypes = + static_cast(tflite::RecordedAllocationType::kNumAllocationTypes); constexpr int kAllocationColumns = 6; constexpr int kMaxBufSize = 100; @@ -85,16 +86,25 @@ LogAllocationRecord GetLogAllocationRecord( tflite::RecordedAllocationType::kPersistentBufferData, tflite::RecordedAllocationType::kTfLiteTensorVariableBufferData, tflite::RecordedAllocationType::kNodeAndRegistrationArray, - tflite::RecordedAllocationType::kOpData}; + tflite::RecordedAllocationType::kOpData, +#ifdef USE_TFLM_COMPRESSION + tflite::RecordedAllocationType::kCompressionData, +#endif // USE_TFLM_COMPRESSION + }; static_assert(std::extent::value == kAllocationTypes, "kAllocationTypes mismatch"); - const char* titles[] = {"Eval tensor data", - "Persistent tensor data", - "Persistent quantization data", - "Persistent buffer data", - "Tensor variable buffer data", - "Node and registration array", - "Operation data"}; + const char* titles[] = { + "Eval tensor data", + "Persistent tensor data", + "Persistent quantization data", + "Persistent buffer data", + "Tensor variable buffer data", + "Node and registration array", + "Operation data", +#ifdef USE_TFLM_COMPRESSION + "Compression data", +#endif // USE_TFLM_COMPRESSION + }; static_assert(std::extent::value == kAllocationTypes, "kAllocationTypes mismatch"); const size_t total_bytes = diff --git a/tensorflow/lite/micro/tools/benchmarking/show_meta_data.cc.template b/tensorflow/lite/micro/tools/benchmarking/show_meta_data.cc.template index a2102a48e1c..8ec4e512f7a 100644 --- a/tensorflow/lite/micro/tools/benchmarking/show_meta_data.cc.template +++ b/tensorflow/lite/micro/tools/benchmarking/show_meta_data.cc.template @@ -20,6 +20,13 @@ limitations under the License. #include "tensorflow/lite/micro/micro_log.h" #include "tensorflow/lite/micro/tools/benchmarking/show_meta_data.h" +#ifndef XTENSA +#undef HIFI3 +#undef HIFI4 +#undef HIFI5 +#undef VISION_P6 +#endif // XTENSA + #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5) #include "NatureDSP_Signal_id.h" #include "xa_nnlib_standards.h" diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index a5a2286a898..e6912e91705 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -378,6 +378,8 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/concatenation.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/conv.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/conv_common.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/cumsum.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress_common.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depthwise_conv.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depthwise_conv_common.cc \ diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc index 92527adce53..8a3ee7b6716 100644 --- a/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc @@ -97,4 +97,10 @@ ifeq ($(OPTIMIZED_KERNEL_DIR), xtensa) $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/unidirectional_sequence_lstm.cc + + # override KERNEL_OPTIMIZATION_LEVEL to enable higher performance + # Xtensa intrinsics. +$(KERNEL_OBJDIR)$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/decompress.o: $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/decompress.cc + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) -O3 -LNO:simd $(INCLUDES) -c $< -o $@ endif