diff --git a/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h b/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h index ebd376466b6..69b049404c0 100644 --- a/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h +++ b/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h @@ -74,8 +74,6 @@ class NonPersistentArenaBufferAllocator : public INonPersistentBufferAllocator { // takes in account any temporary allocations. size_t GetAvailableMemory(size_t alignment) const override; - TF_LITE_REMOVE_VIRTUAL_DELETE - private: // The memory arena that this allocator manages. uint8_t* const buffer_head_; @@ -97,6 +95,8 @@ class NonPersistentArenaBufferAllocator : public INonPersistentBufferAllocator { // Count of outstanding temp buffers. int temp_buffer_count_ = 0; bool resizable_buffer_allocated_ = false; + + TF_LITE_REMOVE_VIRTUAL_DELETE }; } // namespace tflite diff --git a/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h b/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h index 2c8e3dca53b..a86d425d7c6 100644 --- a/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h +++ b/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h @@ -39,7 +39,6 @@ class PersistentArenaBufferAllocator : public IPersistentBufferAllocator { // Returns the size of all persistent allocations in bytes. size_t GetPersistentUsedBytes() const override; - TF_LITE_REMOVE_VIRTUAL_DELETE private: // The memory arena that this allocator manages. uint8_t* const buffer_head_; @@ -51,6 +50,8 @@ class PersistentArenaBufferAllocator : public IPersistentBufferAllocator { // So in essence, the allocated region grows from the bottom and emulates // SingleArenaBufferAllocator's persistent part. uint8_t* tail_temp_; + + TF_LITE_REMOVE_VIRTUAL_DELETE }; } // namespace tflite diff --git a/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h b/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h index a2e39588963..771c2deb436 100644 --- a/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h +++ b/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h @@ -110,8 +110,6 @@ class SingleArenaBufferAllocator : public INonPersistentBufferAllocator, // account any temporary allocations. size_t GetUsedBytes() const; - TF_LITE_REMOVE_VIRTUAL_DELETE - protected: // Returns a pointer to the current end of the head buffer. uint8_t* head() const; @@ -137,6 +135,8 @@ class SingleArenaBufferAllocator : public INonPersistentBufferAllocator, intptr_t temp_buffer_ptr_check_sum_ = 0; // Count of outstanding temp buffers. int temp_buffer_count_ = 0; + + TF_LITE_REMOVE_VIRTUAL_DELETE }; } // namespace tflite diff --git a/tensorflow/lite/micro/compression/model_facade.py b/tensorflow/lite/micro/compression/model_facade.py index 2e58d8080f1..6a8afc8bac8 100644 --- a/tensorflow/lite/micro/compression/model_facade.py +++ b/tensorflow/lite/micro/compression/model_facade.py @@ -100,10 +100,37 @@ def __init__(self, operator, index, subgraph): def opcode(self) -> tflite.OperatorCodeT: return self.subgraph.model.operatorCodes[self.operator.opcodeIndex] + @property + def builtin_opcode(self) -> int: + result: int = self.opcode.deprecatedBuiltinCode + if result == tflite.BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES: + result = self.opcode.builtinCode + return result + @property def inputs(self): return _IndirectIterator(self.operator.inputs, self.subgraph.tensors) + @property + def outputs(self): + return _IndirectIterator(self.operator.outputs, self.subgraph.tensors) + + @property + def inputs_indices(self): + return self.operator.inputs + + @property + def outputs_indices(self): + return self.operator.outputs + + @property + def builtin_options_type(self) -> int: + return self.operator.builtinOptionsType + + @property + def builtin_options(self): + return self.operator.builtinOptions + _NP_DTYPES = { tflite.TensorType.FLOAT16: np.dtype("AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpData)); +} + +#endif // USE_TFLM_COMPRESSION + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0); @@ -70,6 +85,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context, input_value)); } +#ifdef USE_TFLM_COMPRESSION + + TFLITE_DCHECK(node->user_data != nullptr); + OpData* data = static_cast(node->user_data); + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + data->scratch_index = + micro_context->AllocateDecompressionScratchBuffer(node, kInputValue); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(input_value); return kTfLiteOk; } @@ -93,15 +119,36 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { "ResourceVariables and pass it to the interpreter."); return kTfLiteError; } + +#ifdef USE_TFLM_COMPRESSION + OpData* data = static_cast(node->user_data); + const CompressionTensorData* comp_td = + micro_context->GetTensorCompressionData(node, kInputValue); + const void* buffer = tflite::micro::GetTensorData( + micro_context, input_value, comp_td, data->scratch_index); +#else // USE_TFLM_COMPRESSION + const void* buffer = tflite::micro::GetTensorData(input_value); +#endif // USE_TFLM_COMPRESSION + TF_LITE_ENSURE_OK(context, - resources->Assign(input_id->data.i32[0], input_value)); + resources->Assign(input_id->data.i32[0], + EvalTensorBytes(input_value), buffer)); return kTfLiteOk; } } // namespace. +#ifdef USE_TFLM_COMPRESSION + +TFLMRegistration Register_ASSIGN_VARIABLE() { + return tflite::micro::RegisterOp(Init, Prepare, Eval); + +#else // USE_TFLM_COMPRESSION + TFLMRegistration Register_ASSIGN_VARIABLE() { return tflite::micro::RegisterOp(nullptr, Prepare, Eval); + +#endif // USE_TFLM_COMPRESSION } } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/conv.cc b/tensorflow/lite/micro/kernels/conv.cc index 0df35fce4eb..3e4fb62318d 100644 --- a/tensorflow/lite/micro/kernels/conv.cc +++ b/tensorflow/lite/micro/kernels/conv.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,15 +45,35 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); const auto& data = *(static_cast(node->user_data)); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + switch (input->type) { // Already know in/out types are same. case kTfLiteFloat32: { tflite::reference_ops::Conv( ConvParamsFloat(params, data), tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr); @@ -67,9 +87,18 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); } else if (bias->type == kTfLiteInt64) { @@ -79,9 +108,18 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); } else { @@ -119,9 +157,18 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; diff --git a/tensorflow/lite/micro/kernels/conv.h b/tensorflow/lite/micro/kernels/conv.h index 0c8073f48f0..0090053e03c 100644 --- a/tensorflow/lite/micro/kernels/conv.h +++ b/tensorflow/lite/micro/kernels/conv.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -49,6 +49,14 @@ struct OpDataConv { // A buffer used to store unpacked filter values. This is used if the source // tensor is of n-bit precision that cannot be easily processed by kernels. int filter_buffer_index; + +#ifdef USE_TFLM_COMPRESSION + + // scratch buffers for compressed tensors + int weights_scratch_index; + int bias_scratch_index; + +#endif // USE_TFLM_COMPRESSION }; extern const int kConvInputTensor; diff --git a/tensorflow/lite/micro/kernels/conv_common.cc b/tensorflow/lite/micro/kernels/conv_common.cc index 51c7a6ff2d6..9f0f2f79588 100644 --- a/tensorflow/lite/micro/kernels/conv_common.cc +++ b/tensorflow/lite/micro/kernels/conv_common.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -209,6 +209,23 @@ TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node) { &data->filter_buffer_index); } +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + if (micro_context->IsTensorCompressed(node, kConvWeightsTensor) && + filter->type == kTfLiteInt4) { + MicroPrintf("Compression not supported with INT4 tensors"); + return kTfLiteError; + } + data->weights_scratch_index = + micro_context->AllocateDecompressionScratchBuffer(node, + kConvWeightsTensor); + data->bias_scratch_index = + micro_context->AllocateDecompressionScratchBuffer(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(filter); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(output); diff --git a/tensorflow/lite/micro/kernels/conv_test.cc b/tensorflow/lite/micro/kernels/conv_test.cc index 0fb9411a3f0..48eddeb9958 100644 --- a/tensorflow/lite/micro/kernels/conv_test.cc +++ b/tensorflow/lite/micro/kernels/conv_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/conv_test.h" +#include + #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" @@ -46,6 +48,90 @@ static int kOutputShape[] = {4, 2, 1, 2, 3}; static const float kGoldenData[kOutputElements] = {18, 2, 5, 18, 2, 5, 17, 4, 3, 37, 4, 3}; +#ifdef USE_TFLM_COMPRESSION + +// compressed filter data for kBinQuant scheme, matches kFilterData +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantFilterData[] = { + 0x05, 0x38, 0x20, 0x90, 0x00, +}; +constexpr float kBinQuantFilterValueTable[] = { + 1, 2, 3, 4, -1, +}; +constexpr size_t kBinQuantFilterValueTableElements = + std::extent::value; +constexpr int kBinQuantFilterBitWidth = 3; +// compressed bias data for kBinQuant scheme, matches kBiasData +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasData[] = {0x18}; +constexpr int kBinQuantBiasBitWidth = 2; + +// Common inputs and outputs for quantized compressed tensor tests. +// Values from TfLite conv_test.cc SimplePerChannelTest. +static int kInputShapeQ1[] = {4, 1, 2, 3, 2}; +static const float kInputDataQ1[] = { + // [1 * 2 * 3 * 2] as [batch, y, x, input_channel] + 3, 2, // batch = 0, y = 0, x = 0 + 1, -1, // batch = 0, y = 0, x = 1 + -2, -3, // batch = 0, y = 0, x = 2 + 4, 3, // batch = 0, y = 1, x = 0 + 2, -2, // batch = 0, y = 1, x = 1 + -3, -4, // batch = 0, y = 1, x = 2 +}; +constexpr size_t kInputElementsQ1 = std::extent::value; + +constexpr int kNumChannelsQ1 = 2; +static int kFilterShapeQ1[] = {4, 2, 2, 2, 2}; +// Original filter data: +// static constexpr float kFilterDataQ1[] = { +// // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel] +// 1, 2, // out channel = 0, y = 0, x = 0 +// 3, 4, // out channel = 0, y = 0, x = 1 +// 3, 4, // out channel = 0, y = 1, x = 0 +// 5, 6, // out channel = 0, y = 1, x = 1 +// 7, 8, // out channel = 1, y = 0, x = 0 +// 5, 6, // out channel = 1, y = 0, x = 1 +// 3, 4, // out channel = 1, y = 1, x = 0 +// 1, 2, // out channel = 1, y = 1, x = 1 +// }; + +static int kBiasShapeQ1[] = {1, 2}; +static const float kBiasDataQ1[] = {3, -2}; +constexpr size_t kBiasElementsQ1 = std::extent::value; + +static int kOutputShapeQ1[] = {4, 1, 1, 2, 2}; +static const float kGoldenDataQ1[] = {31, 64, -57, -46}; +constexpr int kOutputElementsQ1 = std::extent::value; +static const float kGoldenDataQ1_16[] = {31, 63.99804688, -57, -46}; + +// compressed filter data for kBinQuant scheme, matches kFilterDataQ1 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantFilterDataQ1[] = { + 0x05, 0x34, 0xE5, 0xDE, 0x54, 0xC1, +}; +constexpr float kBinQuantFilterValueTableQ1[] = { + 1, 2, 3, 4, 5, 6, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, +}; +constexpr size_t kBinQuantFilterValueTableElementsQ1 = + std::extent::value; +constexpr int kBinQuantFilterBitWidthQ1 = 3; +// compressed bias data for kBinQuant scheme, matches kBiasDataQ1 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasDataQ1[] = {0x00}; +constexpr int kBinQuantBiasBitWidthQ1 = 1; + +static TfLiteConvParams common_conv_params_q1 = { + kTfLitePaddingValid, // padding + 1, // stride_width + 1, // stride_height + kTfLiteActNone, // activation + 1, // dilation_width_factor + 1, // dilation_height_factor + kTfLiteNoType // quantized_bias_type +}; + +#endif // USE_TFLM_COMPRESSION + static TfLiteConvParams common_conv_params = { kTfLitePaddingValid, // padding 2, // stride_width @@ -122,6 +208,65 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannel) { output_data)); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelCompressed) { + const float input_scale = 0.5f; + const float output_scale = 0.5f; + const int input_zero_point = -1; + const int output_zero_point = -1; + constexpr float filter_scales[] = {tflite::testing::kNumChannelsQ1, 1.0f, + 2.0f}; + constexpr int filter_zero_points[] = {tflite::testing::kNumChannelsQ1, 0, 0}; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int8_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kBinQuantFilterValueTableElementsQ1]; + int32_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int8_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int8_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kBinQuantFilterValueTableElementsQ1 / + tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kBinQuantFilterValueTableQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestConvQuantizedPerChannelCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1, + golden_quantized, output_quantized, output_scale, output_zero_point, + &tflite::testing::common_conv_params_q1, tflite::Register_CONV_2D(), + &filter_comp_info, &bias_comp_info)); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(SimpleTestFloat) { float output_data[tflite::testing::kOutputElements]; @@ -136,6 +281,40 @@ TF_LITE_MICRO_TEST(SimpleTestFloat) { output_data)); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestFloatCompressed) { + tflite::testing::TestCompressionInfo filter_comp_info = {}; + tflite::testing::TestCompressionInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = tflite::testing::kBinQuantFilterValueTable; + filter_comp_info.value_table_stride = + tflite::testing::kBinQuantFilterValueTableElements; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidth; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = tflite::testing::kBiasData; + bias_comp_info.value_table_stride = tflite::testing::kBiasElements; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidth; + + float output_data[tflite::testing::kOutputElements]; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestConvFloat( + tflite::testing::kInputShape, tflite::testing::kInputData, + tflite::testing::kFilterShape, + reinterpret_cast(tflite::testing::kBinQuantFilterData), + tflite::testing::kBiasShape, + reinterpret_cast(tflite::testing::kBinQuantBiasData), + tflite::testing::kOutputShape, tflite::testing::kGoldenData, + &tflite::testing::common_conv_params, tflite::Register_CONV_2D(), + output_data, &filter_comp_info, &bias_comp_info)); +} + +#endif + TF_LITE_MICRO_TEST(InputAndFilterSameWidthHeight) { const int output_dims_count = 2; float output_data[output_dims_count]; @@ -246,6 +425,65 @@ TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel64bBias) { output_data)); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel64bBiasCompressed) { + const float input_scale = 128.0f / 65536; + const float output_scale = 128.0f / 65536; + const int input_zero_point = 0; + const int output_zero_point = 0; + constexpr float filter_scales[] = {tflite::testing::kNumChannelsQ1, 1.0f, + 2.0f}; + constexpr int filter_zero_points[] = {tflite::testing::kNumChannelsQ1, 0, 0}; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int16_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kBinQuantFilterValueTableElementsQ1]; + int64_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int16_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int16_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kBinQuantFilterValueTableElementsQ1 / + tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kBinQuantFilterValueTableQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestConvQuantizedPerChannelCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1_16, + golden_quantized, output_quantized, output_scale, output_zero_point, + &tflite::testing::common_conv_params_q1, tflite::Register_CONV_2D(), + &filter_comp_info, &bias_comp_info)); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel32bBias) { const int output_dims_count = 12; int16_t output_data[output_dims_count]; @@ -276,6 +514,65 @@ TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel32bBias) { output_data)); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel32bBiasCompressed) { + const float input_scale = 128.0f / 65536; + const float output_scale = 128.0f / 65536; + const int input_zero_point = 0; + const int output_zero_point = 0; + constexpr float filter_scales[] = {tflite::testing::kNumChannelsQ1, 1.0f, + 2.0f}; + constexpr int filter_zero_points[] = {tflite::testing::kNumChannelsQ1, 0, 0}; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int16_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kBinQuantFilterValueTableElementsQ1]; + int32_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int16_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int16_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kBinQuantFilterValueTableElementsQ1 / + tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kBinQuantFilterValueTableQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestConvQuantizedPerChannelCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1_16, + golden_quantized, output_quantized, output_scale, output_zero_point, + &tflite::testing::common_conv_params_q1, tflite::Register_CONV_2D(), + &filter_comp_info, &bias_comp_info)); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(SimpleTestDilatedQuantizedPerChannel) { const int output_dims_count = 24; int8_t output_data[output_dims_count]; diff --git a/tensorflow/lite/micro/kernels/conv_test.h b/tensorflow/lite/micro/kernels/conv_test.h index c655f043bcc..642f4c76d7a 100644 --- a/tensorflow/lite/micro/kernels/conv_test.h +++ b/tensorflow/lite/micro/kernels/conv_test.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/conv.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" #include "tensorflow/lite/micro/kernels/micro_ops.h" #include "tensorflow/lite/micro/test_helpers.h" @@ -26,35 +27,101 @@ limitations under the License. namespace tflite { namespace testing { -TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, TfLiteConvParams* conv_params, - TFLMRegistration registration, float* output_data); +constexpr int kConvMaxTensors = 4; +constexpr int kConvMaxInputTensors = 3; +template TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, TfLiteConvParams* conv_params, - TFLMRegistration registration, int8_t* output_data); - -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const float* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - TFLMRegistration registration, - float* output_data, float tolerance = 1e-5); - -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const int8_t* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - TFLMRegistration registration, - int8_t* output_data, float tolerance = 1e-5); - -TfLiteStatus TestConvFloat(int* input_dims_data, const float* input_data, - int* filter_dims_data, const float* filter_data, - int* bias_dims_data, const float* bias_data, - int* output_dims_data, - const float* expected_output_data, - TfLiteConvParams* conv_params, - TFLMRegistration registration, float* output_data); + int output_length, const TfLiteConvParams* conv_params, + TFLMRegistration registration, T* output_data +#ifdef USE_TFLM_COMPRESSION + , + const CompressedTensorList* comp_list_p = nullptr +#endif // USE_TFLM_COMPRESSION +) { + // TODO(b/358165875): support optional bias tensor + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 3}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, + outputs_array, conv_params +#ifdef USE_TFLM_COMPRESSION + , + nullptr, comp_list_p +#endif // USE_TFLM_COMPRESSION + ); + + const char* init_data = reinterpret_cast(conv_params); + TfLiteStatus status = runner.InitAndPrepare(init_data); + if (status != kTfLiteOk) { + return status; + } + return runner.Invoke(); +} + +template +TfLiteStatus ValidateConvGoldens( + TfLiteTensor* tensors, int tensors_size, const T* expected_output_data, + int output_length, const TfLiteConvParams* conv_params, + TFLMRegistration registration, T* output_data, float tolerance = 1e-5 +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { +#ifdef USE_TFLM_COMPRESSION + + TestCompressedList tcl; + if (filter_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*filter_comp_info, tensors[kConvWeightsTensor], + kConvWeightsTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + if (bias_comp_info) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*bias_comp_info, tensors[kConvBiasTensor], + kConvBiasTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + const CompressedTensorList* comp_list_p = tcl.GetCompressedTensorList(); + +#endif // USE_TFLM_COMPRESSION + + TfLiteStatus status = InvokeConv(tensors, tensors_size, output_length, + conv_params, registration, output_data +#ifdef USE_TFLM_COMPRESSION + , + comp_list_p +#endif // USE_TFLM_COMPRESSION + ); + if (status != kTfLiteOk) { + return status; + } + for (int i = 0; i < output_length; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], + tolerance); + } + return kTfLiteOk; +} + +TfLiteStatus TestConvFloat( + int* input_dims_data, const float* input_data, int* filter_dims_data, + const float* filter_data, int* bias_dims_data, const float* bias_data, + int* output_dims_data, const float* expected_output_data, + TfLiteConvParams* conv_params, TFLMRegistration registration, + float* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +); TfLiteStatus TestConvQuantizedPerChannel( int* input_dims_data, const float* input_data, int8_t* input_quantized, @@ -88,6 +155,74 @@ TfLiteStatus TestConvQuantizedPerChannel( float output_scale, int output_zero_point, TfLiteConvParams* conv_params, TFLMRegistration registration, int16_t* output_data); +#ifdef USE_TFLM_COMPRESSION + +template +TfLiteStatus TestConvQuantizedPerChannelCompressed( + int* input_dims_data, const float* input_data, TIO* input_quantized, + float input_scale, int input_zero_point, int* output_dims_data, + const float* expected_output_data, TIO* expected_output_quantized, + TIO* output_quantized, float output_scale, int output_zero_point, + const TfLiteConvParams* conv_params, TFLMRegistration registration, + const TestCompressionQuantizedInfo* filter_comp_info, + const TestCompressionQuantizedInfo* bias_comp_info) { + // TODO(b/358165875): account for optional bias tensor + // bool null_bias = comp_info->bias_data == nullptr ? true : false; + + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInts(filter_comp_info->dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_comp_info->dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + + TfLiteFloatArray* filter_scales = + FloatArrayFromFloats(filter_comp_info->scales); + TfLiteIntArray* filter_zero_points = + IntArrayFromInts(filter_comp_info->zero_points); + TfLiteFloatArray* bias_scales = FloatArrayFromFloats(bias_comp_info->scales); + TfLiteIntArray* bias_zero_points = + IntArrayFromInts(bias_comp_info->zero_points); + + TfLiteAffineQuantization filter_quant = {}; + TfLiteTensor filter_tensor = CreatePerChannelQuantizedTensor( + filter_comp_info->compressed, filter_dims, filter_scales, + filter_zero_points, &filter_quant, kConvQuantizedDimension, + false /* is_variable */, kTfLiteInt8); + SymmetricPerChannelQuantize( + filter_comp_info->data, filter_comp_info->value_table, + filter_scales->size * filter_comp_info->value_table_stride, + filter_scales->size, filter_scales->data); + + TfLiteAffineQuantization bias_quant = {}; + TfLiteTensor bias_tensor = CreatePerChannelQuantizedBiasTensor( + bias_comp_info->compressed, bias_dims, input_scale, filter_scales, + bias_scales, bias_zero_points, &bias_quant, kConvQuantizedDimension, + false /* is_variable */, typeToTfLiteType()); + SymmetricPerChannelQuantize( + bias_comp_info->data, bias_comp_info->value_table, + bias_scales->size * bias_comp_info->value_table_stride, bias_scales->size, + bias_scales->data); + + constexpr int tensors_size = kConvMaxTensors; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_quantized, input_dims, + input_scale, input_zero_point), + filter_tensor, + bias_tensor, + CreateQuantizedTensor(output_quantized, output_dims, output_scale, + output_zero_point), + }; + + const int output_dims_count = ElementCount(*output_dims); + Quantize(expected_output_data, expected_output_quantized, output_dims_count, + output_scale, output_zero_point); + return ValidateConvGoldens(tensors, tensors_size, expected_output_quantized, + output_dims_count, conv_params, registration, + output_quantized, 1.0e-5f /* tolerance */, + filter_comp_info, bias_comp_info); +} + +#endif // USE_TFLM_COMPRESSION + } // namespace testing } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/conv_test_common.cc b/tensorflow/lite/micro/kernels/conv_test_common.cc index a0f733b8e42..3825e05373c 100644 --- a/tensorflow/lite/micro/kernels/conv_test_common.cc +++ b/tensorflow/lite/micro/kernels/conv_test_common.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,88 +18,18 @@ limitations under the License. namespace tflite { namespace testing { -template -TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, TfLiteConvParams* conv_params, - TFLMRegistration registration, T* output_data) { - int inputs_array_data[] = {3, 0, 1, 2}; - TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); - int outputs_array_data[] = {1, 3}; - TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - - micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, - outputs_array, conv_params); - - const char* init_data = reinterpret_cast(conv_params); - TfLiteStatus status = runner.InitAndPrepare(init_data); - if (status != kTfLiteOk) { - return status; - } - return runner.Invoke(); -} - -template -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const T* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - TFLMRegistration registration, T* output_data, - float tolerance) { - TfLiteStatus status = InvokeConv(tensors, tensors_size, output_length, - conv_params, registration, output_data); - if (status != kTfLiteOk) { - return status; - } - for (int i = 0; i < output_length; ++i) { - TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], - tolerance); - } - return kTfLiteOk; -} - -TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, TfLiteConvParams* conv_params, - TFLMRegistration registration, float* output_data) { - return InvokeConv(tensors, tensors_size, output_length, conv_params, - registration, output_data); -} - -TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, TfLiteConvParams* conv_params, - TFLMRegistration registration, int8_t* output_data) { - return InvokeConv(tensors, tensors_size, output_length, conv_params, - registration, output_data); -} - -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const float* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - TFLMRegistration registration, - float* output_data, float tolerance) { - return ValidateConvGoldens(tensors, tensors_size, expected_output_data, - output_length, conv_params, registration, - output_data, tolerance); -} - -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const int8_t* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - TFLMRegistration registration, - int8_t* output_data, float tolerance) { - return ValidateConvGoldens( - tensors, tensors_size, expected_output_data, output_length, conv_params, - registration, output_data, tolerance); -} - -TfLiteStatus TestConvFloat(int* input_dims_data, const float* input_data, - int* filter_dims_data, const float* filter_data, - int* bias_dims_data, const float* bias_data, - int* output_dims_data, - const float* expected_output_data, - TfLiteConvParams* conv_params, - TFLMRegistration registration, float* output_data) { +TfLiteStatus TestConvFloat( + int* input_dims_data, const float* input_data, int* filter_dims_data, + const float* filter_data, int* bias_dims_data, const float* bias_data, + int* output_dims_data, const float* expected_output_data, + TfLiteConvParams* conv_params, TFLMRegistration registration, + float* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info, + const TestCompressionInfo* bias_comp_info +#endif // USE_TFLM_COMPRESSION +) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); @@ -117,7 +47,12 @@ TfLiteStatus TestConvFloat(int* input_dims_data, const float* input_data, return ValidateConvGoldens(tensors, tensors_size, expected_output_data, output_dims_count, conv_params, registration, - output_data); + output_data +#ifdef USE_TFLM_COMPRESSION + , + 1e-5f, filter_comp_info, bias_comp_info +#endif // USE_TFLM_COMPRESSION + ); } template diff --git a/tensorflow/lite/micro/kernels/depthwise_conv.cc b/tensorflow/lite/micro/kernels/depthwise_conv.cc index fa55a705606..4d6cb4c4979 100644 --- a/tensorflow/lite/micro/kernels/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/depthwise_conv.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -52,6 +52,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { ? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor) : nullptr; +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kDepthwiseConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + switch (input->type) { // Already know in/out types are same. case kTfLiteFloat32: { tflite::reference_ops::DepthwiseConv( @@ -59,9 +71,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -94,9 +115,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -118,9 +148,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; diff --git a/tensorflow/lite/micro/kernels/depthwise_conv_common.cc b/tensorflow/lite/micro/kernels/depthwise_conv_common.cc index 52804de3315..0813d2b028e 100644 --- a/tensorflow/lite/micro/kernels/depthwise_conv_common.cc +++ b/tensorflow/lite/micro/kernels/depthwise_conv_common.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -127,7 +127,9 @@ TfLiteStatus CalculateOpDataDepthwiseConv( micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); - micro_context->DeallocateTempTfLiteTensor(bias); + if (has_bias) { + micro_context->DeallocateTempTfLiteTensor(bias); + } micro_context->DeallocateTempTfLiteTensor(output); return kTfLiteOk; @@ -209,6 +211,23 @@ TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node) { context, node, params, input_width, input_height, filter_width, filter_height, output_width, output_height, input->type, data)); +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + if (micro_context->IsTensorCompressed(node, kDepthwiseConvWeightsTensor) && + filter->type == kTfLiteInt4) { + MicroPrintf("Compression not supported with INT4 tensors"); + return kTfLiteError; + } + data->weights_scratch_index = + micro_context->AllocateDecompressionScratchBuffer( + node, kDepthwiseConvWeightsTensor); + data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer( + node, kDepthwiseConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); diff --git a/tensorflow/lite/micro/kernels/depthwise_conv_test.cc b/tensorflow/lite/micro/kernels/depthwise_conv_test.cc index b50b40ae6d6..adedcaeb04e 100644 --- a/tensorflow/lite/micro/kernels/depthwise_conv_test.cc +++ b/tensorflow/lite/micro/kernels/depthwise_conv_test.cc @@ -1,5 +1,5 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,6 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" @@ -32,17 +34,99 @@ constexpr int kOutputTensorIndex = 3; constexpr int kMaxFilterChannels = 64; constexpr int kMaxBiasChannels = 64; +#ifdef USE_TFLM_COMPRESSION + +constexpr size_t kDepthwiseConvMaxTensors = 4; +constexpr size_t kDepthwiseConvMaxInputTensors = 3; + +// Common inputs and outputs (quantized multi channel). +// data from TfLite test: +// PerChannelQuantizedDepthwiseConvolutionOpTest SimpleTestMixedOutputShift +static int kInputShapeQ1[] = {4, 1, 2, 3, 2}; +static constexpr float kInputDataQ1[] = { + // [1 * 2 * 3 * 2] as [batch, y, x, input_channel] + 3, 2, // batch = 0, y = 0, x = 0 + 1, -1, // batch = 0, y = 0, x = 1 + -2, -3, // batch = 0, y = 0, x = 2 + 4, 3, // batch = 0, y = 1, x = 0 + 2, -2, // batch = 0, y = 1, x = 1 + -3, -4, // batch = 0, y = 1, x = 2 +}; +constexpr size_t kInputElementsQ1 = std::extent::value; + +constexpr int kNumChannelsQ1 = 4; +static int kFilterShapeQ1[] = {4, 1, 2, 2, 4}; +static constexpr float kFilterDataQ1[] = { + // This is a compact value table. Original data is: + // [1 * 2 * 2 * 4] as [input_channel, y, x, output_channel] + // depth multiplier = 2 + // 1, 2, 3, 4, y = 0, x = 0 + // 3, 4, 5, 6, y = 0, x = 1 + // 7, 8, 5, 6, y = 1, x = 0 + // 3, 4, 1, 2, y = 1, x = 1 + 1, 3, 7, 8, 2, 4, 1, 3, 5, 2, 4, 6, +}; +constexpr size_t kFilterElementsQ1 = + std::extent::value; + +static int kBiasShapeQ1[] = {1, 4}; +static constexpr float kBiasDataQ1[] = {3, -2, 4, 6}; +constexpr size_t kBiasElementsQ1 = std::extent::value; + +static int kOutputShapeQ1[] = {4, 1, 1, 2, 4}; +static constexpr float kGoldenDataQ1[] = {43, 48, 21, 22, 3, -4, -30, -36}; +constexpr int kOutputElementsQ1 = std::extent::value; + +// compressed filter data for kBinQuant scheme, matches kFilterDataQ1 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantFilterDataQ1[] = {0x15, 0x6A, 0x8A, + 0x60}; +constexpr int kBinQuantFilterBitWidthQ1 = 2; +// compressed bias data for kBinQuant scheme, matches kBiasDataQ1 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasDataQ1[] = {0x00}; +constexpr int kBinQuantBiasBitWidthQ1 = 1; + +#endif // USE_TFLM_COMPRESSION + // Creates a DepthwiseConv opeerator, calls it with the provided input tensors // and some defaults parameters, and compares the output with // expected_output_data. // // The tensors parameter contains both the input tensors as well as a // preallocated output tensor into which the output is stored. -template +template TfLiteStatus ValidateDepthwiseConvGoldens( const T* expected_output_data, int output_length, TfLiteDepthwiseConvParams* conv_params, float tolerance, int tensors_size, - TfLiteTensor* tensors) { + TfLiteTensor* tensors +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { +#ifdef USE_TFLM_COMPRESSION + + TestCompressedList tcl; + if (filter_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*filter_comp_info, tensors[kDepthwiseConvWeightsTensor], + kDepthwiseConvWeightsTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + if (bias_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*bias_comp_info, tensors[kDepthwiseConvBiasTensor], + kDepthwiseConvBiasTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + const CompressedTensorList* comp_list_p = tcl.GetCompressedTensorList(); + +#endif // USE_TFLM_COMPRESSION + int inputs_array_data[] = {3, 0, 1, 2}; TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); int outputs_array_data[] = {1, 3}; @@ -50,8 +134,12 @@ TfLiteStatus ValidateDepthwiseConvGoldens( const TFLMRegistration registration = Register_DEPTHWISE_CONV_2D(); micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, - outputs_array, - reinterpret_cast(conv_params)); + outputs_array, reinterpret_cast(conv_params) +#ifdef USE_TFLM_COMPRESSION + , + nullptr, comp_list_p +#endif // USE_TFLM_COMPRESSION + ); int input_depth = tensors[0].dims->data[3]; int output_depth = tensors[1].dims->data[3]; @@ -183,18 +271,93 @@ void TestDepthwiseConvQuantizedPerChannel( output_scale, output_zero_point, conv_params, filter_packed_type); } +#ifdef USE_TFLM_COMPRESSION + +template +TfLiteStatus TestDepthwiseConvQuantizedCompressed( + int* input_dims_data, const float* input_data, TIO* input_quantized, + float input_scale, int input_zero_point, int* output_dims_data, + const float* expected_output_data, TIO* expected_output_quantized, + TIO* output_quantized, float output_scale, int output_zero_point, + TfLiteDepthwiseConvParams* conv_params, const unsigned int tolerance, + const TestCompressionQuantizedInfo* filter_comp_info, + const TestCompressionQuantizedInfo* bias_comp_info) { + // TODO(b/360169306): account for optional bias tensor + // bool null_bias = comp_info->bias_data == nullptr ? true : false; + + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInts(filter_comp_info->dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_comp_info->dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + + TfLiteFloatArray* filter_scales = + FloatArrayFromFloats(filter_comp_info->scales); + TfLiteIntArray* filter_zero_points = + IntArrayFromInts(filter_comp_info->zero_points); + TfLiteFloatArray* bias_scales = FloatArrayFromFloats(bias_comp_info->scales); + TfLiteIntArray* bias_zero_points = + IntArrayFromInts(bias_comp_info->zero_points); + + TfLiteAffineQuantization filter_quant = {}; + TfLiteTensor filter_tensor = CreatePerChannelQuantizedTensor( + filter_comp_info->compressed, filter_dims, filter_scales, + filter_zero_points, &filter_quant, kDepthwiseConvQuantizedDimension, + false /* is_variable */, kTfLiteInt8); + // Value tables are always in channel order, therefore do not use the + // quantized dimension. + SymmetricPerChannelQuantize( + filter_comp_info->data, filter_comp_info->value_table, + filter_scales->size * filter_comp_info->value_table_stride, + filter_scales->size, filter_scales->data, 0 /* see comment above */); + + TfLiteAffineQuantization bias_quant = {}; + TfLiteTensor bias_tensor = CreatePerChannelQuantizedBiasTensor( + bias_comp_info->compressed, bias_dims, input_scale, filter_scales, + bias_scales, bias_zero_points, &bias_quant, + 0 /* quantized dimension for bias tensor */, false /* is_variable */, + typeToTfLiteType()); + SymmetricPerChannelQuantize( + bias_comp_info->data, bias_comp_info->value_table, + bias_scales->size * bias_comp_info->value_table_stride, bias_scales->size, + bias_scales->data); + + constexpr int tensors_size = kDepthwiseConvMaxTensors; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_quantized, input_dims, + input_scale, input_zero_point), + filter_tensor, + bias_tensor, + CreateQuantizedTensor(output_quantized, output_dims, output_scale, + output_zero_point), + }; + + const int output_dims_count = ElementCount(*output_dims); + Quantize(expected_output_data, expected_output_quantized, output_dims_count, + output_scale, output_zero_point); + return ValidateDepthwiseConvGoldens( + expected_output_quantized, output_dims_count, conv_params, tolerance, + tensors_size, tensors, filter_comp_info, bias_comp_info); +} + +#endif // USE_TFLM_COMPRESSION + +// TODO(ddavis-2015): is this still valid? // Xtensa kernels do not support float activations., and the corresponding tests // are disabled. As a result, helper functions that are only needed for float // kernel tests also need to be ifdef'd out to avoid build errors due to unused // functions. #if !defined(XTENSA) -void TestDepthwiseConvFloat(int* input_dims_data, const float* input_data, - int* filter_dims_data, const float* filter_data, - int* bias_dims_data, const float* bias_data, - const float* expected_output_data, - int* output_dims_data, - TfLiteDepthwiseConvParams* conv_params, - float* output_data) { +void TestDepthwiseConvFloat( + int* input_dims_data, const float* input_data, int* filter_dims_data, + const float* filter_data, int* bias_dims_data, const float* bias_data, + const float* expected_output_data, int* output_dims_data, + TfLiteDepthwiseConvParams* conv_params, float* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); @@ -212,7 +375,12 @@ void TestDepthwiseConvFloat(int* input_dims_data, const float* input_data, }; ValidateDepthwiseConvGoldens(expected_output_data, output_dims_count, - conv_params, 1e-5, tensors_size, tensors); + conv_params, 1e-5, tensors_size, tensors +#ifdef USE_TFLM_COMPRESSION + , + filter_comp_info, bias_comp_info +#endif // USE_TFLM_COMPRESSION + ); } #endif // !defined(XTENSA) @@ -253,6 +421,60 @@ TF_LITE_MICRO_TEST(SimpleTest) { bias_values, golden, output_shape, &conv_params, output_data); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestCompressed) { + int input_shape[] = {4, 1, 3, 2, 2}; + const float input_values[] = {1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12}; + int filter_shape[] = {4, 1, 2, 2, 4}; + // Filter values: + // {1, 2, 3, 4, -9, 10, -11, 12, 5, 6, 7, 8, 13, -14, 15, -16} + // Align the tensor data the same as a Buffer in the schema + alignas(16) const uint8_t kBinQuantFilterData[] = {0x01, 0x23, 0xF8, 0xE9, + 0x45, 0x67, 0xAD, 0xBC}; + const float kBinQuantFilterValueTable[] = {1, 2, 3, 4, 5, 6, 7, 8, + 10, 12, 13, 15, -16, -14, -11, -9}; + int bias_shape[] = {4, 1, 1, 1, 4}; + const float bias_values[] = {1, 2, 3, 4}; + // Align the tensor data the same as a Buffer in the schema + alignas(16) const uint8_t kBinQuantBiasData[] = {0x1B}; + const float golden[] = { + 71, -34, 99, -20, 91, -26, 127, -4, + }; + int output_shape[] = {4, 1, 2, 1, 4}; + const int output_dims_count = std::extent::value; + float output_data[output_dims_count]; + + tflite::testing::TestCompressionInfo filter_comp_info = {}; + tflite::testing::TestCompressionInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = kBinQuantFilterValueTable; + filter_comp_info.value_table_stride = + std::extent::value; + filter_comp_info.bit_width = 4; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_values; + bias_comp_info.value_table_stride = std::extent::value; + bias_comp_info.bit_width = 2; + + TfLiteDepthwiseConvParams conv_params; + conv_params.activation = kTfLiteActNone; + conv_params.dilation_width_factor = 1; + conv_params.dilation_height_factor = 1; + conv_params.stride_height = 1; + conv_params.stride_width = 1; + + tflite::testing::TestDepthwiseConvFloat( + input_shape, input_values, filter_shape, + reinterpret_cast(kBinQuantFilterData), bias_shape, + reinterpret_cast(kBinQuantBiasData), golden, output_shape, + &conv_params, output_data, &filter_comp_info, &bias_comp_info); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(SimpleTestRelu) { int input_shape[] = {4, 1, 3, 2, 2}; const float input_values[] = {1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12}; @@ -1068,4 +1290,144 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelInt16InputInt8Filter) { bias_quantized, output_shape, golden, golden_quantized, output_data, output_scale, output_zero_point, &conv_params); } + +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelInt8Compressed) { + // data from TfLite test: + // PerChannelQuantizedDepthwiseConvolutionOpTest SimpleTestMixedOutputShift + const float input_scale = 0.5f; + const float output_scale = 0.5f; + const int input_zero_point = -1; + const int output_zero_point = -1; + constexpr float filter_scales[] = { + tflite::testing::kNumChannelsQ1, 0.1f, 2.0f, 3.0f, 0.4f, + }; + constexpr int filter_zero_points[] = { + tflite::testing::kNumChannelsQ1, 0, 0, 0, 0, + }; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int8_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kFilterElementsQ1]; + int32_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int8_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int8_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kFilterElementsQ1 / tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kFilterDataQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TfLiteDepthwiseConvParams conv_params = {}; + conv_params.activation = kTfLiteActNone; + conv_params.dilation_width_factor = 1; + conv_params.dilation_height_factor = 1; + conv_params.stride_height = 1; + conv_params.stride_width = 1; + + // tolerance of 3 is approx. 2.0f + // TODO(ddavis-2015): why does the tolerance differ from TfLite test??? + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestDepthwiseConvQuantizedCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1, + golden_quantized, output_quantized, output_scale, output_zero_point, + &conv_params, 3, &filter_comp_info, &bias_comp_info)); +} + +TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelInt16Compressed) { + // data from TfLite test: + // PerChannelQuantizedDepthwiseConvolutionOpTest SimpleTestMixedOutputShift + const float input_scale = + tflite::testing::SymmetricScaleFromMinMax(-4.0f, 4.0f); + const float output_scale = + tflite::testing::SymmetricScaleFromMinMax(-63.5f, 64.0f); + const int input_zero_point = 0; + const int output_zero_point = 0; + constexpr float filter_scales[] = { + tflite::testing::kNumChannelsQ1, 0.1f, 2.0f, 3.0f, 0.4f, + }; + constexpr int filter_zero_points[] = { + tflite::testing::kNumChannelsQ1, 0, 0, 0, 0, + }; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int16_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kFilterElementsQ1]; + int64_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int16_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int16_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kFilterElementsQ1 / tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kFilterDataQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TfLiteDepthwiseConvParams conv_params = {}; + conv_params.activation = kTfLiteActNone; + conv_params.dilation_width_factor = 1; + conv_params.dilation_height_factor = 1; + conv_params.stride_height = 1; + conv_params.stride_width = 1; + + // tolerance of 512 is approx. 1.0f + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestDepthwiseConvQuantizedCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1, + golden_quantized, output_quantized, output_scale, output_zero_point, + &conv_params, 512, &filter_comp_info, &bias_comp_info)); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/transpose_conv.cc b/tensorflow/lite/micro/kernels/transpose_conv.cc index ea0efae0607..7d65dc3de7c 100644 --- a/tensorflow/lite/micro/kernels/transpose_conv.cc +++ b/tensorflow/lite/micro/kernels/transpose_conv.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,30 +27,26 @@ limitations under the License. #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/padding.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/transpose_conv.h" #include "tensorflow/lite/micro/micro_log.h" namespace tflite { namespace { -// For the TfLite transpose_conv implementation, input tensor 0 corresponds to -// the OutputShapeTensor. However, since TFLM does not support dynamic tensors, -// the TFLM implementation ignores input tensor 0 and the only inputs we care -// about are kFilterTensor, kInputTensor and kBiasTensor. -constexpr int kFilterTensor = 1; -constexpr int kInputTensor = 2; -constexpr int kBiasTensor = 3; -constexpr int kOutputTensor = 0; - -// Conv is quantized along dimension 0: -// https://www.tensorflow.org/lite/performance/quantization_spec -constexpr int kConvQuantizedDimension = 0; - struct OpData { ConvParams params; // A scratch buffer is required for quantized implementations. int scratch_buffer_index; +#ifdef USE_TFLM_COMPRESSION + + // scratch buffers for compressed tensors + int filter_scratch_index; + int bias_scratch_index; + +#endif // USE_TFLM_COMPRESSION + // Index to the converted 64-bit bias buffer from 16-bit bias. This is // required to handle 16x8 transpose convolutions where a 16-bit bias is // provided, whereas the kernel expects 64-bit biases. @@ -102,17 +98,17 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, MicroContext* micro_context = GetMicroContext(context); TfLiteTensor* input = - micro_context->AllocateTempInputTensor(node, kInputTensor); + micro_context->AllocateTempInputTensor(node, kTransposeConvInputTensor); TF_LITE_ENSURE(context, input != nullptr); - TfLiteTensor* filter = - micro_context->AllocateTempInputTensor(node, kFilterTensor); + TfLiteTensor* filter = micro_context->AllocateTempInputTensor( + node, kTransposeConvFilterTensor); TF_LITE_ENSURE(context, filter != nullptr); TfLiteTensor* bias = - micro_context->AllocateTempInputTensor(node, kBiasTensor); - TfLiteTensor* output = - micro_context->AllocateTempOutputTensor(node, kOutputTensor); + micro_context->AllocateTempInputTensor(node, kTransposeConvBiasTensor); + TfLiteTensor* output = micro_context->AllocateTempOutputTensor( + node, kTransposeConvOutputTensor); TF_LITE_ENSURE(context, output != nullptr); - int output_channels = filter->dims->data[kConvQuantizedDimension]; + int output_channels = filter->dims->data[kTransposeConvQuantizedDimension]; TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( context, input, filter, bias, output, kTfLiteActNone, @@ -164,13 +160,13 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) { MicroContext* micro_context = GetMicroContext(context); TfLiteTensor* output = - micro_context->AllocateTempOutputTensor(node, kOutputTensor); + micro_context->AllocateTempOutputTensor(node, kTransposeConvOutputTensor); TF_LITE_ENSURE(context, output != nullptr); TfLiteTensor* input = - micro_context->AllocateTempInputTensor(node, kInputTensor); + micro_context->AllocateTempInputTensor(node, kTransposeConvInputTensor); TF_LITE_ENSURE(context, input != nullptr); TfLiteTensor* filter = - micro_context->AllocateTempInputTensor(node, kFilterTensor); + micro_context->AllocateTempInputTensor(node, kTransposeConvFilterTensor); TF_LITE_ENSURE(context, filter != nullptr); TF_LITE_ENSURE_MSG( @@ -186,7 +182,7 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) { const int filter_height = SizeOfDimension(filter, 1); // Dynamically allocate per-channel quantization parameters. - const int num_channels = filter->dims->data[kConvQuantizedDimension]; + const int num_channels = filter->dims->data[kTransposeConvQuantizedDimension]; data->per_channel_output_multiplier = static_cast(context->AllocatePersistentBuffer( context, num_channels * sizeof(int32_t))); @@ -223,10 +219,10 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, affine_quantization->scale); TF_LITE_ENSURE(context, affine_quantization->zero_point); - TF_LITE_ENSURE(context, - affine_quantization->scale->size == 1 || - affine_quantization->scale->size == - filter->dims->data[kConvQuantizedDimension]); + TF_LITE_ENSURE( + context, affine_quantization->scale->size == 1 || + affine_quantization->scale->size == + filter->dims->data[kTransposeConvQuantizedDimension]); TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, affine_quantization->zero_point->size); } @@ -244,6 +240,18 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) { data->params.stride_width = params->stride_width; data->params.stride_height = params->stride_height; +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + data->filter_scratch_index = + micro_context->AllocateDecompressionScratchBuffer( + node, kTransposeConvFilterTensor); + data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer( + node, kTransposeConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); @@ -252,15 +260,26 @@ TfLiteStatus TransposeConvPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); + tflite::micro::GetEvalInput(context, node, kTransposeConvInputTensor); const TfLiteEvalTensor* filter = - tflite::micro::GetEvalInput(context, node, kFilterTensor); + tflite::micro::GetEvalInput(context, node, kTransposeConvFilterTensor); const TfLiteEvalTensor* bias = (NumInputs(node) == 4) - ? tflite::micro::GetEvalInput(context, node, kBiasTensor) + ? tflite::micro::GetEvalInput(context, node, kTransposeConvBiasTensor) : nullptr; TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); + tflite::micro::GetEvalOutput(context, node, kTransposeConvOutputTensor); + +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, kTransposeConvFilterTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kTransposeConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION TFLITE_DCHECK(node->user_data != nullptr); const OpData& data = *(static_cast(node->user_data)); @@ -280,9 +299,17 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) { op_params, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, data.filter_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr); @@ -296,9 +323,17 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) { data.per_channel_output_shift, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, data.filter_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer); @@ -311,16 +346,29 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) { auto* bias_converted_buffer = static_cast(context->GetScratchBuffer( context, data.bias_converted_buffer_index)); + const int16_t* const bias_int16_data = +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION + static_cast(bias->data.data); +#endif // USE_TFLM_COMPRESSION for (int i = 0; i < tflite::micro::GetTensorShape(bias).FlatSize(); i++) { - bias_converted_buffer[i] = bias->data.i16[i]; + bias_converted_buffer[i] = bias_int16_data[i]; } reference_integer_ops::TransposeConv( data.params, data.per_channel_output_multiplier, data.per_channel_output_shift, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.filter_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(bias), bias_converted_buffer, tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), @@ -331,9 +379,18 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) { data.per_channel_output_shift, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.filter_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetOptionalTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer); diff --git a/tensorflow/lite/micro/kernels/transpose_conv.h b/tensorflow/lite/micro/kernels/transpose_conv.h index 3a99ccbf847..ec0416e067f 100644 --- a/tensorflow/lite/micro/kernels/transpose_conv.h +++ b/tensorflow/lite/micro/kernels/transpose_conv.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,6 +23,19 @@ limitations under the License. namespace tflite { +// For the TfLite transpose_conv implementation, input tensor 0 corresponds to +// the OutputShapeTensor. However, since TFLM does not support dynamic tensors, +// the TFLM implementation ignores input tensor 0 and the only inputs we care +// about are kFilterTensor, kInputTensor and kBiasTensor. +constexpr int kTransposeConvFilterTensor = 1; +constexpr int kTransposeConvInputTensor = 2; +constexpr int kTransposeConvBiasTensor = 3; +constexpr int kTransposeConvOutputTensor = 0; + +// Conv is quantized along dimension 0: +// https://www.tensorflow.org/lite/performance/quantization_spec +constexpr int kTransposeConvQuantizedDimension = 0; + // This is the most generic TFLMRegistration. The actual supported types // may still be target dependent. The only requirement is that every // implementation (reference or optimized) must define this function. diff --git a/tensorflow/lite/micro/kernels/transpose_conv_test.cc b/tensorflow/lite/micro/kernels/transpose_conv_test.cc index 49d2c90f439..2d5f3a0ba4e 100644 --- a/tensorflow/lite/micro/kernels/transpose_conv_test.cc +++ b/tensorflow/lite/micro/kernels/transpose_conv_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,9 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/micro/kernels/transpose_conv.h" + +#include + #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/micro/kernels/conv_test.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" #include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/micro/test_helpers.h" @@ -47,20 +50,127 @@ static const float kGoldenData[kOutputElements] = { 184, 412, 568, 528, 678, 1347, 1689, 1434, 1494, 2715, 3057, 2442, 1968, 3352, 3652, 2760}; +#ifdef USE_TFLM_COMPRESSION + +constexpr size_t kTransposeConvMaxTensors = 5; +constexpr size_t kTransposeConvMaxInputTensors = 4; + +// compressed filter data for kBinQuant scheme, matches kFilterData +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantFilterData[] = { + 0x00, 0x44, 0x32, 0x14, 0xC7, 0x42, 0x54, 0xB6, 0x35, 0xCF, 0x84, 0x40}; +constexpr int kBinQuantFilterBitWidth = 5; +// compressed bias data for kBinQuant scheme, matches kBiasData +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasData[] = {0x00}; +constexpr int kBinQuantBiasBitWidth = 1; + +// Common inputs and outputs (quantized single channel). +// data from TfLite test: SimpleBiasTestQuantizedPerChannelSingleChannel +static int kInputShapeQ1[] = {4, 1, 4, 4, 1}; +static constexpr float kInputDataQ1[] = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; +constexpr size_t kInputElementsQ1 = std::extent::value; + +constexpr int kNumChannelsQ1 = 1; +static int kFilterShapeQ1[] = {4, 1, 3, 3, 1}; +static constexpr float kFilterDataQ1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; +constexpr size_t kFilterElementsQ1 = + std::extent::value; + +static int kBiasShapeQ1[] = {1, 1}; +static constexpr float kBiasDataQ1[] = {1}; +constexpr size_t kBiasElementsQ1 = std::extent::value; + +static int kOutputShapeQ1[] = {4, 1, 4, 4, 1}; +static constexpr float kGoldenDataQ1[] = { + 30, 62, 84, 76, 100, 194, 238, 200, 208, 372, 418, 330, 264, 446, 486, 366}; +constexpr int kOutputElementsQ1 = std::extent::value; + +// compressed filter data for kBinQuant scheme, matches kFilterDataQ1 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantFilterDataQ1[] = {0x01, 0x23, 0x45, 0x67, + 0x80}; +constexpr int kBinQuantFilterBitWidthQ1 = 4; +// compressed bias data for kBinQuant scheme, matches kBiasDataQ1 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasDataQ1[] = {0x00}; +constexpr int kBinQuantBiasBitWidthQ1 = 1; + +// Common inputs and outputs (quantized multi channel). +// data from TfLite test: SimpleBiasTestQuantizedPerChannel16x8Bias64 +static int kInputShapeQ2[] = {4, 1, 2, 3, 2}; +static constexpr float kInputDataQ2[] = { + // [1 * 2 * 3 * 2] as [batch, y, x, input_channel] + 3, 2, // batch = 0, y = 0, x = 0 + 1, -1, // batch = 0, y = 0, x = 1 + -2, -3, // batch = 0, y = 0, x = 2 + 4, 3, // batch = 0, y = 1, x = 0 + 2, -2, // batch = 0, y = 1, x = 1 + -3, -4, // batch = 0, y = 1, x = 2 +}; +constexpr size_t kInputElementsQ2 = std::extent::value; + +constexpr int kNumChannelsQ2 = 2; +static int kFilterShapeQ2[] = {4, 2, 2, 2, 2}; +// Original filter data: +// static constexpr float kFilterDataQ2[] = { +// // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel] +// 1, 2, // out channel = 0, y = 0, x = 0 +// 3, 4, // out channel = 0, y = 0, x = 1 +// 3, 4, // out channel = 0, y = 1, x = 0 +// 5, 6, // out channel = 0, y = 1, x = 1 +// 7, 8, // out channel = 1, y = 0, x = 0 +// 5, 6, // out channel = 1, y = 0, x = 1 +// 3, 4, // out channel = 1, y = 1, x = 0 +// 1, 2, // out channel = 1, y = 1, x = 1 +// }; + +static int kBiasShapeQ2[] = {1, 2}; +static constexpr float kBiasDataQ2[] = {3, -2}; +constexpr size_t kBiasElementsQ2 = std::extent::value; + +static int kOutputShapeQ2[] = {4, 1, 2, 3, 2}; +static constexpr float kGoldenDataQ2[] = {10, 35, 19, 24, -6, -41, + 30, 64, 51, 40, -29, -64}; +constexpr int kOutputElementsQ2 = std::extent::value; + +// compressed filter data for kBinQuant scheme, matches kFilterDataQ2 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantFilterDataQ2[] = {0x05, 0x34, 0xE5, + 0xDE, 0x54, 0xC1}; +constexpr float kBinQuantFilterValueTableQ2[] = {1, 2, 3, 4, 5, 6, 0, 0, + 1, 2, 3, 4, 5, 6, 7, 8}; +constexpr size_t kBinQuantFilterValueTableElementsQ2 = + std::extent::value; +constexpr int kBinQuantFilterBitWidthQ2 = 3; +// compressed bias data for kBinQuant scheme, matches kBiasDataQ2 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasDataQ2[] = {0x00}; +constexpr int kBinQuantBiasBitWidthQ2 = 1; + +#endif // USE_TFLM_COMPRESSION + // Transpose conv uses TfLiteConvParams. -static TfLiteConvParams common_conv_params = {kTfLitePaddingSame, // padding - 1, // stride_width - 1, // stride_height - kTfLiteActNone, - 1, - 1, - kTfLiteNoType}; +static const TfLiteConvParams common_conv_params = { + kTfLitePaddingSame, // padding + 1, // stride_width + 1, // stride_height + kTfLiteActNone, + 1, + 1, + kTfLiteNoType}; template -TfLiteStatus InvokeTransposeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, - TfLiteConvParams* conv_params, - T* output_data) { +TfLiteStatus InvokeTransposeConv( + TfLiteTensor* tensors, int tensors_size, int output_length, + const TfLiteConvParams* conv_params, T* output_data +#ifdef USE_TFLM_COMPRESSION + , + const CompressedTensorList* comp_list_p = nullptr +#endif // USE_TFLM_COMPRESSION +) { + // TODO(b/358151309): support optional bias tensor int inputs_array_data[] = {4, 0, 1, 2, 3}; TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); int outputs_array_data[] = {1, 4}; @@ -68,7 +178,12 @@ TfLiteStatus InvokeTransposeConv(TfLiteTensor* tensors, int tensors_size, const TFLMRegistration registration = tflite::Register_TRANSPOSE_CONV(); micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, - outputs_array, conv_params); + outputs_array, conv_params +#ifdef USE_TFLM_COMPRESSION + , + nullptr, comp_list_p +#endif // USE_TFLM_COMPRESSION + ); const char* init_data = reinterpret_cast(conv_params); TfLiteStatus status = runner.InitAndPrepare(init_data); @@ -78,15 +193,45 @@ TfLiteStatus InvokeTransposeConv(TfLiteTensor* tensors, int tensors_size, return runner.Invoke(); } -template -TfLiteStatus ValidateTransposeConvGoldens(TfLiteTensor* tensors, - int tensors_size, - const T* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - T* output_data, float tolerance) { +template +TfLiteStatus ValidateTransposeConvGoldens( + TfLiteTensor* tensors, int tensors_size, const T* expected_output_data, + int output_length, const TfLiteConvParams* conv_params, T* output_data, + float tolerance = 1e-5f +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { +#ifdef USE_TFLM_COMPRESSION + + TestCompressedList tcl; + if (filter_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*filter_comp_info, tensors[kTransposeConvFilterTensor], + kTransposeConvFilterTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + if (bias_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*bias_comp_info, tensors[kTransposeConvBiasTensor], + kTransposeConvBiasTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + const CompressedTensorList* comp_list_p = tcl.GetCompressedTensorList(); + +#endif // USE_TFLM_COMPRESSION + TfLiteStatus status = InvokeTransposeConv( - tensors, tensors_size, output_length, conv_params, output_data); + tensors, tensors_size, output_length, conv_params, output_data +#ifdef USE_TFLM_COMPRESSION + , + comp_list_p +#endif // USE_TFLM_COMPRESSION + ); if (status != kTfLiteOk) { return status; } @@ -97,11 +242,18 @@ TfLiteStatus ValidateTransposeConvGoldens(TfLiteTensor* tensors, return kTfLiteOk; } +template TfLiteStatus TestTransposeConvFloat( int* input_dims_data, const float* input_data, int* filter_dims_data, const float* filter_data, int* bias_dims_data, const float* bias_data, int* output_dims_data, const float* expected_output_data, - TfLiteConvParams* conv_params, float* output_data) { + const TfLiteConvParams* conv_params, float* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); @@ -125,7 +277,12 @@ TfLiteStatus TestTransposeConvFloat( return ValidateTransposeConvGoldens(tensors, tensors_size, expected_output_data, output_dims_count, - conv_params, output_data, 0.001f); + conv_params, output_data +#ifdef USE_TFLM_COMPRESSION + , + 1e-5, filter_comp_info, bias_comp_info +#endif // USE_TFLM_COMPRESSION + ); } TfLiteStatus TestTransposeConvQuantized( @@ -135,8 +292,8 @@ TfLiteStatus TestTransposeConvQuantized( int* bias_dims_data, const float* bias_data, int32_t* bias_quantized, float* bias_scales, int* bias_zero_points, int* output_dims_data, const float* expected_output_data, int8_t* expected_output_quantized, - float output_scale, int output_zero_point, TfLiteConvParams* conv_params, - int8_t* output_data) { + float output_scale, int output_zero_point, + const TfLiteConvParams* conv_params, int8_t* output_data) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); @@ -181,8 +338,8 @@ TfLiteStatus TestTransposeConvQuantized( int* bias_dims_data, const float* bias_data, T* bias_quantized, float* bias_scales, int* bias_zero_points, int* output_dims_data, const float* expected_output_data, int16_t* expected_output_quantized, - float output_scale, int output_zero_point, TfLiteConvParams* conv_params, - int16_t* output_data) { + float output_scale, int output_zero_point, + const TfLiteConvParams* conv_params, int16_t* output_data) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); @@ -221,6 +378,80 @@ TfLiteStatus TestTransposeConvQuantized( conv_params, output_data, 4.0f); } +#ifdef USE_TFLM_COMPRESSION + +template +TfLiteStatus TestTransposeConvQuantizedCompressed( + int* input_dims_data, const float* input_data, TIO* input_quantized, + float input_scale, int input_zero_point, int* output_dims_data, + const float* expected_output_data, TIO* expected_output_quantized, + TIO* output_quantized, float output_scale, int output_zero_point, + const TfLiteConvParams* conv_params, const unsigned int tolerance, + const TestCompressionQuantizedInfo* filter_comp_info, + const TestCompressionQuantizedInfo* bias_comp_info) { + // TODO(b/358151309): account for optional bias tensor + // bool null_bias = comp_info->bias_data == nullptr ? true : false; + + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInts(filter_comp_info->dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_comp_info->dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + + TfLiteFloatArray* filter_scales = + FloatArrayFromFloats(filter_comp_info->scales); + TfLiteIntArray* filter_zero_points = + IntArrayFromInts(filter_comp_info->zero_points); + TfLiteFloatArray* bias_scales = FloatArrayFromFloats(bias_comp_info->scales); + TfLiteIntArray* bias_zero_points = + IntArrayFromInts(bias_comp_info->zero_points); + + TfLiteAffineQuantization filter_quant = {}; + TfLiteTensor filter_tensor = CreatePerChannelQuantizedTensor( + filter_comp_info->compressed, filter_dims, filter_scales, + filter_zero_points, &filter_quant, kTransposeConvQuantizedDimension, + false /* is_variable */, kTfLiteInt8); + SymmetricPerChannelQuantize( + filter_comp_info->data, filter_comp_info->value_table, + filter_scales->size * filter_comp_info->value_table_stride, + filter_scales->size, filter_scales->data); + + TfLiteAffineQuantization bias_quant = {}; + TfLiteTensor bias_tensor = CreatePerChannelQuantizedBiasTensor( + bias_comp_info->compressed, bias_dims, input_scale, filter_scales, + bias_scales, bias_zero_points, &bias_quant, + kTransposeConvQuantizedDimension, false /* is_variable */, + typeToTfLiteType()); + SymmetricPerChannelQuantize( + bias_comp_info->data, bias_comp_info->value_table, + bias_scales->size * bias_comp_info->value_table_stride, bias_scales->size, + bias_scales->data); + + int output_shape_dims_data[] = {1, 0}; + int32_t* output_shape = nullptr; + TfLiteIntArray* output_shape_dims = IntArrayFromInts(output_shape_dims_data); + + constexpr int tensors_size = kTransposeConvMaxTensors; + TfLiteTensor tensors[tensors_size] = { + CreateTensor(output_shape, output_shape_dims), + filter_tensor, + CreateQuantizedTensor(input_data, input_quantized, input_dims, + input_scale, input_zero_point), + bias_tensor, + CreateQuantizedTensor(output_quantized, output_dims, output_scale, + output_zero_point), + }; + + const int output_dims_count = ElementCount(*output_dims); + Quantize(expected_output_data, expected_output_quantized, output_dims_count, + output_scale, output_zero_point); + return ValidateTransposeConvGoldens( + tensors, tensors_size, expected_output_quantized, output_dims_count, + conv_params, output_quantized, tolerance, filter_comp_info, + bias_comp_info); +} + +#endif // USE_TFLM_COMPRESSION + } // namespace } // namespace testing } // namespace tflite @@ -240,6 +471,41 @@ TF_LITE_MICRO_TEST(SimpleTestFloat) { &tflite::testing::common_conv_params, output_data)); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestFloatCompressed) { + tflite::testing::TestCompressionInfo filter_comp_info = {}; + tflite::testing::TestCompressionInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = tflite::testing::kFilterData; + filter_comp_info.value_table_stride = + std::extent::value; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidth; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = tflite::testing::kBiasData; + bias_comp_info.value_table_stride = + std::extent::value; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidth; + + float output_data[tflite::testing::kOutputElements]; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestTransposeConvFloat( + tflite::testing::kInputShape, tflite::testing::kInputData, + tflite::testing::kFilterShape, + reinterpret_cast(tflite::testing::kBinQuantFilterData), + tflite::testing::kBiasShape, + reinterpret_cast(tflite::testing::kBinQuantBiasData), + tflite::testing::kOutputShape, tflite::testing::kGoldenData, + &tflite::testing::common_conv_params, output_data, &filter_comp_info, + &bias_comp_info)); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(fusedRELUTest) { float output_data[tflite::testing::kOutputElements]; float golden_data[] = {29, 24, 0, 0, 99, 72, 0, 0, @@ -476,4 +742,202 @@ TF_LITE_MICRO_TEST(HybridModeIsError) { &tflite::testing::common_conv_params, output_data)); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelSingleChannelCompressed) { + // data from TfLite test: SimpleBiasTestQuantizedPerChannelSingleChannel + const float input_scale = 16.0f / 255.0f; + const float output_scale = 2.0f; + const int input_zero_point = -128; + const int output_zero_point = -128; + constexpr float filter_scales[] = { + tflite::testing::kNumChannelsQ1, + 9.0f / 127.0f, + }; + constexpr int filter_zero_points[] = { + tflite::testing::kNumChannelsQ1, + 0, + }; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int8_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kFilterElementsQ1]; + int32_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int8_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int8_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kFilterElementsQ1 / tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kFilterDataQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestTransposeConvQuantizedCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1, + golden_quantized, output_quantized, output_scale, output_zero_point, + &tflite::testing::common_conv_params, 0, &filter_comp_info, + &bias_comp_info)); +} + +TF_LITE_MICRO_TEST( + SimpleBiasTestQuantizedPerChannelBias16MultiChannelCompressed) { + // data from TfLite test: SimpleBiasTestQuantizedPerChannel16x8Bias64 + const float input_scale = 4.0f / 127.0f; + const float output_scale = 128.0f / 65536.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + constexpr float filter_scales[] = { + tflite::testing::kNumChannelsQ2, + 7.0f / 127.0f, + 8.0f / 127.0f, + }; + constexpr int filter_zero_points[] = { + tflite::testing::kNumChannelsQ2, + 0, + 0, + }; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int16_t input_quantized[tflite::testing::kInputElementsQ2]; + int8_t filter_quantized[tflite::testing::kBinQuantFilterValueTableElementsQ2]; + int16_t bias_quantized[tflite::testing::kBiasElementsQ2]; + int16_t golden_quantized[tflite::testing::kOutputElementsQ2]; + int16_t output_quantized[tflite::testing::kOutputElementsQ2]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kBinQuantFilterValueTableElementsQ2 / + tflite::testing::kNumChannelsQ2; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ2; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ2; + filter_comp_info.data = tflite::testing::kBinQuantFilterValueTableQ2; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ2; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ2 / tflite::testing::kNumChannelsQ2; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ2; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ2; + bias_comp_info.data = tflite::testing::kBiasDataQ2; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ2; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + // The quantized output is compared to the expected output (quantized). + // A tolerance of 81 is approx. 0.1582f which is less than the TfLite + // tolerance of 0.19f. + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestTransposeConvQuantizedCompressed( + tflite::testing::kInputShapeQ2, tflite::testing::kInputDataQ2, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ2, tflite::testing::kGoldenDataQ2, + golden_quantized, output_quantized, output_scale, output_zero_point, + &tflite::testing::common_conv_params, 81, &filter_comp_info, + &bias_comp_info)); +} + +TF_LITE_MICRO_TEST( + SimpleBiasTestQuantizedPerChannelBias64MultiChannelCompressed) { + // data from TfLite test: SimpleBiasTestQuantizedPerChannel16x8Bias64 + const float input_scale = 4.0f / 127.0f; + const float output_scale = 128.0f / 65536.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + constexpr float filter_scales[] = { + tflite::testing::kNumChannelsQ2, + 7.0f / 127.0f, + 8.0f / 127.0f, + }; + constexpr int filter_zero_points[] = { + tflite::testing::kNumChannelsQ2, + 0, + 0, + }; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int16_t input_quantized[tflite::testing::kInputElementsQ2]; + int8_t filter_quantized[tflite::testing::kBinQuantFilterValueTableElementsQ2]; + int64_t bias_quantized[tflite::testing::kBiasElementsQ2]; + int16_t golden_quantized[tflite::testing::kOutputElementsQ2]; + int16_t output_quantized[tflite::testing::kOutputElementsQ2]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kBinQuantFilterValueTableElementsQ2 / + tflite::testing::kNumChannelsQ2; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ2; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ2; + filter_comp_info.data = tflite::testing::kBinQuantFilterValueTableQ2; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ2; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ2 / tflite::testing::kNumChannelsQ2; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ2; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ2; + bias_comp_info.data = tflite::testing::kBiasDataQ2; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ2; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + // The quantized output is compared to the expected output (quantized). + // A tolerance of 81 is approx. 0.1582f which is less than the TfLite + // tolerance of 0.19f. + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestTransposeConvQuantizedCompressed( + tflite::testing::kInputShapeQ2, tflite::testing::kInputDataQ2, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ2, tflite::testing::kGoldenDataQ2, + golden_quantized, output_quantized, output_scale, output_zero_point, + &tflite::testing::common_conv_params, 81, &filter_comp_info, + &bias_comp_info)); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/xtensa/conv.cc b/tensorflow/lite/micro/kernels/xtensa/conv.cc index 384dba9f7ac..5eb7a1bb7d4 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -52,14 +52,34 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { switch (input->type) { case kTfLiteFloat32: { +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION tflite::reference_ops::Conv( ConvParamsFloat(params, op_data.reference_op_data), tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + op_data.reference_op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + op_data.reference_op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr); diff --git a/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc b/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc index 1d2d7ec253e..965fce23167 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc @@ -145,9 +145,29 @@ TfLiteStatus ConvEvalHifiInt16(TfLiteContext* context, TfLiteNode* node, const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + const int16_t* input_data = tflite::micro::GetTensorData(input); +#ifdef USE_TFLM_COMPRESSION + const int8_t* filter_data = tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.reference_op_data.weights_scratch_index); + const int64_t* bias_data = tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + data.reference_op_data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION const int8_t* filter_data = tflite::micro::GetTensorData(filter); const int64_t* bias_data = tflite::micro::GetTensorData(bias); +#endif // USE_TFLM_COMPRESSION int16_t* output_data = tflite::micro::GetTensorData(output); int output_data_format = 0; @@ -179,7 +199,6 @@ TfLiteStatus ConvEvalHifiInt16(TfLiteContext* context, TfLiteNode* node, } else { void* p_scratch = static_cast( context->GetScratchBuffer(context, data.scratch_tensor_index)); - for (int batch = 0; batch < batches; ++batch) { int16_t* p_out_temp; p_out_temp = &output_data[batch * out_length]; @@ -243,8 +262,25 @@ TfLiteStatus ConvEvalHifiInt8(TfLiteContext* context, TfLiteNode* node, const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + const int8_t* input_data = tflite::micro::GetTensorData(input); +#ifdef USE_TFLM_COMPRESSION + const int32_t* bias_data = tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + data.reference_op_data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION const int32_t* bias_data = tflite::micro::GetTensorData(bias); +#endif // USE_TFLM_COMPRESSION int8_t* output_data = tflite::micro::GetTensorData(output); const int8_t* filter_data; @@ -257,7 +293,13 @@ TfLiteStatus ConvEvalHifiInt8(TfLiteContext* context, TfLiteNode* node, tflite::micro::GetTensorShape(filter).FlatSize(), unpacked_filter_data); filter_data = unpacked_filter_data; } else { +#ifdef USE_TFLM_COMPRESSION + filter_data = tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.reference_op_data.weights_scratch_index); +#else // USE_TFLM_COMPRESSION filter_data = tflite::micro::GetTensorData(filter); +#endif // USE_TFLM_COMPRESSION } int output_data_format = 0; diff --git a/tensorflow/lite/micro/kernels/xtensa/conv_int16_reference.cc b/tensorflow/lite/micro/kernels/xtensa/conv_int16_reference.cc index 2492d4b348b..0f583cdaceb 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv_int16_reference.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv_int16_reference.cc @@ -45,6 +45,17 @@ TfLiteStatus ConvReferenceEvalInt16(TfLiteContext* context, TfLiteNode* node) { ? tflite::micro::GetEvalInput(context, node, kConvBiasTensor) : nullptr; +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + if (bias == nullptr || bias->type == kTfLiteInt32) { reference_integer_ops::ConvPerChannel( ConvParamsQuantized(params, op_data), @@ -52,9 +63,18 @@ TfLiteStatus ConvReferenceEvalInt16(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); } else if (bias->type == kTfLiteInt64) { @@ -64,9 +84,18 @@ TfLiteStatus ConvReferenceEvalInt16(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); } else { diff --git a/tensorflow/lite/micro/kernels/xtensa/conv_int8_reference.cc b/tensorflow/lite/micro/kernels/xtensa/conv_int8_reference.cc index 6ac07bab403..ba746f0ff8f 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv_int8_reference.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv_int8_reference.cc @@ -45,6 +45,17 @@ TfLiteStatus ConvReferenceEvalInt8(TfLiteContext* context, TfLiteNode* node) { ? tflite::micro::GetEvalInput(context, node, kConvBiasTensor) : nullptr; +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + const int8_t* filter_data; if (filter->type == kTfLiteInt4) { int8_t* unpacked_filter_data = static_cast( @@ -54,7 +65,12 @@ TfLiteStatus ConvReferenceEvalInt8(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(filter).FlatSize(), unpacked_filter_data); filter_data = unpacked_filter_data; } else { +#ifdef USE_TFLM_COMPRESSION + filter_data = tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, op_data.weights_scratch_index); +#else // USE_TFLM_COMPRESSION filter_data = tflite::micro::GetTensorData(filter); +#endif // USE_TFLM_COMPRESSION } reference_integer_ops::ConvPerChannel( @@ -64,7 +80,12 @@ TfLiteStatus ConvReferenceEvalInt8(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), filter_data, tflite::micro::GetTensorShape(bias), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); diff --git a/tensorflow/lite/micro/kernels/xtensa/conv_vision.cc b/tensorflow/lite/micro/kernels/xtensa/conv_vision.cc index 812ab60ebf2..8a0330907c3 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv_vision.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv_vision.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -104,6 +104,58 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) { filter_int8 = *filter; } +#ifdef USE_TFLM_COMPRESSION + + uint8_t* filter_data = nullptr; + int32_t* bias_data = nullptr; + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + if (filter_comp_td != nullptr) { + const size_t filter_data_size = + NumElements(&filter_int8) * TfLiteTypeGetSize(kTfLiteInt8); + filter_data = + micro_context->AllocateTempBuffer(filter_data_size, sizeof(int8_t)); + if (filter_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* filter_eval = + tflite::micro::GetEvalInput(context, node, kConvWeightsTensor); + filter_data = static_cast(micro_context->DecompressTensorToBuffer( + *filter_eval, *filter_comp_td, filter_data)); + } else { + filter_data = GetTensorData(&filter_int8); + } + + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + if (bias_comp_td != nullptr) { + const size_t bias_data_size = + NumElements(bias) * TfLiteTypeGetSize(kTfLiteInt32); + bias_data = reinterpret_cast( + micro_context->AllocateTempBuffer(bias_data_size, sizeof(int32_t))); + if (bias_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* bias_eval = + tflite::micro::GetEvalInput(context, node, kConvBiasTensor); + bias_data = static_cast(micro_context->DecompressTensorToBuffer( + *bias_eval, *bias_comp_td, bias_data)); + } else { + bias_data = GetTensorData(bias); + } + + if (filter_data == nullptr || bias_data == nullptr) { + return kTfLiteError; + } + +#else // USE_TFLM_COMPRESSION + + uint8_t* filter_data = GetTensorData(&filter_int8); + int32_t* bias_data = GetTensorData(bias); + +#endif // USE_TFLM_COMPRESSION + status = xiConvSetContext( data->p_context, data->context_size, input_depth, input_width, input_height, output_depth, output_width, output_height, filter_width, @@ -112,8 +164,7 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) { data->reference_op_data.output_multiplier, data->reference_op_data.output_shift, data->reference_op_data.output_activation_min, - data->reference_op_data.output_activation_max, - (uint8_t*)GetTensorData(&filter_int8), + data->reference_op_data.output_activation_max, filter_data, data->reference_op_data.padding.width, data->reference_op_data.padding.height); if (status) { @@ -138,9 +189,7 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) { status = xiConvDoCoeffReorder( data->p_context, data->context_size, reinterpret_cast(data->reorder_coefficient_bias), - data->reorder_coefficient_bias_size, - const_cast(GetTensorData(&filter_int8)), - const_cast(GetTensorData(bias))); + data->reorder_coefficient_bias_size, filter_data, bias_data); if (status) { return kTfLiteError; } @@ -149,6 +198,17 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) { micro_context->DeallocateTempBuffer(GetTensorData(&filter_int8)); } +#ifdef USE_TFLM_COMPRESSION + + if (filter_comp_td) { + micro_context->DeallocateTempBuffer(filter_data); + } + if (bias_comp_td) { + micro_context->DeallocateTempBuffer(reinterpret_cast(bias_data)); + } + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); diff --git a/tensorflow/lite/micro/kernels/xtensa/decompress.cc b/tensorflow/lite/micro/kernels/xtensa/decompress.cc new file mode 100644 index 00000000000..13d2ce2dec7 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/decompress.cc @@ -0,0 +1,711 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef USE_TFLM_COMPRESSION + +#include "tensorflow/lite/micro/kernels/decompress.h" + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/micro/micro_common.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" +#include "tensorflow/lite/micro/micro_utils.h" + +#ifdef HIFI5 +#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" +#endif // HIFI5 + +namespace tflite { +namespace { + +#ifdef HIFI5 + +struct DecompressionStateXtensa : DecompressionState { + DecompressionStateXtensa() = delete; + + DecompressionStateXtensa(const DecompressionState& other) + : DecompressionState(other) {} + + void DecompressToBufferWidth4_Xtensa(int8_t* buffer); + void DecompressToBufferWidth3_Xtensa(int8_t* buffer); + void DecompressToBufferWidth2_Xtensa(int8_t* buffer); + + void DecompressToBufferWidthAnyInt8_Xtensa(int8_t* buffer); + void DecompressToBufferWidthAnyInt16_Xtensa(int16_t* buffer); + void DecompressToBufferWidthAnyInt32_Xtensa(int32_t* buffer); + void DecompressToBufferWidthAnyInt64_Xtensa(int64_t* buffer); +}; + +void DecompressionStateXtensa::DecompressToBufferWidth4_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int8x8 d_shuffle_t = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + int elements_per_channel_t_by_4 = elements_per_channel_ >> 4; + int elements_per_channel_t_rem = elements_per_channel_ & 0xF; + int j; + + ae_int8x8 d_out1, d_out2; + ae_int8x8 d_value_0_t, d_value_1_t; + ae_int8x8 d_value_0, d_value_1; + ae_int8x8 d_index, d_dummy; + + ae_int8x8* __restrict pIn_tmp = (ae_int8x8*)compressed_indices_; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint8_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + const uint8_t* __restrict value_table_t = value_table; + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (size_t i = 0; i < num_channels_; i++) { + value_table_t = value_table; + ae_valignx2 align_vtab = AE_LA128_PP(value_table_t); + AE_LA8X8X2_IP(d_value_0_t, d_value_1_t, align_vtab, + (ae_int8x16*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0_t, d_value_1_t, + d_shuffle_value_t); + + ae_valign align_load = AE_LA64_PP(pIn_tmp); + + for (j = 0; j < elements_per_channel_t_by_4; j++) { + AE_LA8X8_IP(d_index, align_load, pIn_tmp); + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + + value_table += stride; + if (elements_per_channel_t_rem) { + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + (elements_per_channel_t_rem >> + 1)); /* Loading 48 bits for decoding 16 weight values */ + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + } + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void DecompressionStateXtensa::DecompressToBufferWidth3_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + int i, j; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + ae_int8x8* pIn_tmp = (ae_int8x8*)compressed_indices_; + const uint8_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + const uint8_t* __restrict value_table_t = value_table; + + int num_channels_t = num_channels_; + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + + int elements_per_channel_t_by_4 = elements_per_channel_ >> 4; + int elements_per_channel_t_rem = elements_per_channel_ & 0xF; + + ae_int8x8 d_index, d_dummy; + ae_int8x8 d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11; + ae_int8x8 d_out1, d_out2; + + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + ae_int8x8 d_shuffle_t1 = AE_MOVINT8X8_FROMINT64(0x0F00050C00020000LL); + ae_int8x8 d_shuffle_t2 = AE_MOVINT8X8_FROMINT64(0x000E00040B000100LL); + ae_int8x8 d_shuffle_t3 = AE_MOVINT8X8_FROMINT64(0x0F060D040C030A01LL); + ae_int8x8 d_shuffle_t = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (i = 0; i < num_channels_t; i++) { + ae_int8x8 d_value_0 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + ae_int8x8 d_value_1 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + + value_table_t = value_table; + + ae_valign align_vtab = AE_LA64_PP(value_table_t); + AE_LA8X8_IP(d_value_0, align_vtab, (ae_int8x8*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0, d_value_1, d_shuffle_value_t); + + for (j = 0; j < elements_per_channel_t_by_4; j++) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 6); /* Loading 48 bits for decoding 16 weight values */ + + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 1)); + d2 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d3 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 3)); + d4 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 4)); + + d1 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), 0x7007007007000000LL)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d2), 0x0700700700700000LL)); + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d3), 0x0070070070070000LL)); + d4 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d4), 0x0007007007007000LL)); + + d5 = d1 | d2; + d6 = d3 | d4; + + d7 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d5), 4)); + d8 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d6), 4)); + + d9 = AE_SEL8X8(d5, d7, d_shuffle_t1); + d10 = AE_SEL8X8(d6, d8, d_shuffle_t2); + d11 = AE_SEL8X8(d9, d10, d_shuffle_t3); + + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d11); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + if (elements_per_channel_t_rem) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 3); /* Loading 48 bits for decoding 16 weight values */ + + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 1)); + d2 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d3 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 3)); + d4 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 4)); + + d1 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), 0x7007007007000000LL)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d2), 0x0700700700700000LL)); + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d3), 0x0070070070070000LL)); + d4 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d4), 0x0007007007007000LL)); + + d5 = d1 | d2; + d6 = d3 | d4; + + d7 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d5), 4)); + d8 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d6), 4)); + + d9 = AE_SEL8X8(d5, d7, d_shuffle_t1); + d10 = AE_SEL8X8(d6, d8, d_shuffle_t2); + d11 = AE_SEL8X8(d9, d10, d_shuffle_t3); + + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d11); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + } + + value_table = value_table + stride; + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void DecompressionStateXtensa::DecompressToBufferWidth2_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + int i, j; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + ae_int8x8* pIn_tmp = (ae_int8x8*)compressed_indices_; + const uint8_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + const uint8_t* __restrict value_table_t = value_table; + + int num_channels_t = num_channels_; + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + + int elements_per_channel_t_by_5 = elements_per_channel_ >> 5; + int elements_per_channel_t_rem = elements_per_channel_ & 0x1F; + int elements_per_channel_t_rem_minus_16 = 0; + if (elements_per_channel_t_rem > 16) { + elements_per_channel_t_rem_minus_16 = elements_per_channel_t_rem - 16; + } + + ae_int8x8 d_index, d_dummy; + ae_int8x8 d0, d1, d2, d3, d4, d5; + ae_int8x8 q0, q1, q2, q3; + ae_int8x8 d_out1, d_out2; + + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + ae_int8x8 d_shuffle_t1 = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + ae_int8x8 d_shuffle_t2 = AE_MOVINT8X8_FROMINT64(0xFBEA7362D9C85140LL); + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (i = 0; i < num_channels_t; i++) { + ae_int8x8 d_value_0 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + ae_int8x8 d_value_1 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + + value_table_t = value_table; + + ae_valign align_vtab = AE_LA64_PP(value_table_t); + AE_LA8X8_IP(d_value_0, align_vtab, (ae_int8x8*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0, d_value_1, d_shuffle_value_t); + + for (j = 0; j < elements_per_channel_t_by_5; j++) { + // AE_LA8X8_IP( d_index, align_index, pIn_tmp ); /* Loading 64 bits + // for decoding 32 weight values */ + + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 8); /* Loading 64 bits for decoding 32 weight values */ + d0 = d_index; + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d0), + 0x3333333333333333LL)); // i1,i3,i5, .... + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), + 0x3333333333333333LL)); // i0,i2,i4, .... + + AE_DSEL8X8(d4, d5, d3, d2, + d_shuffle_t1); // d4 = i0,i2,i1,i3,i4,i6,... d5 = + // i16,i18, i17,i19, .... + + AE_DSEL8X8(q0, q1, d_value_0, d_value_1, + d4); // q0 = 0,1,4,5,8,9,12,13 q1 = 2,3,6,7,10,11,14,15 + AE_DSEL8X8( + q2, q3, d_value_0, d_value_1, + d5); // q2 = 16,17,20,21,24,25,28,29 q3 = 18,19,22,23,26,27,30,31 + + AE_DSEL8X8(d_out1, d_out2, q0, q1, d_shuffle_t2); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + + AE_DSEL8X8(d_out1, d_out2, q2, q3, d_shuffle_t2); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + if (elements_per_channel_t_rem) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + (elements_per_channel_t_rem >> + 2)); /* Loading 48 bits for decoding 16 weight values */ + d0 = d_index; + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d0), + 0x3333333333333333LL)); // i1,i3,i5, .... + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), + 0x3333333333333333LL)); // i0,i2,i4, .... + + AE_DSEL8X8(d4, d5, d3, d2, + d_shuffle_t1); // d4 = i0,i2,i1,i3,i4,i6,... d5 = + // i16,i18, i17,i19, .... + + AE_DSEL8X8(q0, q1, d_value_0, d_value_1, + d4); // q0 = 0,1,4,5,8,9,12,13 q1 = 2,3,6,7,10,11,14,15 + AE_DSEL8X8( + q2, q3, d_value_0, d_value_1, + d5); // q2 = 16,17,20,21,24,25,28,29 q3 = 18,19,22,23,26,27,30,31 + + AE_DSEL8X8(d_out1, d_out2, q0, q1, d_shuffle_t2); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + + AE_DSEL8X8(d_out1, d_out2, q2, q3, d_shuffle_t2); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem_minus_16); + } + + value_table = value_table + stride; + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void DecompressionStateXtensa::DecompressToBufferWidthAnyInt8_Xtensa( + int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint8_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (comp_data_.data.lut_data->use_alternate_axis) { + int count = count_indices_; + const uint8_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + uint32_t index_1, index_2; + uint32_t mask_bits = (1 << compressed_bit_width_) - 1; + + for (int i = 0; i < num_channels_t; i++) { + elements_per_channel_t = elements_per_channel_; + /* if output pointer is not 2 byte aligned */ + if ((unsigned int)p_out_tmp & 0x1) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + elements_per_channel_t = elements_per_channel_t - 1; + } + for (int j = 0; j < (elements_per_channel_t >> 1); j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, 2 * bw); + index_1 = (index >> compressed_bit_width_) & mask_bits; + index_2 = (index)&mask_bits; + ae_int8x8 d_tmp1 = AE_L8_X((const ae_int8*)value_table, index_1); + ae_int8x8 d_tmp2 = AE_L8_X((const ae_int8*)value_table, index_2); + ae_int16x4 d_tmp = + AE_MOVINT16X4_FROMINT8X8(AE_SEL8X8I(d_tmp2, d_tmp1, 21)); + AE_S16_0_IP(d_tmp, (ae_int16*)p_out_tmp, 2); + } + if (elements_per_channel_t & 0x1) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + } + value_table += stride; + } + } +} + +void DecompressionStateXtensa::DecompressToBufferWidthAnyInt16_Xtensa( + int16_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint16_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int16* __restrict p_out_tmp = (ae_int16*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (comp_data_.data.lut_data->use_alternate_axis) { + int count = count_indices_; + const uint16_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int16x4 d_tmp = AE_L16_X((const ae_int16*)value_table, index << 1); + AE_S16_0_IP(d_tmp, p_out_tmp, 2); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int16x4 d_tmp = AE_L16_X((const ae_int16*)value_table, index << 1); + AE_S16_0_IP(d_tmp, p_out_tmp, 2); + } + + value_table += stride; + } + } +} + +void DecompressionStateXtensa::DecompressToBufferWidthAnyInt32_Xtensa( + int32_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint32_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int32* __restrict p_out_tmp = (ae_int32*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (comp_data_.data.lut_data->use_alternate_axis) { + int count = count_indices_; + const uint32_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int32x2 d_tmp = AE_L32_X((const ae_int32*)value_table, index << 2); + AE_S32_L_IP(d_tmp, p_out_tmp, 4); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int32x2 d_tmp = AE_L32_X((const ae_int32*)value_table, index << 2); + AE_S32_L_IP(d_tmp, p_out_tmp, 4); + } + + value_table += stride; + } + } +} + +void DecompressionStateXtensa::DecompressToBufferWidthAnyInt64_Xtensa( + int64_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint64_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int64* __restrict p_out_tmp = (ae_int64*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (comp_data_.data.lut_data->use_alternate_axis) { + int count = count_indices_; + const uint64_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int64 d_tmp = AE_L64_X((const ae_int64*)value_table, index << 3); + AE_S64_IP(d_tmp, p_out_tmp, 8); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int64 d_tmp = AE_L64_X((const ae_int64*)value_table, index << 3); + AE_S64_IP(d_tmp, p_out_tmp, 8); + } + + value_table += stride; + } + } +} + +#endif // HIFI5 + +} // namespace + +#ifdef HIFI5 + +template <> +bool* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); + + return static_cast(buffer); +} + +template <> +int8_t* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + if (comp_data_.data.lut_data->compressed_bit_width == 4 && + !comp_data_.data.lut_data->use_alternate_axis) { + if (!(elements_per_channel_ & 0x01)) { + dsx.DecompressToBufferWidth4_Xtensa(static_cast(buffer)); + } else { + dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); + } + } else if (comp_data_.data.lut_data->compressed_bit_width == 3 && + !comp_data_.data.lut_data->use_alternate_axis) { + if (!(elements_per_channel_ & 0x07)) { + dsx.DecompressToBufferWidth3_Xtensa(static_cast(buffer)); + } else { + dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); + } + } else if (comp_data_.data.lut_data->compressed_bit_width == 2 && + !comp_data_.data.lut_data->use_alternate_axis) { + if (!(elements_per_channel_ & 0x03)) { + dsx.DecompressToBufferWidth2_Xtensa(static_cast(buffer)); + } else { + dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); + } + } else { + dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); + } + + return static_cast(buffer); +} + +template <> +int16_t* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + dsx.DecompressToBufferWidthAnyInt16_Xtensa(static_cast(buffer)); + + return static_cast(buffer); +} + +template <> +int32_t* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + dsx.DecompressToBufferWidthAnyInt32_Xtensa(static_cast(buffer)); + + return static_cast(buffer); +} + +template <> +float* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + dsx.DecompressToBufferWidthAnyInt32_Xtensa(static_cast(buffer)); + + return static_cast(buffer); +} + +template <> +int64_t* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + dsx.DecompressToBufferWidthAnyInt64_Xtensa(static_cast(buffer)); + + return static_cast(buffer); +} + +#else // HIFI5 + +template +T* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 4 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth4_16(static_cast(buffer)); + } else if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 3 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth3_32(static_cast(buffer)); + } else if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 2 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth2_16(static_cast(buffer)); + } else { + DecompressToBufferWidthAny(static_cast(buffer)); + } + + return static_cast(buffer); +} + +template bool* DecompressionState::DecompressToBuffer(void*); +template float* DecompressionState::DecompressToBuffer(void*); +template int8_t* DecompressionState::DecompressToBuffer(void*); +template int16_t* DecompressionState::DecompressToBuffer(void*); +template int32_t* DecompressionState::DecompressToBuffer(void*); +template int64_t* DecompressionState::DecompressToBuffer(void*); + +#endif // HIFI5 + +} // namespace tflite + +#endif // USE_TFLM_COMPRESSION diff --git a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc index 8536ff79507..088420aee17 100644 --- a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc @@ -1,5 +1,5 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -93,6 +93,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteEvalTensor filter_int8 = tflite::micro::MakeUnpackedInt4Tensor( context, op_data.reference_op_data.filter_buffer_index, filter); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kDepthwiseConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + switch (input->type) { // Already know in/out types are same. case kTfLiteInt8: { switch (filter_int8.type) { @@ -111,9 +123,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, &filter_int8, filter_comp_td, + op_data.reference_op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + op_data.reference_op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(&filter_int8), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); #endif // defined(HIFI3) || defined(HIFI4) || defined(HIFI5) @@ -136,9 +158,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, &filter_int8, filter_comp_td, + op_data.reference_op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + op_data.reference_op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(&filter_int8), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; diff --git a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc index 8c2052b23e7..9bff6d997a3 100644 --- a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc +++ b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -97,10 +97,22 @@ TfLiteStatus DepthwiseConvEvalHifi(TfLiteContext* context, TfLiteNode* node, const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) { +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kDepthwiseConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + // If dilation is not required use the optimized NN Library kernel. // Otherwise call the reference implementation. if ((params.dilation_width_factor == 1) && - (params.dilation_height_factor == 1)) { + (params.dilation_height_factor == 1) && bias != nullptr) { const int stride_width = params.stride_width; const int stride_height = params.stride_height; const int pad_width = data.reference_op_data.padding.width; @@ -133,8 +145,17 @@ TfLiteStatus DepthwiseConvEvalHifi(TfLiteContext* context, TfLiteNode* node, TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); const int8_t* input_data = tflite::micro::GetTensorData(input); +#ifdef USE_TFLM_COMPRESSION + const int8_t* filter_data = tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, + data.reference_op_data.weights_scratch_index); + const int32_t* bias_data = tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + data.reference_op_data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION const int8_t* filter_data = tflite::micro::GetTensorData(filter); const int32_t* bias_data = tflite::micro::GetTensorData(bias); +#endif // USE_TFLM_COMPRESSION int8_t* output_data = tflite::micro::GetTensorData(output); int32_t input_data_format = 0; @@ -178,9 +199,19 @@ TfLiteStatus DepthwiseConvEvalHifi(TfLiteContext* context, TfLiteNode* node, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, + data.reference_op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + data.reference_op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); diff --git a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc index 35fa8cf1c1a..23e18dc8342 100644 --- a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc +++ b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -53,7 +53,7 @@ TfLiteStatus DepthwiseConvPrepareVision(TfLiteContext* context, TF_LITE_ENSURE(context, filter != nullptr); TfLiteTensor* bias = micro_context->AllocateTempInputTensor(node, kDepthwiseConvBiasTensor); - TF_LITE_ENSURE(context, filter != nullptr); + TF_LITE_ENSURE(context, bias != nullptr); // Dynamically allocate per-channel quantization parameters. const int num_channels = @@ -135,18 +135,81 @@ TfLiteStatus DepthwiseConvPrepareVision(TfLiteContext* context, filter_int8 = *filter; } +#ifdef USE_TFLM_COMPRESSION + + uint8_t* filter_data = nullptr; + int32_t* bias_data = nullptr; + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kDepthwiseConvWeightsTensor); + if (filter_comp_td != nullptr) { + const size_t filter_data_size = + NumElements(&filter_int8) * TfLiteTypeGetSize(kTfLiteInt8); + filter_data = + micro_context->AllocateTempBuffer(filter_data_size, sizeof(int8_t)); + if (filter_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* filter_eval = + tflite::micro::GetEvalInput(context, node, kDepthwiseConvWeightsTensor); + filter_data = static_cast(micro_context->DecompressTensorToBuffer( + *filter_eval, *filter_comp_td, filter_data)); + } else { + filter_data = GetTensorData(&filter_int8); + } + + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor); + if (bias_comp_td != nullptr) { + const size_t bias_data_size = + NumElements(bias) * TfLiteTypeGetSize(kTfLiteInt32); + bias_data = reinterpret_cast( + micro_context->AllocateTempBuffer(bias_data_size, sizeof(int32_t))); + if (bias_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* bias_eval = + tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor); + bias_data = static_cast(micro_context->DecompressTensorToBuffer( + *bias_eval, *bias_comp_td, bias_data)); + } else { + bias_data = GetTensorData(bias); + } + + if (filter_data == nullptr || bias_data == nullptr) { + return kTfLiteError; + } + +#else // USE_TFLM_COMPRESSION + + uint8_t* filter_data = GetTensorData(&filter_int8); + int32_t* bias_data = GetTensorData(bias); + +#endif // USE_TFLM_COMPRESSION + status = xiDepthwiseConvDoCoeffReorder( data->p_context, data->context_size, reinterpret_cast(data->reorder_coefficient_bias), - data->reorder_coefficient_bias_size, - const_cast(GetTensorData(&filter_int8)), - const_cast(GetTensorData(bias))); + data->reorder_coefficient_bias_size, filter_data, bias_data); if (status) { return kTfLiteError; } if (filter->type == kTfLiteInt4) { micro_context->DeallocateTempBuffer(GetTensorData(&filter_int8)); } + +#ifdef USE_TFLM_COMPRESSION + + if (filter_comp_td) { + micro_context->DeallocateTempBuffer(filter_data); + } + if (bias_comp_td) { + micro_context->DeallocateTempBuffer(reinterpret_cast(bias_data)); + } + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc index df5458001b7..4a141784d8f 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,6 +45,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, + kFullyConnectedWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + TFLITE_DCHECK(node->user_data != nullptr); const auto& data = @@ -58,9 +70,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -93,9 +114,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc index cf87c5ff1ed..91b9f40c907 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -120,6 +120,23 @@ TfLiteStatus XtensaPrepareFullyConnected(TfLiteContext* context, context, params->activation, input->type, input, filter, bias, output, data)); +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + if (micro_context->IsTensorCompressed(node, kFullyConnectedWeightsTensor) && + filter->type == kTfLiteInt4) { + MicroPrintf("Compression not supported with INT4 tensors"); + return kTfLiteError; + } + data->weights_scratch_index = + micro_context->AllocateDecompressionScratchBuffer( + node, kFullyConnectedWeightsTensor); + data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer( + node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); if (bias != nullptr) { diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc index f850c0c0fca..32dfba2e5a8 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,12 +32,37 @@ TfLiteStatus XtensaEvalFullyConnectedQuantizedInt8( const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) { #if !defined(VISION_P6) + +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, + kFullyConnectedWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + const int32_t* bias_data = +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION tflite::micro::GetOptionalTensorData(bias); +#endif // USE_TFLM_COMPRESSION + + const int8_t* filter_data = +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, data.weights_scratch_index); +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter); +#endif // USE_TFLM_COMPRESSION // P6 Vision will handle INT4 filters as a reference operation. // For all other architectures, unpack INT4 here. - const int8_t* filter_data = tflite::micro::GetTensorData(filter); if (filter->type == kTfLiteInt4) { int8_t* unpacked_filter_data = static_cast( context->GetScratchBuffer(context, data.filter_buffer_index)); @@ -47,6 +72,7 @@ TfLiteStatus XtensaEvalFullyConnectedQuantizedInt8( tflite::micro::GetTensorShape(filter).FlatSize(), unpacked_filter_data); filter_data = unpacked_filter_data; } + #endif // !defined(VISION_P6) #if defined(HIFIMINI) diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc index 14bb9a12b15..24fd1258277 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -108,14 +108,66 @@ TfLiteStatus FullyConnectedPrepareVision(TfLiteContext* context, filter_int8 = *filter; } +#ifdef USE_TFLM_COMPRESSION + + uint8_t* filter_data = nullptr; + int32_t* bias_data = nullptr; + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kFullyConnectedWeightsTensor); + if (filter_comp_td != nullptr) { + const size_t filter_data_size = + NumElements(&filter_int8) * TfLiteTypeGetSize(kTfLiteInt8); + filter_data = + micro_context->AllocateTempBuffer(filter_data_size, sizeof(int8_t)); + if (filter_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* filter_eval = tflite::micro::GetEvalInput( + context, node, kFullyConnectedWeightsTensor); + filter_data = static_cast(micro_context->DecompressTensorToBuffer( + *filter_eval, *filter_comp_td, filter_data)); + } else { + filter_data = GetTensorData(&filter_int8); + } + + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); + if (bias_comp_td != nullptr) { + const size_t bias_data_size = + NumElements(bias) * TfLiteTypeGetSize(kTfLiteInt32); + bias_data = reinterpret_cast( + micro_context->AllocateTempBuffer(bias_data_size, sizeof(int32_t))); + if (bias_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* bias_eval = + tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor); + bias_data = static_cast(micro_context->DecompressTensorToBuffer( + *bias_eval, *bias_comp_td, bias_data)); + } else { + bias_data = GetTensorData(bias); + } + + if (filter_data == nullptr || (bias != nullptr && bias_data == nullptr)) { + return kTfLiteError; + } + +#else // USE_TFLM_COMPRESSION + + uint8_t* filter_data = GetTensorData(&filter_int8); + int32_t* bias_data = GetTensorData(bias); + +#endif // USE_TFLM_COMPRESSION + status = xiFullyConnectedSetContext( data->p_context, data->context_size, inputDims, outputDims, filterDims, 1, input->params.zero_point, filter->params.zero_point, output->params.zero_point, data->reference_op_data.output_multiplier, data->reference_op_data.output_shift, data->reference_op_data.output_activation_min, - data->reference_op_data.output_activation_max, - (uint8_t*)GetTensorData(&filter_int8)); + data->reference_op_data.output_activation_max, filter_data); if (status) { return kTfLiteError; @@ -139,9 +191,7 @@ TfLiteStatus FullyConnectedPrepareVision(TfLiteContext* context, status = xiFullyConnectedDoCoeffReorder( data->p_context, data->context_size, reinterpret_cast(data->reorder_coefficient_bias), - data->reorder_coefficient_bias_size, - const_cast(GetTensorData(&filter_int8)), - const_cast(GetTensorData(bias))); + data->reorder_coefficient_bias_size, filter_data, bias_data); if (status) { return kTfLiteError; } @@ -149,6 +199,18 @@ TfLiteStatus FullyConnectedPrepareVision(TfLiteContext* context, if (filter->type == kTfLiteInt4) { micro_context->DeallocateTempBuffer(GetTensorData(&filter_int8)); } + +#ifdef USE_TFLM_COMPRESSION + + if (filter_comp_td) { + micro_context->DeallocateTempBuffer(filter_data); + } + if (bias_comp_td) { + micro_context->DeallocateTempBuffer(reinterpret_cast(bias_data)); + } + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); diff --git a/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc b/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc index 44a9f86049c..d46ef0ba88a 100644 --- a/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -49,6 +49,14 @@ struct OpData { // A scratch buffer is required for quantized implementations. int scratch_buffer_index; +#ifdef USE_TFLM_COMPRESSION + + // scratch buffers for compressed tensors + int filter_scratch_index; + int bias_scratch_index; + +#endif // USE_TFLM_COMPRESSION + // TODO(b/192090531): Remove this once all 8x16 transpose conv models use // 64-bit biases. int bias_converted_buffer_index; @@ -268,6 +276,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { data->params.stride_width = params->stride_width; data->params.stride_height = params->stride_height; +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + data->filter_scratch_index = + micro_context->AllocateDecompressionScratchBuffer(node, kFilterTensor); + data->bias_scratch_index = + micro_context->AllocateDecompressionScratchBuffer(node, kBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); @@ -286,6 +305,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, kOutputTensor); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, kFilterTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kBiasTensor); + +#endif // USE_TFLM_COMPRESSION + TFLITE_DCHECK(node->user_data != nullptr); const OpData& data = *(static_cast(node->user_data)); @@ -309,9 +339,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { op_params, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, data.filter_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr); @@ -321,7 +359,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { int32_t* scratch_buffer = static_cast( context->GetScratchBuffer(context, data.scratch_buffer_index)); #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5) - if (bias->type == kTfLiteInt32) { + if (bias != nullptr && bias->type == kTfLiteInt32) { const RuntimeShape& input_shape = tflite::micro::GetTensorShape(input); const RuntimeShape& filter_shape = tflite::micro::GetTensorShape(filter); @@ -343,9 +381,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); const int8_t* input_data = tflite::micro::GetTensorData(input); +#ifdef USE_TFLM_COMPRESSION + const int8_t* filter_data = tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, data.filter_scratch_index); + const int32_t* bias_data = tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION const int8_t* filter_data = tflite::micro::GetTensorData(filter); const int32_t* bias_data = tflite::micro::GetTensorData(bias); +#endif // USE_TFLM_COMPRESSION int8_t* output_data = tflite::micro::GetTensorData(output); const int num_elements = output_shape.FlatSize(); @@ -369,9 +414,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { data.per_channel_output_shift, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.filter_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer); @@ -382,9 +436,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { data.per_channel_output_shift, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, data.filter_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer); @@ -396,20 +458,36 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { context->GetScratchBuffer(context, data.scratch_buffer_index)); // TODO(b/192090531): Remove this once all 8x16 transpose conv models use // 64-bit biases. - if (bias->type == kTfLiteInt16) { - std::int64_t* bias_converted_buffer = - static_cast(context->GetScratchBuffer( - context, data.bias_converted_buffer_index)); - for (int i = 0; i < tflite::micro::GetTensorShape(bias).FlatSize(); - i++) { - bias_converted_buffer[i] = bias->data.i16[i]; + if (bias == nullptr || bias->type == kTfLiteInt16) { + std::int64_t* bias_converted_buffer = nullptr; + if (bias != nullptr) { + bias_converted_buffer = + static_cast(context->GetScratchBuffer( + context, data.bias_converted_buffer_index)); + const int16_t* const bias_int16_data = +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION + static_cast(bias->data.data); +#endif // USE_TFLM_COMPRESSION + for (int i = 0; i < tflite::micro::GetTensorShape(bias).FlatSize(); + i++) { + bias_converted_buffer[i] = bias_int16_data[i]; + } } reference_integer_ops::TransposeConv( data.params, data.per_channel_output_multiplier, data.per_channel_output_shift, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.filter_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(bias), bias_converted_buffer, tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), @@ -438,9 +516,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const int output_width = output_shape.Dims(2); const int16_t* input_data = tflite::micro::GetTensorData(input); +#ifdef USE_TFLM_COMPRESSION + const int8_t* filter_data = tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, data.filter_scratch_index); + const int64_t* bias_data = tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION const int8_t* filter_data = tflite::micro::GetTensorData(filter); const int64_t* bias_data = tflite::micro::GetTensorData(bias); +#endif // USE_TFLM_COMPRESSION int16_t* output_data = tflite::micro::GetTensorData(output); const int num_elements = output_shape.FlatSize(); @@ -457,15 +542,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { data.per_channel_output_shift, data.per_channel_output_multiplier, scratch_buffer); } -#else // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5) +#else // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5) reference_integer_ops::TransposeConv( data.params, data.per_channel_output_multiplier, data.per_channel_output_shift, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.filter_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer); diff --git a/tensorflow/lite/micro/memory_planner/greedy_memory_planner.cc b/tensorflow/lite/micro/memory_planner/greedy_memory_planner.cc index a087b236cc9..62099308c7e 100644 --- a/tensorflow/lite/micro/memory_planner/greedy_memory_planner.cc +++ b/tensorflow/lite/micro/memory_planner/greedy_memory_planner.cc @@ -31,7 +31,7 @@ char GetOrdinalCharacter(int i) { } else if (i < 62) { return 'A' + (i - 36); } - return '*'; + return GetOrdinalCharacter(i % 62); } } // namespace @@ -335,9 +335,14 @@ void GreedyMemoryPlanner::PrintMemoryPlan() { CalculateOffsetsIfNeeded(); for (int i = 0; i < buffer_count_; ++i) { - MicroPrintf("%c (id=%d): size=%d, offset=%d, first_used=%d last_used=%d", - GetOrdinalCharacter(i), i, requirements_[i].size, - buffer_offsets_[i], requirements_[i].first_time_used, + char c = '*'; + if (requirements_[i].first_time_used != requirements_[i].last_time_used) { + // not a scratch buffer nor subgraph output tensor + c = GetOrdinalCharacter(i); + } + MicroPrintf("%c (id=%d): size=%d, offset=%d, first_used=%d last_used=%d", c, + i, requirements_[i].size, buffer_offsets_[i], + requirements_[i].first_time_used, requirements_[i].last_time_used); } @@ -379,7 +384,12 @@ void GreedyMemoryPlanner::PrintMemoryPlan() { const int line_end = ((offset + size) * kLineWidth) / max_size; for (int n = line_start; n < line_end; ++n) { if (line[n] == '.') { - line[n] = GetOrdinalCharacter(i); + if (requirements->first_time_used == requirements->last_time_used) { + // scratch buffer or subgraph output tensor + line[n] = '*'; + } else { + line[n] = GetOrdinalCharacter(i); + } } else { line[n] = '!'; } @@ -387,7 +397,7 @@ void GreedyMemoryPlanner::PrintMemoryPlan() { } line[kLineWidth] = 0; - MicroPrintf("%s%d: %s (%dk)", t < 10 ? " " : "", t, (const char*)line, + MicroPrintf("%4d: %s (%dk)", t, (const char*)line, (memory_use + 1023) / 1024); } } diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index f5f6e38e003..3ec00a6b614 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -44,8 +44,6 @@ TFLMRegistration* Register_DETECTION_POSTPROCESS(); template class MicroMutableOpResolver : public MicroOpResolver { public: - TF_LITE_REMOVE_VIRTUAL_DELETE - explicit MicroMutableOpResolver() {} const TFLMRegistration* FindOp(tflite::BuiltinOperator op) const override { @@ -704,6 +702,8 @@ class MicroMutableOpResolver : public MicroOpResolver { BuiltinOperator builtin_codes_[tOpCount]; TfLiteBridgeBuiltinParseFunction builtin_parsers_[tOpCount]; unsigned int num_buitin_ops_ = 0; + + TF_LITE_REMOVE_VIRTUAL_DELETE }; }; // namespace tflite diff --git a/tensorflow/lite/micro/micro_profiler.cc b/tensorflow/lite/micro/micro_profiler.cc index ebead51a90d..e349bf73668 100644 --- a/tensorflow/lite/micro/micro_profiler.cc +++ b/tensorflow/lite/micro/micro_profiler.cc @@ -86,14 +86,14 @@ void MicroProfiler::LogTicksPerTagCsv() { TFLITE_DCHECK(tags_[i] != nullptr); int position = FindExistingOrNextPosition(tags_[i]); TFLITE_DCHECK(position >= 0); - total_ticks_per_tag[position].tag = tags_[i]; - total_ticks_per_tag[position].ticks = - total_ticks_per_tag[position].ticks + ticks; + total_ticks_per_tag_[position].tag = tags_[i]; + total_ticks_per_tag_[position].ticks = + total_ticks_per_tag_[position].ticks + ticks; total_ticks += ticks; } for (int i = 0; i < num_events_; ++i) { - TicksPerTag each_tag_entry = total_ticks_per_tag[i]; + TicksPerTag each_tag_entry = total_ticks_per_tag_[i]; if (each_tag_entry.tag == nullptr) { break; } @@ -112,7 +112,7 @@ void MicroProfiler::LogTicksPerTagCsv() { int MicroProfiler::FindExistingOrNextPosition(const char* tag_name) { int pos = 0; for (; pos < num_events_; pos++) { - TicksPerTag each_tag_entry = total_ticks_per_tag[pos]; + TicksPerTag each_tag_entry = total_ticks_per_tag_[pos]; if (each_tag_entry.tag == nullptr || strcmp(each_tag_entry.tag, tag_name) == 0) { return pos; @@ -120,4 +120,13 @@ int MicroProfiler::FindExistingOrNextPosition(const char* tag_name) { } return pos < num_events_ ? pos : -1; } + +void MicroProfiler::ClearEvents() { + for (int i = 0; i < num_events_; i++) { + total_ticks_per_tag_[i].tag = nullptr; + } + + num_events_ = 0; +} + } // namespace tflite diff --git a/tensorflow/lite/micro/micro_profiler.h b/tensorflow/lite/micro/micro_profiler.h index b52ebcb4ea9..fd8bc42ffd4 100644 --- a/tensorflow/lite/micro/micro_profiler.h +++ b/tensorflow/lite/micro/micro_profiler.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,7 +45,7 @@ class MicroProfiler : public MicroProfilerInterface { virtual void EndEvent(uint32_t event_handle) override; // Clears all the events that have been currently profiled. - void ClearEvents() { num_events_ = 0; } + void ClearEvents(); // Returns the sum of the ticks taken across all the events. This number // is only meaningful if all of the events are disjoint (the end time of @@ -83,7 +83,7 @@ class MicroProfiler : public MicroProfilerInterface { // In practice, the number of tags will be much lower than the number of // events. But it is theoretically possible that each event to be unique and // hence we allow total_ticks_per_tag to have kMaxEvents entries. - TicksPerTag total_ticks_per_tag[kMaxEvents] = {}; + TicksPerTag total_ticks_per_tag_[kMaxEvents] = {}; int FindExistingOrNextPosition(const char* tag_name); diff --git a/tensorflow/lite/micro/micro_resource_variable.cc b/tensorflow/lite/micro/micro_resource_variable.cc index 767e7d17d6f..843aac664bc 100644 --- a/tensorflow/lite/micro/micro_resource_variable.cc +++ b/tensorflow/lite/micro/micro_resource_variable.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -113,8 +113,8 @@ TfLiteStatus MicroResourceVariables::Allocate(int id, TfLiteContext* context, return kTfLiteOk; } -TfLiteStatus MicroResourceVariables::Assign(int id, - const TfLiteEvalTensor* tensor) { +TfLiteStatus MicroResourceVariables::Assign(int id, size_t count_bytes, + const void* input_buffer) { if (id < 0 || id >= num_resource_variables_) { MicroPrintf("Attempting to read non-existent resource variable %d", id); return kTfLiteError; @@ -128,8 +128,9 @@ TfLiteStatus MicroResourceVariables::Assign(int id, "with a TfLiteTensor first."); return kTfLiteError; } - TFLITE_DCHECK(EvalTensorBytes(tensor) == variable.bytes); - memcpy(variable.resource_buffer, tensor->data.raw, variable.bytes); + TFLITE_DCHECK(count_bytes == variable.bytes); + TFLITE_DCHECK(input_buffer != nullptr); + memcpy(variable.resource_buffer, input_buffer, variable.bytes); return kTfLiteOk; } diff --git a/tensorflow/lite/micro/micro_resource_variable.h b/tensorflow/lite/micro/micro_resource_variable.h index fb9917d4784..57da6497b3a 100644 --- a/tensorflow/lite/micro/micro_resource_variable.h +++ b/tensorflow/lite/micro/micro_resource_variable.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,10 +46,10 @@ class MicroResourceVariables { TfLiteStatus Allocate(int id, TfLiteContext* context, const TfLiteTensor* tensor); - // Copies input tensor contents to the resource buffer. + // Copies input_buffer contents to the resource buffer. // AllocateResourceVariable with a TFLite tensor must have been called first // in order to allocate the resource buffer. - TfLiteStatus Assign(int id, const TfLiteEvalTensor* tensor); + TfLiteStatus Assign(int id, size_t count_bytes, const void* input_buffer); // Zeros out all resource buffers. TfLiteStatus ResetAll(); diff --git a/tensorflow/lite/micro/micro_resource_variable_test.cc b/tensorflow/lite/micro/micro_resource_variable_test.cc index 13868bb440d..a30718cb994 100644 --- a/tensorflow/lite/micro/micro_resource_variable_test.cc +++ b/tensorflow/lite/micro/micro_resource_variable_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/lite/micro/micro_resource_variable.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" @@ -120,7 +121,9 @@ TF_LITE_MICRO_TEST(VerifyAssignAndReadResourceBuffer) { .type = kTfLiteFloat32, }; - resource_variables->Assign(id, &assign_tensor); + resource_variables->Assign( + id, tflite::EvalTensorBytes(&assign_tensor), + tflite::micro::GetTensorData(&assign_tensor)); int32_t buffer[32]; TfLiteEvalTensor read_tensor = { diff --git a/tensorflow/lite/micro/micro_utils.h b/tensorflow/lite/micro/micro_utils.h index 98ef81dc8ed..b362d3402bb 100644 --- a/tensorflow/lite/micro/micro_utils.h +++ b/tensorflow/lite/micro/micro_utils.h @@ -90,12 +90,19 @@ void SymmetricQuantize(const float* input, T* output, int num_elements, template void SymmetricPerChannelQuantize(const float* input, T* output, int num_elements, int num_channels, - float* scales) { + float* scales, + size_t quantized_dimension = 0) { int elements_per_channel = num_elements / num_channels; for (int i = 0; i < num_channels; i++) { for (int j = 0; j < elements_per_channel; j++) { - output[i * elements_per_channel + j] = FloatToSymmetricQuantizedType( - input[i * elements_per_channel + j], scales[i]); + size_t offset; + if (quantized_dimension == 0) { + offset = i * elements_per_channel + j; + } else { + offset = i + elements_per_channel * j; + } + output[offset] = + FloatToSymmetricQuantizedType(input[offset], scales[i]); } } } diff --git a/tensorflow/lite/micro/recording_micro_allocator.cc b/tensorflow/lite/micro/recording_micro_allocator.cc index ee76196d255..18addaee5f7 100644 --- a/tensorflow/lite/micro/recording_micro_allocator.cc +++ b/tensorflow/lite/micro/recording_micro_allocator.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -78,14 +78,15 @@ RecordedAllocation RecordingMicroAllocator::GetRecordedAllocation( return recorded_node_and_registration_array_data_; case RecordedAllocationType::kOpData: return recorded_op_data_; - // the function MicroPrintf was never reached outside the switch, because - // each case has a return. As the intention of the MicroPrintf is to be - // called when no matching case is found, a default case was added to - // contemplate an invalid allocation type +#ifdef USE_TFLM_COMPRESSION + case RecordedAllocationType::kCompressionData: + return recorded_compression_data_; +#endif // USE_TFLM_COMPRESSION default: - MicroPrintf("Invalid allocation type supplied: %d", allocation_type); - return RecordedAllocation(); + break; } + MicroPrintf("Invalid allocation type supplied: %d", allocation_type); + return RecordedAllocation(); } const RecordingSingleArenaBufferAllocator* @@ -117,6 +118,13 @@ void RecordingMicroAllocator::PrintAllocations() const { "NodeAndRegistration structs"); PrintRecordedAllocation(RecordedAllocationType::kOpData, "Operator runtime data", "OpData structs"); + +#ifdef USE_TFLM_COMPRESSION + + PrintRecordedAllocation(RecordedAllocationType::kCompressionData, + "Persistent compression data", "allocations"); + +#endif // USE_TFLM_COMPRESSION } void* RecordingMicroAllocator::AllocatePersistentBuffer(size_t bytes) { @@ -233,6 +241,21 @@ TfLiteStatus RecordingMicroAllocator::PopulateTfLiteTensorFromFlatbuffer( return status; } +#ifdef USE_TFLM_COMPRESSION + +TfLiteStatus RecordingMicroAllocator::AllocateCompressedTensorsList( + const Model* model, SubgraphAllocations* subgraph_allocations) { + RecordedAllocation allocations = SnapshotAllocationUsage(); + + TfLiteStatus status = MicroAllocator::AllocateCompressedTensorsList( + model, subgraph_allocations); + + RecordAllocationUsage(allocations, recorded_compression_data_); + return status; +} + +#endif // USE_TFLM_COMPRESSION + RecordedAllocation RecordingMicroAllocator::SnapshotAllocationUsage() const { return {/*requested_bytes=*/recording_memory_allocator_->GetRequestedBytes(), /*used_bytes=*/recording_memory_allocator_->GetUsedBytes(), diff --git a/tensorflow/lite/micro/recording_micro_allocator.h b/tensorflow/lite/micro/recording_micro_allocator.h index b6f69264dc0..80f163240d3 100644 --- a/tensorflow/lite/micro/recording_micro_allocator.h +++ b/tensorflow/lite/micro/recording_micro_allocator.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,6 +33,11 @@ enum class RecordedAllocationType { kTfLiteTensorVariableBufferData, kNodeAndRegistrationArray, kOpData, +#ifdef USE_TFLM_COMPRESSION + kCompressionData, +#endif // USE_TFLM_COMPRESSION + + kNumAllocationTypes, // must be last }; // Container for holding information about allocation recordings by a given @@ -93,6 +98,13 @@ class RecordingMicroAllocator : public MicroAllocator { int subgraph_index, bool allocate_temp) override; +#ifdef USE_TFLM_COMPRESSION + + TfLiteStatus AllocateCompressedTensorsList( + const Model* model, SubgraphAllocations* subgraph_allocations) override; + +#endif // USE_TFLM_COMPRESSION + private: RecordingMicroAllocator(RecordingSingleArenaBufferAllocator* memory_allocator, MicroMemoryPlanner* memory_planner); @@ -113,6 +125,9 @@ class RecordingMicroAllocator : public MicroAllocator { RecordedAllocation recorded_persistent_buffer_data_ = {}; RecordedAllocation recorded_tflite_tensor_variable_buffer_data_ = {}; RecordedAllocation recorded_node_and_registration_array_data_ = {}; +#ifdef USE_TFLM_COMPRESSION + RecordedAllocation recorded_compression_data_ = {}; +#endif // USE_TFLM_COMPRESSION // TODO(b/187993291): Re-enable OpData allocating tracking. RecordedAllocation recorded_op_data_ = {}; diff --git a/tensorflow/lite/micro/recording_micro_allocator_test.cc b/tensorflow/lite/micro/recording_micro_allocator_test.cc index 9d3a5965de4..121a74c3324 100644 --- a/tensorflow/lite/micro/recording_micro_allocator_test.cc +++ b/tensorflow/lite/micro/recording_micro_allocator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -317,6 +317,72 @@ TF_LITE_MICRO_TEST(TestMultiSubgraphModel) { num_tensors * TF_LITE_EVAL_TENSOR_STRUCT_SIZE); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(TestCompressedModel) { + tflite::ScratchBufferHandle* scratch_buffer_handles = nullptr; + tflite::testing::TestingOpResolver ops_resolver; + const tflite::Model* model = tflite::testing::GetSimpleMockModelCompressed(); + const int arena_size = 2048; + + uint8_t arena[arena_size]; + + tflite::RecordingMicroAllocator* micro_allocator = + tflite::RecordingMicroAllocator::Create(arena, arena_size); + TF_LITE_MICRO_EXPECT(micro_allocator != nullptr); + TF_LITE_MICRO_CHECK_FAIL(); + + tflite::SubgraphAllocations* subgraph_allocations = + micro_allocator->StartModelAllocation(model); + TF_LITE_MICRO_EXPECT(nullptr != subgraph_allocations); + TF_LITE_MICRO_CHECK_FAIL(); + + TfLiteStatus status = micro_allocator->FinishModelAllocation( + model, subgraph_allocations, &scratch_buffer_handles); + TF_LITE_MICRO_EXPECT_EQ(status, kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + + micro_allocator->PrintAllocations(); + + size_t count_compression_allocations = 0; + size_t size_compression_allocations = 0; + for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs()->size(); + subgraph_idx++) { + const tflite::CompressionTensorData** ctl = + subgraph_allocations[subgraph_idx].compressed.tensors; + if (ctl == nullptr) { + continue; + } + const tflite::SubGraph* subgraph = model->subgraphs()->Get(subgraph_idx); + const size_t num_tensors = subgraph->tensors()->size(); + for (size_t i = 0; i < num_tensors; i++) { + if (ctl[i] != nullptr) { + count_compression_allocations++; + size_compression_allocations += sizeof(tflite::CompressionTensorData); + count_compression_allocations++; + size_compression_allocations += sizeof(tflite::LookupTableData); + } + } + // Add the CompressionTensorData array + count_compression_allocations++; + size_compression_allocations += + num_tensors * sizeof(tflite::CompressionTensorData*); + } + + tflite::RecordedAllocation recorded_allocation = + micro_allocator->GetRecordedAllocation( + tflite::RecordedAllocationType::kCompressionData); + + TF_LITE_MICRO_EXPECT_EQ(recorded_allocation.count, + count_compression_allocations); + TF_LITE_MICRO_EXPECT_EQ(recorded_allocation.requested_bytes, + size_compression_allocations); + TF_LITE_MICRO_EXPECT_GE(recorded_allocation.used_bytes, + size_compression_allocations); +} + +#endif // USE_TFLM_COMPRESSION + // TODO(b/158124094): Find a way to audit OpData allocations on // cross-architectures. diff --git a/tensorflow/lite/micro/tools/benchmarking/Makefile.inc b/tensorflow/lite/micro/tools/benchmarking/Makefile.inc index 396e7016384..a79420cb982 100644 --- a/tensorflow/lite/micro/tools/benchmarking/Makefile.inc +++ b/tensorflow/lite/micro/tools/benchmarking/Makefile.inc @@ -20,6 +20,15 @@ endif $(GENERATED_SRCS_DIR)$(GENERIC_BENCHMARK_MODEL_DIR)$(GENERIC_BENCHMARK_MODEL_NAME)_model_data.h endif +ifeq ($(ENABLE_COMPRESSION), yes) +ifneq ($(GENERIC_BENCHMARK_ALT_MEM_ATTR),) + CXXFLAGS += -DGENERIC_BENCHMARK_ALT_MEM_ATTR=$(GENERIC_BENCHMARK_ALT_MEM_ATTR) +endif +ifneq ($(GENERIC_BENCHMARK_ALT_MEM_SIZE),) + CXXFLAGS += -DGENERIC_BENCHMARK_ALT_MEM_SIZE=$(GENERIC_BENCHMARK_ALT_MEM_SIZE) +endif +endif + GENERIC_BENCHMARK_SRCS := \ $(MICROLITE_BENCHMARK_ROOT_DIR)/generic_model_benchmark.cc \ $(MICROLITE_BENCHMARK_ROOT_DIR)/metrics.cc \ diff --git a/tensorflow/lite/micro/tools/benchmarking/collect_meta_data.sh b/tensorflow/lite/micro/tools/benchmarking/collect_meta_data.sh index c60bdf3ed72..424a1b8da65 100755 --- a/tensorflow/lite/micro/tools/benchmarking/collect_meta_data.sh +++ b/tensorflow/lite/micro/tools/benchmarking/collect_meta_data.sh @@ -52,7 +52,7 @@ function substitute_strings() { IFS=${SAVED_IFS} replacement=() for line in "${lines_array[@]}"; do - line=$(sed -e 's/"/\\"/g' <<< "${line}") + line=$(sed -e 's/\\/\\\\/g' -e 's/"/\\"/g' <<< "${line}") line=$(printf '"%s",\n ' "${line}") replacement+=( "${line}" ) done diff --git a/tensorflow/lite/micro/tools/benchmarking/generic_model_benchmark.cc b/tensorflow/lite/micro/tools/benchmarking/generic_model_benchmark.cc index f398963a00d..9af661fb3b8 100644 --- a/tensorflow/lite/micro/tools/benchmarking/generic_model_benchmark.cc +++ b/tensorflow/lite/micro/tools/benchmarking/generic_model_benchmark.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -56,19 +57,37 @@ limitations under the License. #endif // defind(GENERIC_BENCHMARK_USING_BUILTIN_MODEL) +#if defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) && \ + !defined(GENERIC_BENCHMARK_ALT_MEM_SIZE) +#error "GENERIC_BENCHMARK_ALT_MEM_SIZE missing from CXXFLAGS" +#endif // defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) && + // !defined(GENERIC_BENCHMARK_ALT_MEM_SIZE) + +#if defined(GENERIC_BENCHMARK_ALT_MEM_SIZE) && \ + !defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) +#error "GENERIC_BENCHMARK_ALT_MEM_ATTR missing from CXXFLAGS" +#endif // defined(GENERIC_BENCHMARK_ALT_MEM_SIZE) && + // !defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) + +#if defined(GENERIC_BENCHMARK_ALT_MEM_SIZE) && \ + defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) && defined(USE_TFLM_COMPRESSION) +#define USE_ALT_DECOMPRESSION_MEM +#endif // defined(GENERIC_BENCHMARK_ALT_MEM_SIZE) && + // defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) && + // defined(USE_TFLM_COMPRESSION) + /* - * Generic model benchmark. Evaluates runtime performance of a provided model - * with random inputs. + * Generic model benchmark. Evaluates runtime performance of a provided + * model with random inputs. */ namespace tflite { - namespace { using Profiler = ::tflite::MicroProfiler; -// Seed used for the random input. Input data shouldn't affect invocation timing -// so randomness isn't really needed. +// Seed used for the random input. Input data shouldn't affect invocation +// timing so randomness isn't really needed. constexpr uint32_t kRandomSeed = 0xFB; #if !defined(GENERIC_BENCHMARK_USING_BUILTIN_MODEL) @@ -80,6 +99,11 @@ constexpr size_t kTensorArenaSize = GENERIC_BENCHMARK_TENSOR_ARENA_SIZE; constexpr size_t kTensorArenaSize = 5e6 - MODEL_SIZE; #endif // !defined(GENERIC_BENCHMARK_USING_BUILTIN_MODEL) +#if defined(USE_ALT_DECOMPRESSION_MEM) +constexpr size_t kAltMemorySize = GENERIC_BENCHMARK_ALT_MEM_SIZE; +alignas(16) GENERIC_BENCHMARK_ALT_MEM_ATTR uint8_t g_alt_memory[kAltMemorySize]; +#endif // defined(USE_ALT_DECOMPRESSION_MEM) + constexpr int kNumResourceVariable = 100; void SetRandomInput(const uint32_t random_seed, @@ -130,39 +154,145 @@ bool ReadFile(const char* file_name, void* buffer, size_t buffer_size) { } #endif // !defined(GENERIC_BENCHMARK_USING_BUILTIN_MODEL) +uint32_t crctab[256]; + +void GenCRC32Table() { + constexpr uint32_t kPolyN = 0xEDB88320; + for (size_t index = 0; index < 256; index++) { + crctab[index] = index; + for (int i = 0; i < 8; i++) { + if (crctab[index] & 1) { + crctab[index] = (crctab[index] >> 1) ^ kPolyN; + } else { + crctab[index] >>= 1; + } + } + } +} + +uint32_t ComputeCRC32(const uint8_t* data, const size_t data_length) { + uint32_t crc32 = ~0U; + + for (size_t i = 0; i < data_length; i++) { + // crctab is an array of 256 32-bit constants + const uint32_t index = (crc32 ^ data[i]) & 0xFF; + crc32 = (crc32 >> 8) ^ crctab[index]; + } + + // invert all bits of result + crc32 ^= ~0U; + return crc32; +} + +void ShowOutputCRC32(tflite::MicroInterpreter* interpreter) { + GenCRC32Table(); + for (size_t i = 0; i < interpreter->outputs_size(); ++i) { + TfLiteTensor* output = interpreter->output_tensor(i); + uint8_t* output_values = tflite::GetTensorData(output); + uint32_t crc32_value = ComputeCRC32(output_values, output->bytes); + MicroPrintf("Output CRC32: 0x%X", crc32_value); + } +} + +void ShowInputCRC32(tflite::MicroInterpreter* interpreter) { + GenCRC32Table(); + for (size_t i = 0; i < interpreter->inputs_size(); ++i) { + TfLiteTensor* input = interpreter->input_tensor(i); + uint8_t* input_values = tflite::GetTensorData(input); + uint32_t crc32_value = ComputeCRC32(input_values, input->bytes); + MicroPrintf("Input CRC32: 0x%X", crc32_value); + } +} + int Benchmark(const uint8_t* model_data, tflite::PrettyPrintType print_type) { - Profiler profiler; + static Profiler profiler; + static Profiler profiler2; + TfLiteStatus status; + +// use this to keep the application size stable regardless of whether +// compression is being used +#ifdef USE_TFLM_COMPRESSION + constexpr bool using_compression = true; +#else // USE_TFLM_COMPRESSION + constexpr bool using_compression = false; +#endif // USE_TFLM_COMPRESSION + alignas(16) static uint8_t tensor_arena[kTensorArenaSize]; - uint32_t event_handle = profiler.BeginEvent("TfliteGetModel"); +#ifdef USE_ALT_DECOMPRESSION_MEM + std::initializer_list + alt_memory_region = {{g_alt_memory, kAltMemorySize}}; +#endif // USE_ALT_DECOMPRESSION_MEM + + uint32_t event_handle = profiler.BeginEvent("tflite::GetModel"); const tflite::Model* model = tflite::GetModel(model_data); profiler.EndEvent(event_handle); + event_handle = profiler.BeginEvent("tflite::CreateOpResolver"); TflmOpResolver op_resolver; - TF_LITE_ENSURE_STATUS(CreateOpResolver(op_resolver)); + status = CreateOpResolver(op_resolver); + if (status != kTfLiteOk) { + MicroPrintf("tflite::CreateOpResolver failed"); + return -1; + } + profiler.EndEvent(event_handle); + event_handle = profiler.BeginEvent("tflite::RecordingMicroAllocator::Create"); tflite::RecordingMicroAllocator* allocator( tflite::RecordingMicroAllocator::Create(tensor_arena, kTensorArenaSize)); + profiler.EndEvent(event_handle); + event_handle = profiler.BeginEvent("tflite::MicroInterpreter instantiation"); tflite::RecordingMicroInterpreter interpreter( model, op_resolver, allocator, tflite::MicroResourceVariables::Create(allocator, kNumResourceVariable), &profiler); - TF_LITE_ENSURE_STATUS(interpreter.AllocateTensors()); + profiler.EndEvent(event_handle); + +#ifdef USE_ALT_DECOMPRESSION_MEM + event_handle = + profiler.BeginEvent("tflite::MicroInterpreter::SetDecompressionMemory"); + status = interpreter.SetDecompressionMemory(alt_memory_region); + if (status != kTfLiteOk) { + MicroPrintf("tflite::MicroInterpreter::SetDecompressionMemory failed"); + return -1; + } + profiler.EndEvent(event_handle); +#endif // USE_ALT_DECOMPRESSION_MEM + + event_handle = + profiler.BeginEvent("tflite::MicroInterpreter::AllocateTensors"); + status = interpreter.AllocateTensors(); + if (status != kTfLiteOk) { + MicroPrintf("tflite::MicroInterpreter::AllocateTensors failed"); + return -1; + } + profiler.EndEvent(event_handle); - profiler.Log(); + profiler.LogTicksPerTagCsv(); profiler.ClearEvents(); + if (using_compression) { + status = interpreter.SetAlternateProfiler(&profiler2); + if (status != kTfLiteOk) { + MicroPrintf("tflite::MicroInterpreter::SetAlternateProfiler failed"); + return -1; + } + } + MicroPrintf(""); // null MicroPrintf serves as a newline. - // For streaming models, the interpreter will return kTfLiteAbort if the model - // does not yet have enough data to make an inference. As such, we need to - // invoke the interpreter multiple times until we either receive an error or - // kTfLiteOk. This loop also works for non-streaming models, as they'll just - // return kTfLiteOk after the first invocation. + // For streaming models, the interpreter will return kTfLiteAbort if the + // model does not yet have enough data to make an inference. As such, we + // need to invoke the interpreter multiple times until we either receive an + // error or kTfLiteOk. This loop also works for non-streaming models, as + // they'll just return kTfLiteOk after the first invocation. uint32_t seed = kRandomSeed; while (true) { SetRandomInput(seed++, interpreter); - TfLiteStatus status = interpreter.Invoke(); + ShowInputCRC32(&interpreter); + MicroPrintf(""); // null MicroPrintf serves as a newline. + + status = interpreter.Invoke(); if ((status != kTfLiteOk) && (static_cast(status) != kTfLiteAbort)) { MicroPrintf("Model interpreter invocation failed: %d\n", status); return -1; @@ -174,6 +304,17 @@ int Benchmark(const uint8_t* model_data, tflite::PrettyPrintType print_type) { MicroPrintf(""); // null MicroPrintf serves as a newline. profiler.ClearEvents(); + if (using_compression) { + profiler2.Log(); + MicroPrintf(""); // null MicroPrintf serves as a newline. + profiler2.LogTicksPerTagCsv(); + MicroPrintf(""); // null MicroPrintf serves as a newline. + profiler2.ClearEvents(); + } + + ShowOutputCRC32(&interpreter); + MicroPrintf(""); // null MicroPrintf serves as a newline. + if (status == kTfLiteOk) { break; } diff --git a/tensorflow/lite/micro/tools/benchmarking/metrics.cc b/tensorflow/lite/micro/tools/benchmarking/metrics.cc index 3a4bf7e4917..f71a4cd139e 100644 --- a/tensorflow/lite/micro/tools/benchmarking/metrics.cc +++ b/tensorflow/lite/micro/tools/benchmarking/metrics.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,7 +46,8 @@ struct LogAllocationRecord { constexpr int kArenaRows = 3; constexpr int kArenaColumns = 3; -constexpr int kAllocationTypes = 7; +constexpr int kAllocationTypes = + static_cast(tflite::RecordedAllocationType::kNumAllocationTypes); constexpr int kAllocationColumns = 6; constexpr int kMaxBufSize = 100; @@ -85,16 +86,25 @@ LogAllocationRecord GetLogAllocationRecord( tflite::RecordedAllocationType::kPersistentBufferData, tflite::RecordedAllocationType::kTfLiteTensorVariableBufferData, tflite::RecordedAllocationType::kNodeAndRegistrationArray, - tflite::RecordedAllocationType::kOpData}; + tflite::RecordedAllocationType::kOpData, +#ifdef USE_TFLM_COMPRESSION + tflite::RecordedAllocationType::kCompressionData, +#endif // USE_TFLM_COMPRESSION + }; static_assert(std::extent::value == kAllocationTypes, "kAllocationTypes mismatch"); - const char* titles[] = {"Eval tensor data", - "Persistent tensor data", - "Persistent quantization data", - "Persistent buffer data", - "Tensor variable buffer data", - "Node and registration array", - "Operation data"}; + const char* titles[] = { + "Eval tensor data", + "Persistent tensor data", + "Persistent quantization data", + "Persistent buffer data", + "Tensor variable buffer data", + "Node and registration array", + "Operation data", +#ifdef USE_TFLM_COMPRESSION + "Compression data", +#endif // USE_TFLM_COMPRESSION + }; static_assert(std::extent::value == kAllocationTypes, "kAllocationTypes mismatch"); const size_t total_bytes = diff --git a/tensorflow/lite/micro/tools/benchmarking/show_meta_data.cc.template b/tensorflow/lite/micro/tools/benchmarking/show_meta_data.cc.template index a2102a48e1c..8ec4e512f7a 100644 --- a/tensorflow/lite/micro/tools/benchmarking/show_meta_data.cc.template +++ b/tensorflow/lite/micro/tools/benchmarking/show_meta_data.cc.template @@ -20,6 +20,13 @@ limitations under the License. #include "tensorflow/lite/micro/micro_log.h" #include "tensorflow/lite/micro/tools/benchmarking/show_meta_data.h" +#ifndef XTENSA +#undef HIFI3 +#undef HIFI4 +#undef HIFI5 +#undef VISION_P6 +#endif // XTENSA + #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5) #include "NatureDSP_Signal_id.h" #include "xa_nnlib_standards.h" diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc index 92527adce53..8a3ee7b6716 100644 --- a/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc @@ -97,4 +97,10 @@ ifeq ($(OPTIMIZED_KERNEL_DIR), xtensa) $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/unidirectional_sequence_lstm.cc + + # override KERNEL_OPTIMIZATION_LEVEL to enable higher performance + # Xtensa intrinsics. +$(KERNEL_OBJDIR)$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/decompress.o: $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/decompress.cc + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) -O3 -LNO:simd $(INCLUDES) -c $< -o $@ endif