diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index a1d22cfbd8e..6902728043f 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -60,7 +60,7 @@ TfLiteStatus FullyConnectedPrepare(TfLiteContext* context, TfLiteNode* node) { (input->type == kTfLiteInt8 && (filter->type != kTfLiteInt8 && filter->type != kTfLiteInt4)) || (input->type == kTfLiteInt16 && filter->type != kTfLiteInt8)) { - MicroPrintf("Input type: %s with filter type : %s not supported.", + MicroPrintf("Input type: %s with filter type: %s not supported.", TfLiteTypeGetName(input->type), TfLiteTypeGetName(filter->type)); return kTfLiteError; @@ -79,6 +79,23 @@ TfLiteStatus FullyConnectedPrepare(TfLiteContext* context, TfLiteNode* node) { context, params->activation, input->type, input, filter, bias, output, data)); +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + if (micro_context->IsTensorCompressed(node, kFullyConnectedWeightsTensor) && + filter->type == kTfLiteInt4) { + MicroPrintf("Compression not supported with INT4 tensors"); + return kTfLiteError; + } + data->weights_scratch_index = + micro_context->AllocateDecompressionScratchBuffer( + node, kFullyConnectedWeightsTensor); + data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer( + node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); if (bias != nullptr) { @@ -102,8 +119,19 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); - TFLITE_DCHECK(node->user_data != nullptr); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, + kFullyConnectedWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + + TFLITE_DCHECK(node->user_data != nullptr); const auto& data = *(static_cast(node->user_data)); @@ -115,9 +143,18 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -152,9 +189,19 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)) : tflite::reference_integer_ops::FullyConnected( @@ -162,9 +209,19 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -186,9 +243,18 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; diff --git a/tensorflow/lite/micro/kernels/fully_connected.h b/tensorflow/lite/micro/kernels/fully_connected.h index 670488ab618..64213f0fb63 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.h +++ b/tensorflow/lite/micro/kernels/fully_connected.h @@ -50,6 +50,14 @@ struct OpDataFullyConnected { int32_t* per_channel_output_shift; bool is_per_channel; #endif + +#ifdef USE_TFLM_COMPRESSION + + // scratch buffers for compressed tensors + int weights_scratch_index; + int bias_scratch_index; + +#endif // USE_TFLM_COMPRESSION }; extern const int kFullyConnectedInputTensor; diff --git a/tensorflow/lite/micro/kernels/fully_connected_test.cc b/tensorflow/lite/micro/kernels/fully_connected_test.cc index 2ad132055b8..1197b105534 100644 --- a/tensorflow/lite/micro/kernels/fully_connected_test.cc +++ b/tensorflow/lite/micro/kernels/fully_connected_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,6 +42,29 @@ const float simple_weights_data[] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 }; +int simple_bias_dims[] = {1, 3}; +const float simple_bias_data[] = {1, 2, 3}; + +#ifdef USE_TFLM_COMPRESSION + +// compressed filter data for kBinQuant scheme +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantWeightData[] = { + 0x01, 0x23, 0x45, 0x67, 0x89, 0x01, 0x23, 0x45, + 0x67, 0x89, 0x01, 0x23, 0x45, 0x67, 0x89}; +constexpr float kBinQuantWeightValueTable[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; +constexpr size_t kBinQuantWeightValueTableElements = + std::extent::value; +constexpr int kBinQuantWeightBitWidth = 4; +// compressed bias data for kBinQuant scheme +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasData[] = {0x18}; +constexpr int kBinQuantBiasBitWidth = 2; +constexpr size_t simple_bias_size = + std::extent::value; + +#endif // USE_TFLM_COMPRESSION + // TODO(b/258710417): INT4 isn't currently supported on Hexagon. #if !defined(HEXAGON) const float simple_int4_weights_data[] = { @@ -53,8 +76,6 @@ const float simple_golden_null_bias_int4_weights[] = { -28, -28, -28, 0, 0, 0, }; #endif -int simple_bias_dims[] = {1, 3}; -const float simple_bias_data[] = {1, 2, 3}; const float simple_golden[] = { 24, 25, 26, 58, 59, 60, }; @@ -241,11 +262,19 @@ const float representative_64x16_golden[] = { const int representative_64x16_output_size = 16; int representative_64x16_output_dims[] = {2, 1, 16}; -template +constexpr int kMaxTensors = 4; + +template TfLiteStatus ValidateFullyConnectedGoldens( TfLiteTensor* tensors, const int tensors_size, bool null_bias, const TfLiteFusedActivation activation, const float tolerance, - const int output_len, const T* golden, T* output_data) { + const int output_len, const T* golden, T* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* weight_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { TfLiteFullyConnectedParams builtin_data = { activation, kTfLiteFullyConnectedWeightsFormatDefault, false, false, kTfLiteNoType}; @@ -272,10 +301,37 @@ TfLiteStatus ValidateFullyConnectedGoldens( TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); +#ifdef USE_TFLM_COMPRESSION + + TestCompressedList tcl; + + if (weight_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*weight_comp_info, tensors[kFullyConnectedWeightsTensor], + kFullyConnectedWeightsTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + if (bias_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*bias_comp_info, tensors[kFullyConnectedBiasTensor], + kFullyConnectedBiasTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + const CompressedTensorList* comp_list_p = tcl.GetCompressedTensorList(); + +#endif // USE_TFLM_COMPRESSION + const TFLMRegistration registration = Register_FULLY_CONNECTED(); micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, outputs_array, - reinterpret_cast(&builtin_data)); + reinterpret_cast(&builtin_data), nullptr +#ifdef USE_TFLM_COMPRESSION + , + comp_list_p +#endif // USE_TFLM_COMPRESSION + ); TfLiteStatus status = runner.InitAndPrepare(); if (status != kTfLiteOk) { @@ -297,7 +353,13 @@ TfLiteStatus TestFullyConnectedFloat( int* input_dims_data, const float* input_data, int* weights_dims_data, const float* weights_data, int* bias_dims_data, const float* bias_data, const float* golden, int* output_dims_data, - TfLiteFusedActivation activation, float* output_data) { + TfLiteFusedActivation activation, float* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* weight_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* weights_dims = IntArrayFromInts(weights_dims_data); TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); @@ -305,16 +367,15 @@ TfLiteStatus TestFullyConnectedFloat( const int output_dims_count = ElementCount(*output_dims); bool null_bias = bias_data == nullptr ? true : false; - constexpr int array_size = 4; // Avoid variable length array warning. - const int inputs_size = bias_data == nullptr ? 2 : 3; + const int inputs_size = null_bias ? 2 : 3; constexpr int outputs_size = 1; const int tensors_size = inputs_size + outputs_size; - TfLiteTensor tensors[array_size]; + TfLiteTensor tensors[kMaxTensors]; tensors[0] = CreateTensor(input_data, input_dims); tensors[1] = CreateTensor(weights_data, weights_dims); - if (bias_data == nullptr) { + if (null_bias) { tensors[2] = CreateTensor(output_data, output_dims); } else { tensors[2] = CreateTensor(bias_data, bias_dims); @@ -323,7 +384,12 @@ TfLiteStatus TestFullyConnectedFloat( return ValidateFullyConnectedGoldens(tensors, tensors_size, null_bias, activation, 1e-4f, output_dims_count, - golden, output_data); + golden, output_data +#ifdef USE_TFLM_COMPRESSION + , + weight_comp_info, bias_comp_info +#endif // USE_TFLM_COMPRESSION + ); } template @@ -345,7 +411,7 @@ TfLiteStatus TestFullyConnectedQuantized( bool null_bias = bias_data == nullptr ? true : false; constexpr int array_size = 4; // Avoid variable length array warning. - const int inputs_size = bias_data == nullptr ? 2 : 3; + const int inputs_size = null_bias ? 2 : 3; constexpr int outputs_size = 1; const int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[array_size]; @@ -355,7 +421,7 @@ TfLiteStatus TestFullyConnectedQuantized( tensors[1] = CreateQuantizedTensor( weights_data, weights_quantized, weights_dims, weights_scale, weights_zero_point, false, weights_packed_type); - if (bias_data == nullptr) { + if (null_bias) { tensors[2] = CreateQuantizedTensor(output_data, output_dims, output_scale, output_zero_point); } else { @@ -373,6 +439,71 @@ TfLiteStatus TestFullyConnectedQuantized( golden_quantized, output_data); } +#ifdef USE_TFLM_COMPRESSION + +template +TfLiteStatus TestFullyConnectedQuantizedCompressed( + int* input_dims_data, const float* input_data, TIO* input_quantized, + float input_scale, int input_zero_point, int* output_dims_data, + const float* expected_output_data, TIO* expected_output_quantized, + TIO* output_quantized, float output_scale, int output_zero_point, + const TfLiteFusedActivation activation, + const TestCompressionQuantizedInfo* weight_comp_info, + const TestCompressionQuantizedInfo* bias_comp_info) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* weight_dims = IntArrayFromInts(weight_comp_info->dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_comp_info->dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + + TfLiteFloatArray* weight_scales = + FloatArrayFromFloats(weight_comp_info->scales); + TfLiteIntArray* weight_zero_points = + IntArrayFromInts(weight_comp_info->zero_points); + + TfLiteTensor weight_tensor = CreateQuantizedTensor( + weight_comp_info->compressed, weight_dims, weight_scales->data[0], + weight_zero_points->data[0], false, kTfLiteInt8); + SymmetricQuantize(weight_comp_info->data, weight_comp_info->value_table, + weight_comp_info->value_table_stride, + weight_scales->data[0]); + + TfLiteTensor bias_tensor = {}; + if (bias_comp_info != nullptr) { + bias_tensor = CreateQuantizedTensor(bias_comp_info->compressed, bias_dims, + input_scale * weight_scales->data[0], 0, + false, typeToTfLiteType()); + SymmetricQuantize(bias_comp_info->data, bias_comp_info->value_table, + bias_comp_info->value_table_stride, + bias_tensor.params.scale); + } + + TfLiteTensor output_tensor = CreateQuantizedTensor( + output_quantized, output_dims, output_scale, output_zero_point); + + const int tensors_size = + (bias_comp_info == nullptr) ? kMaxTensors - 1 : kMaxTensors; + TfLiteTensor tensors[kMaxTensors] = {}; + tensors[0] = CreateQuantizedTensor(input_data, input_quantized, input_dims, + input_scale, input_zero_point); + tensors[1] = weight_tensor; + if (bias_comp_info == nullptr) { + tensors[2] = output_tensor; + } else { + tensors[2] = bias_tensor; + tensors[3] = output_tensor; + } + + const int output_dims_count = ElementCount(*output_dims); + Quantize(expected_output_data, expected_output_quantized, output_dims_count, + output_scale, output_zero_point); + return ValidateFullyConnectedGoldens( + tensors, tensors_size, bias_comp_info == nullptr, activation, 0.0f, + output_dims_count, expected_output_quantized, output_quantized, + weight_comp_info, bias_comp_info); +} + +#endif // USE_TFLM_COMPRESSION + } // namespace } // namespace testing } // namespace tflite @@ -393,6 +524,40 @@ TF_LITE_MICRO_TEST(SimpleTest) { kTfLiteOk); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestCompressed) { + float output_data[tflite::testing::simple_output_size]; + + tflite::testing::TestCompressionInfo weight_comp_info = {}; + tflite::testing::TestCompressionInfo bias_comp_info = {}; + + weight_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + weight_comp_info.value_table = tflite::testing::kBinQuantWeightValueTable; + weight_comp_info.value_table_stride = + tflite::testing::kBinQuantWeightValueTableElements; + weight_comp_info.bit_width = tflite::testing::kBinQuantWeightBitWidth; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = tflite::testing::simple_bias_data; + bias_comp_info.value_table_stride = tflite::testing::simple_bias_size; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidth; + + TF_LITE_MICRO_EXPECT_EQ( + tflite::testing::TestFullyConnectedFloat( + tflite::testing::simple_input_dims, + tflite::testing::simple_input_data, + tflite::testing::simple_weights_dims, + reinterpret_cast(tflite::testing::kBinQuantWeightData), + tflite::testing::simple_bias_dims, + reinterpret_cast(tflite::testing::kBinQuantBiasData), + tflite::testing::simple_golden, tflite::testing::simple_output_dims, + kTfLiteActNone, output_data, &weight_comp_info, &bias_comp_info), + kTfLiteOk); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(SimpleTestNullBias) { float output_data[tflite::testing::simple_output_size]; TF_LITE_MICRO_EXPECT_EQ( @@ -434,6 +599,58 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8) { kTfLiteOk); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8Compressed) { + const float input_scale = 1.0f; + const int input_zero_point = -1; + constexpr float weights_scale[] = {1, 1.0f}; + constexpr int weights_zero_point[] = {1, 0}; + const float output_scale = 0.5f; + const int output_zero_point = -1; + + int8_t input_quantized[tflite::testing::simple_input_size]; + int8_t weights_quantized[tflite::testing::kBinQuantWeightValueTableElements]; + int32_t bias_quantized[tflite::testing::simple_output_size]; + int8_t golden_quantized[tflite::testing::simple_output_size]; + int8_t output_data[tflite::testing::simple_output_size]; + + tflite::testing::TestCompressionQuantizedInfo weight_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + weight_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + weight_comp_info.value_table = weights_quantized; + weight_comp_info.value_table_stride = + tflite::testing::kBinQuantWeightValueTableElements; + weight_comp_info.bit_width = tflite::testing::kBinQuantWeightBitWidth; + weight_comp_info.compressed = tflite::testing::kBinQuantWeightData; + weight_comp_info.data = tflite::testing::kBinQuantWeightValueTable; + weight_comp_info.dims_data = tflite::testing::simple_weights_dims; + weight_comp_info.scales = weights_scale; + weight_comp_info.zero_points = weights_zero_point; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = tflite::testing::simple_bias_size; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidth; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasData; + bias_comp_info.data = tflite::testing::simple_bias_data; + bias_comp_info.dims_data = tflite::testing::simple_bias_dims; + // bias scales and bias zero_points are not used + + TF_LITE_MICRO_EXPECT_EQ( + tflite::testing::TestFullyConnectedQuantizedCompressed( + tflite::testing::simple_input_dims, + tflite::testing::simple_input_data, input_quantized, input_scale, + input_zero_point, tflite::testing::simple_output_dims, + tflite::testing::simple_golden, golden_quantized, output_data, + output_scale, output_zero_point, kTfLiteActNone, &weight_comp_info, + &bias_comp_info), + kTfLiteOk); +} + +#endif // USE_TFLM_COMPRESSION + #if !defined(HEXAGON) TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16) { const float input_scale = 128.0 / 65536; @@ -443,7 +660,6 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16) { const float output_scale = 128.0 / 65536; const int output_zero_point = 0; - const float simple_golden[] = {24, 25, 26, 58, 59, 60}; int16_t input_quantized[tflite::testing::simple_input_size]; int8_t weights_quantized[tflite::testing::simple_weights_size]; int64_t bias_quantized[tflite::testing::simple_output_size]; @@ -457,12 +673,66 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16) { input_zero_point, tflite::testing::simple_weights_dims, tflite::testing::simple_weights_data, weights_quantized, weights_scale, weights_zero_point, tflite::testing::simple_bias_dims, - tflite::testing::simple_bias_data, bias_quantized, simple_golden, - golden_quantized, tflite::testing::simple_output_dims, output_scale, - output_zero_point, kTfLiteActNone, output_data), + tflite::testing::simple_bias_data, bias_quantized, + tflite::testing::simple_golden, golden_quantized, + tflite::testing::simple_output_dims, output_scale, output_zero_point, + kTfLiteActNone, output_data), kTfLiteOk); } -#endif + +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16Compressed) { + const float input_scale = 128.0 / 65536; + const int input_zero_point = 0; + constexpr float weights_scale[] = {1, 1.0f}; + constexpr int weights_zero_point[] = {1, 0}; + const float output_scale = 128.0 / 65536; + const int output_zero_point = 0; + + int16_t input_quantized[tflite::testing::simple_input_size]; + int8_t weights_quantized[tflite::testing::kBinQuantWeightValueTableElements]; + int64_t bias_quantized[tflite::testing::simple_output_size]; + int16_t golden_quantized[tflite::testing::simple_output_size]; + int16_t output_data[tflite::testing::simple_output_size]; + + tflite::testing::TestCompressionQuantizedInfo weight_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + weight_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + weight_comp_info.value_table = weights_quantized; + weight_comp_info.value_table_stride = + tflite::testing::kBinQuantWeightValueTableElements; + weight_comp_info.bit_width = tflite::testing::kBinQuantWeightBitWidth; + weight_comp_info.compressed = tflite::testing::kBinQuantWeightData; + weight_comp_info.data = tflite::testing::kBinQuantWeightValueTable; + weight_comp_info.dims_data = tflite::testing::simple_weights_dims; + weight_comp_info.scales = weights_scale; + weight_comp_info.zero_points = weights_zero_point; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = tflite::testing::simple_bias_size; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidth; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasData; + bias_comp_info.data = tflite::testing::simple_bias_data; + bias_comp_info.dims_data = tflite::testing::simple_bias_dims; + // bias scales and bias zero_points are not used + + TF_LITE_MICRO_EXPECT_EQ( + tflite::testing::TestFullyConnectedQuantizedCompressed( + tflite::testing::simple_input_dims, + tflite::testing::simple_input_data, input_quantized, input_scale, + input_zero_point, tflite::testing::simple_output_dims, + tflite::testing::simple_golden, golden_quantized, output_data, + output_scale, output_zero_point, kTfLiteActNone, &weight_comp_info, + &bias_comp_info), + kTfLiteOk); +} + +#endif // USE_TFLM_COMPRESSION + +#endif // !defined(HEXAGON) TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedInt8) { const float input_scale = 1.0f; diff --git a/tensorflow/lite/micro/kernels/kernel_util.h b/tensorflow/lite/micro/kernels/kernel_util.h index 5f8eb2e4ec3..5cb71af7953 100644 --- a/tensorflow/lite/micro/kernels/kernel_util.h +++ b/tensorflow/lite/micro/kernels/kernel_util.h @@ -100,13 +100,13 @@ const T* GetOptionalTensorData(const TfLiteEvalTensor* tensor) { #ifdef USE_TFLM_COMPRESSION -// Overloads existing GetTensorData. If not compressed, this will return +// Overloads existing GetOptionalTensorData. If not compressed, this will return // tensor->data. template -const T* GetTensorData(MicroContext* micro_context, - const TfLiteEvalTensor* tensor, - const CompressionTensorData* compression_data, - int scratch_buffer_handle) { +const T* GetOptionalTensorData(MicroContext* micro_context, + const TfLiteEvalTensor* tensor, + const CompressionTensorData* compression_data, + int scratch_buffer_handle) { if (tensor == nullptr) { return nullptr; } @@ -128,6 +128,18 @@ const T* GetTensorData(MicroContext* micro_context, return reinterpret_cast(uncompressed_data); } +// Overloads existing GetTensorData. If not compressed, this will return +// tensor->data. +template +const T* GetTensorData(MicroContext* micro_context, + const TfLiteEvalTensor* tensor, + const CompressionTensorData* compression_data, + int scratch_buffer_handle) { + TFLITE_DCHECK(tensor != nullptr); + return GetOptionalTensorData(micro_context, tensor, compression_data, + scratch_buffer_handle); +} + #endif // USE_TFLM_COMPRESSION // Returns the shape of a TfLiteEvalTensor struct. diff --git a/tensorflow/lite/micro/kernels/xtensa/decompress.cc b/tensorflow/lite/micro/kernels/xtensa/decompress.cc new file mode 100644 index 00000000000..13d2ce2dec7 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/decompress.cc @@ -0,0 +1,711 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef USE_TFLM_COMPRESSION + +#include "tensorflow/lite/micro/kernels/decompress.h" + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/micro/micro_common.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" +#include "tensorflow/lite/micro/micro_utils.h" + +#ifdef HIFI5 +#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" +#endif // HIFI5 + +namespace tflite { +namespace { + +#ifdef HIFI5 + +struct DecompressionStateXtensa : DecompressionState { + DecompressionStateXtensa() = delete; + + DecompressionStateXtensa(const DecompressionState& other) + : DecompressionState(other) {} + + void DecompressToBufferWidth4_Xtensa(int8_t* buffer); + void DecompressToBufferWidth3_Xtensa(int8_t* buffer); + void DecompressToBufferWidth2_Xtensa(int8_t* buffer); + + void DecompressToBufferWidthAnyInt8_Xtensa(int8_t* buffer); + void DecompressToBufferWidthAnyInt16_Xtensa(int16_t* buffer); + void DecompressToBufferWidthAnyInt32_Xtensa(int32_t* buffer); + void DecompressToBufferWidthAnyInt64_Xtensa(int64_t* buffer); +}; + +void DecompressionStateXtensa::DecompressToBufferWidth4_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int8x8 d_shuffle_t = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + int elements_per_channel_t_by_4 = elements_per_channel_ >> 4; + int elements_per_channel_t_rem = elements_per_channel_ & 0xF; + int j; + + ae_int8x8 d_out1, d_out2; + ae_int8x8 d_value_0_t, d_value_1_t; + ae_int8x8 d_value_0, d_value_1; + ae_int8x8 d_index, d_dummy; + + ae_int8x8* __restrict pIn_tmp = (ae_int8x8*)compressed_indices_; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint8_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + const uint8_t* __restrict value_table_t = value_table; + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (size_t i = 0; i < num_channels_; i++) { + value_table_t = value_table; + ae_valignx2 align_vtab = AE_LA128_PP(value_table_t); + AE_LA8X8X2_IP(d_value_0_t, d_value_1_t, align_vtab, + (ae_int8x16*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0_t, d_value_1_t, + d_shuffle_value_t); + + ae_valign align_load = AE_LA64_PP(pIn_tmp); + + for (j = 0; j < elements_per_channel_t_by_4; j++) { + AE_LA8X8_IP(d_index, align_load, pIn_tmp); + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + + value_table += stride; + if (elements_per_channel_t_rem) { + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + (elements_per_channel_t_rem >> + 1)); /* Loading 48 bits for decoding 16 weight values */ + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + } + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void DecompressionStateXtensa::DecompressToBufferWidth3_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + int i, j; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + ae_int8x8* pIn_tmp = (ae_int8x8*)compressed_indices_; + const uint8_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + const uint8_t* __restrict value_table_t = value_table; + + int num_channels_t = num_channels_; + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + + int elements_per_channel_t_by_4 = elements_per_channel_ >> 4; + int elements_per_channel_t_rem = elements_per_channel_ & 0xF; + + ae_int8x8 d_index, d_dummy; + ae_int8x8 d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11; + ae_int8x8 d_out1, d_out2; + + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + ae_int8x8 d_shuffle_t1 = AE_MOVINT8X8_FROMINT64(0x0F00050C00020000LL); + ae_int8x8 d_shuffle_t2 = AE_MOVINT8X8_FROMINT64(0x000E00040B000100LL); + ae_int8x8 d_shuffle_t3 = AE_MOVINT8X8_FROMINT64(0x0F060D040C030A01LL); + ae_int8x8 d_shuffle_t = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (i = 0; i < num_channels_t; i++) { + ae_int8x8 d_value_0 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + ae_int8x8 d_value_1 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + + value_table_t = value_table; + + ae_valign align_vtab = AE_LA64_PP(value_table_t); + AE_LA8X8_IP(d_value_0, align_vtab, (ae_int8x8*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0, d_value_1, d_shuffle_value_t); + + for (j = 0; j < elements_per_channel_t_by_4; j++) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 6); /* Loading 48 bits for decoding 16 weight values */ + + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 1)); + d2 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d3 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 3)); + d4 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 4)); + + d1 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), 0x7007007007000000LL)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d2), 0x0700700700700000LL)); + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d3), 0x0070070070070000LL)); + d4 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d4), 0x0007007007007000LL)); + + d5 = d1 | d2; + d6 = d3 | d4; + + d7 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d5), 4)); + d8 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d6), 4)); + + d9 = AE_SEL8X8(d5, d7, d_shuffle_t1); + d10 = AE_SEL8X8(d6, d8, d_shuffle_t2); + d11 = AE_SEL8X8(d9, d10, d_shuffle_t3); + + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d11); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + if (elements_per_channel_t_rem) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 3); /* Loading 48 bits for decoding 16 weight values */ + + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 1)); + d2 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d3 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 3)); + d4 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 4)); + + d1 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), 0x7007007007000000LL)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d2), 0x0700700700700000LL)); + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d3), 0x0070070070070000LL)); + d4 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d4), 0x0007007007007000LL)); + + d5 = d1 | d2; + d6 = d3 | d4; + + d7 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d5), 4)); + d8 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d6), 4)); + + d9 = AE_SEL8X8(d5, d7, d_shuffle_t1); + d10 = AE_SEL8X8(d6, d8, d_shuffle_t2); + d11 = AE_SEL8X8(d9, d10, d_shuffle_t3); + + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d11); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + } + + value_table = value_table + stride; + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void DecompressionStateXtensa::DecompressToBufferWidth2_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + int i, j; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + ae_int8x8* pIn_tmp = (ae_int8x8*)compressed_indices_; + const uint8_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + const uint8_t* __restrict value_table_t = value_table; + + int num_channels_t = num_channels_; + const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; + + int elements_per_channel_t_by_5 = elements_per_channel_ >> 5; + int elements_per_channel_t_rem = elements_per_channel_ & 0x1F; + int elements_per_channel_t_rem_minus_16 = 0; + if (elements_per_channel_t_rem > 16) { + elements_per_channel_t_rem_minus_16 = elements_per_channel_t_rem - 16; + } + + ae_int8x8 d_index, d_dummy; + ae_int8x8 d0, d1, d2, d3, d4, d5; + ae_int8x8 q0, q1, q2, q3; + ae_int8x8 d_out1, d_out2; + + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + ae_int8x8 d_shuffle_t1 = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + ae_int8x8 d_shuffle_t2 = AE_MOVINT8X8_FROMINT64(0xFBEA7362D9C85140LL); + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (i = 0; i < num_channels_t; i++) { + ae_int8x8 d_value_0 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + ae_int8x8 d_value_1 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + + value_table_t = value_table; + + ae_valign align_vtab = AE_LA64_PP(value_table_t); + AE_LA8X8_IP(d_value_0, align_vtab, (ae_int8x8*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0, d_value_1, d_shuffle_value_t); + + for (j = 0; j < elements_per_channel_t_by_5; j++) { + // AE_LA8X8_IP( d_index, align_index, pIn_tmp ); /* Loading 64 bits + // for decoding 32 weight values */ + + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 8); /* Loading 64 bits for decoding 32 weight values */ + d0 = d_index; + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d0), + 0x3333333333333333LL)); // i1,i3,i5, .... + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), + 0x3333333333333333LL)); // i0,i2,i4, .... + + AE_DSEL8X8(d4, d5, d3, d2, + d_shuffle_t1); // d4 = i0,i2,i1,i3,i4,i6,... d5 = + // i16,i18, i17,i19, .... + + AE_DSEL8X8(q0, q1, d_value_0, d_value_1, + d4); // q0 = 0,1,4,5,8,9,12,13 q1 = 2,3,6,7,10,11,14,15 + AE_DSEL8X8( + q2, q3, d_value_0, d_value_1, + d5); // q2 = 16,17,20,21,24,25,28,29 q3 = 18,19,22,23,26,27,30,31 + + AE_DSEL8X8(d_out1, d_out2, q0, q1, d_shuffle_t2); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + + AE_DSEL8X8(d_out1, d_out2, q2, q3, d_shuffle_t2); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + if (elements_per_channel_t_rem) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + (elements_per_channel_t_rem >> + 2)); /* Loading 48 bits for decoding 16 weight values */ + d0 = d_index; + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d0), + 0x3333333333333333LL)); // i1,i3,i5, .... + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), + 0x3333333333333333LL)); // i0,i2,i4, .... + + AE_DSEL8X8(d4, d5, d3, d2, + d_shuffle_t1); // d4 = i0,i2,i1,i3,i4,i6,... d5 = + // i16,i18, i17,i19, .... + + AE_DSEL8X8(q0, q1, d_value_0, d_value_1, + d4); // q0 = 0,1,4,5,8,9,12,13 q1 = 2,3,6,7,10,11,14,15 + AE_DSEL8X8( + q2, q3, d_value_0, d_value_1, + d5); // q2 = 16,17,20,21,24,25,28,29 q3 = 18,19,22,23,26,27,30,31 + + AE_DSEL8X8(d_out1, d_out2, q0, q1, d_shuffle_t2); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + + AE_DSEL8X8(d_out1, d_out2, q2, q3, d_shuffle_t2); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem_minus_16); + } + + value_table = value_table + stride; + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void DecompressionStateXtensa::DecompressToBufferWidthAnyInt8_Xtensa( + int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint8_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (comp_data_.data.lut_data->use_alternate_axis) { + int count = count_indices_; + const uint8_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + uint32_t index_1, index_2; + uint32_t mask_bits = (1 << compressed_bit_width_) - 1; + + for (int i = 0; i < num_channels_t; i++) { + elements_per_channel_t = elements_per_channel_; + /* if output pointer is not 2 byte aligned */ + if ((unsigned int)p_out_tmp & 0x1) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + elements_per_channel_t = elements_per_channel_t - 1; + } + for (int j = 0; j < (elements_per_channel_t >> 1); j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, 2 * bw); + index_1 = (index >> compressed_bit_width_) & mask_bits; + index_2 = (index)&mask_bits; + ae_int8x8 d_tmp1 = AE_L8_X((const ae_int8*)value_table, index_1); + ae_int8x8 d_tmp2 = AE_L8_X((const ae_int8*)value_table, index_2); + ae_int16x4 d_tmp = + AE_MOVINT16X4_FROMINT8X8(AE_SEL8X8I(d_tmp2, d_tmp1, 21)); + AE_S16_0_IP(d_tmp, (ae_int16*)p_out_tmp, 2); + } + if (elements_per_channel_t & 0x1) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + } + value_table += stride; + } + } +} + +void DecompressionStateXtensa::DecompressToBufferWidthAnyInt16_Xtensa( + int16_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint16_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int16* __restrict p_out_tmp = (ae_int16*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (comp_data_.data.lut_data->use_alternate_axis) { + int count = count_indices_; + const uint16_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int16x4 d_tmp = AE_L16_X((const ae_int16*)value_table, index << 1); + AE_S16_0_IP(d_tmp, p_out_tmp, 2); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int16x4 d_tmp = AE_L16_X((const ae_int16*)value_table, index << 1); + AE_S16_0_IP(d_tmp, p_out_tmp, 2); + } + + value_table += stride; + } + } +} + +void DecompressionStateXtensa::DecompressToBufferWidthAnyInt32_Xtensa( + int32_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint32_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int32* __restrict p_out_tmp = (ae_int32*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (comp_data_.data.lut_data->use_alternate_axis) { + int count = count_indices_; + const uint32_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int32x2 d_tmp = AE_L32_X((const ae_int32*)value_table, index << 2); + AE_S32_L_IP(d_tmp, p_out_tmp, 4); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int32x2 d_tmp = AE_L32_X((const ae_int32*)value_table, index << 2); + AE_S32_L_IP(d_tmp, p_out_tmp, 4); + } + + value_table += stride; + } + } +} + +void DecompressionStateXtensa::DecompressToBufferWidthAnyInt64_Xtensa( + int64_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = comp_data_.data.lut_data->value_table_channel_stride; + const uint64_t* __restrict value_table = + static_cast(comp_data_.data.lut_data->value_table); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int64* __restrict p_out_tmp = (ae_int64*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (comp_data_.data.lut_data->use_alternate_axis) { + int count = count_indices_; + const uint64_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int64 d_tmp = AE_L64_X((const ae_int64*)value_table, index << 3); + AE_S64_IP(d_tmp, p_out_tmp, 8); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int64 d_tmp = AE_L64_X((const ae_int64*)value_table, index << 3); + AE_S64_IP(d_tmp, p_out_tmp, 8); + } + + value_table += stride; + } + } +} + +#endif // HIFI5 + +} // namespace + +#ifdef HIFI5 + +template <> +bool* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); + + return static_cast(buffer); +} + +template <> +int8_t* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + if (comp_data_.data.lut_data->compressed_bit_width == 4 && + !comp_data_.data.lut_data->use_alternate_axis) { + if (!(elements_per_channel_ & 0x01)) { + dsx.DecompressToBufferWidth4_Xtensa(static_cast(buffer)); + } else { + dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); + } + } else if (comp_data_.data.lut_data->compressed_bit_width == 3 && + !comp_data_.data.lut_data->use_alternate_axis) { + if (!(elements_per_channel_ & 0x07)) { + dsx.DecompressToBufferWidth3_Xtensa(static_cast(buffer)); + } else { + dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); + } + } else if (comp_data_.data.lut_data->compressed_bit_width == 2 && + !comp_data_.data.lut_data->use_alternate_axis) { + if (!(elements_per_channel_ & 0x03)) { + dsx.DecompressToBufferWidth2_Xtensa(static_cast(buffer)); + } else { + dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); + } + } else { + dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); + } + + return static_cast(buffer); +} + +template <> +int16_t* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + dsx.DecompressToBufferWidthAnyInt16_Xtensa(static_cast(buffer)); + + return static_cast(buffer); +} + +template <> +int32_t* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + dsx.DecompressToBufferWidthAnyInt32_Xtensa(static_cast(buffer)); + + return static_cast(buffer); +} + +template <> +float* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + dsx.DecompressToBufferWidthAnyInt32_Xtensa(static_cast(buffer)); + + return static_cast(buffer); +} + +template <> +int64_t* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + DecompressionStateXtensa dsx(*this); + + dsx.DecompressToBufferWidthAnyInt64_Xtensa(static_cast(buffer)); + + return static_cast(buffer); +} + +#else // HIFI5 + +template +T* DecompressionState::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 4 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth4_16(static_cast(buffer)); + } else if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 3 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth3_32(static_cast(buffer)); + } else if (std::is_same::value && + comp_data_.data.lut_data->compressed_bit_width == 2 && + !comp_data_.data.lut_data->use_alternate_axis) { + DecompressToBufferWidth2_16(static_cast(buffer)); + } else { + DecompressToBufferWidthAny(static_cast(buffer)); + } + + return static_cast(buffer); +} + +template bool* DecompressionState::DecompressToBuffer(void*); +template float* DecompressionState::DecompressToBuffer(void*); +template int8_t* DecompressionState::DecompressToBuffer(void*); +template int16_t* DecompressionState::DecompressToBuffer(void*); +template int32_t* DecompressionState::DecompressToBuffer(void*); +template int64_t* DecompressionState::DecompressToBuffer(void*); + +#endif // HIFI5 + +} // namespace tflite + +#endif // USE_TFLM_COMPRESSION diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc index df5458001b7..511335a550f 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,6 +45,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, + kFullyConnectedWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + TFLITE_DCHECK(node->user_data != nullptr); const auto& data = @@ -58,9 +70,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -93,9 +114,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc index cf87c5ff1ed..91b9f40c907 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected_common_xtensa.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -120,6 +120,23 @@ TfLiteStatus XtensaPrepareFullyConnected(TfLiteContext* context, context, params->activation, input->type, input, filter, bias, output, data)); +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + if (micro_context->IsTensorCompressed(node, kFullyConnectedWeightsTensor) && + filter->type == kTfLiteInt4) { + MicroPrintf("Compression not supported with INT4 tensors"); + return kTfLiteError; + } + data->weights_scratch_index = + micro_context->AllocateDecompressionScratchBuffer( + node, kFullyConnectedWeightsTensor); + data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer( + node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); if (bias != nullptr) { diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc index f850c0c0fca..1901ea99df6 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,12 +32,37 @@ TfLiteStatus XtensaEvalFullyConnectedQuantizedInt8( const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) { #if !defined(VISION_P6) + +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, + kFullyConnectedWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + const int32_t* bias_data = +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION tflite::micro::GetOptionalTensorData(bias); +#endif // USE_TFLM_COMPRESSION + + const int8_t* filter_data = +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, data.weights_scratch_index); +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter); +#endif // USE_TFLM_COMPRESSION // P6 Vision will handle INT4 filters as a reference operation. // For all other architectures, unpack INT4 here. - const int8_t* filter_data = tflite::micro::GetTensorData(filter); if (filter->type == kTfLiteInt4) { int8_t* unpacked_filter_data = static_cast( context->GetScratchBuffer(context, data.filter_buffer_index)); @@ -47,6 +72,7 @@ TfLiteStatus XtensaEvalFullyConnectedQuantizedInt8( tflite::micro::GetTensorShape(filter).FlatSize(), unpacked_filter_data); filter_data = unpacked_filter_data; } + #endif // !defined(VISION_P6) #if defined(HIFIMINI) diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc index 14bb9a12b15..e81855f3e8c 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -108,14 +108,68 @@ TfLiteStatus FullyConnectedPrepareVision(TfLiteContext* context, filter_int8 = *filter; } +#ifdef USE_TFLM_COMPRESSION + + uint8_t* filter_data = nullptr; + int32_t* bias_data = nullptr; + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kFullyConnectedWeightsTensor); + if (filter_comp_td != nullptr) { + const size_t filter_data_size = + NumElements(&filter_int8) * TfLiteTypeGetSize(kTfLiteInt8); + filter_data = + micro_context->AllocateTempBuffer(filter_data_size, sizeof(int8_t)); + if (filter_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* filter_eval = tflite::micro::GetEvalInput( + context, node, kFullyConnectedWeightsTensor); + filter_data = static_cast(micro_context->DecompressTensorToBuffer( + *filter_eval, *filter_comp_td, filter_data)); + } else { + filter_data = GetTensorData(&filter_int8); + } + + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); + if (bias_comp_td != nullptr) { + TFLITE_DCHECK(bias != nullptr); + const size_t bias_data_size = + NumElements(bias) * TfLiteTypeGetSize(kTfLiteInt32); + bias_data = reinterpret_cast( + micro_context->AllocateTempBuffer(bias_data_size, sizeof(int32_t))); + if (bias_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* bias_eval = + tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor); + TFLITE_DCHECK(bias_eval != nullptr); + bias_data = static_cast(micro_context->DecompressTensorToBuffer( + *bias_eval, *bias_comp_td, bias_data)); + } else { + bias_data = GetTensorData(bias); + } + + if (filter_data == nullptr || (bias != nullptr && bias_data == nullptr)) { + return kTfLiteError; + } + +#else // USE_TFLM_COMPRESSION + + uint8_t* filter_data = GetTensorData(&filter_int8); + int32_t* bias_data = GetTensorData(bias); + +#endif // USE_TFLM_COMPRESSION + status = xiFullyConnectedSetContext( data->p_context, data->context_size, inputDims, outputDims, filterDims, 1, input->params.zero_point, filter->params.zero_point, output->params.zero_point, data->reference_op_data.output_multiplier, data->reference_op_data.output_shift, data->reference_op_data.output_activation_min, - data->reference_op_data.output_activation_max, - (uint8_t*)GetTensorData(&filter_int8)); + data->reference_op_data.output_activation_max, filter_data); if (status) { return kTfLiteError; @@ -139,9 +193,7 @@ TfLiteStatus FullyConnectedPrepareVision(TfLiteContext* context, status = xiFullyConnectedDoCoeffReorder( data->p_context, data->context_size, reinterpret_cast(data->reorder_coefficient_bias), - data->reorder_coefficient_bias_size, - const_cast(GetTensorData(&filter_int8)), - const_cast(GetTensorData(bias))); + data->reorder_coefficient_bias_size, filter_data, bias_data); if (status) { return kTfLiteError; } @@ -149,6 +201,18 @@ TfLiteStatus FullyConnectedPrepareVision(TfLiteContext* context, if (filter->type == kTfLiteInt4) { micro_context->DeallocateTempBuffer(GetTensorData(&filter_int8)); } + +#ifdef USE_TFLM_COMPRESSION + + if (filter_comp_td) { + micro_context->DeallocateTempBuffer(filter_data); + } + if (bias_comp_td) { + micro_context->DeallocateTempBuffer(reinterpret_cast(bias_data)); + } + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); diff --git a/tensorflow/lite/micro/memory_arena_threshold_test.cc b/tensorflow/lite/micro/memory_arena_threshold_test.cc index 6bc23bc37d0..34c62cda412 100644 --- a/tensorflow/lite/micro/memory_arena_threshold_test.cc +++ b/tensorflow/lite/micro/memory_arena_threshold_test.cc @@ -63,7 +63,6 @@ constexpr int kKeywordModelOnlyTotalSize = 14472; // TODO(b/207157610): replace magic number that depends on OPs constexpr int kKeywordModelOnlyTailSize = 13800; constexpr int kKeywordModelPersistentTfLiteTensorDataSize = 128; -constexpr int kKeywordModelPersistentBufferDataSize = 832; #else // Total size contributed by the keyword model excluding the // RecordingMicroAllocator's overhead. @@ -74,7 +73,6 @@ constexpr int kKeywordModelOnlyTotalSize = 14936; // TODO(b/207157610): replace magic number that depends on OPs constexpr int kKeywordModelOnlyTailSize = 14264; constexpr int kKeywordModelPersistentTfLiteTensorDataSize = 224; -constexpr int kKeywordModelPersistentBufferDataSize = 840; #endif constexpr int kKeywordModelHeadSize = 672; constexpr int kKeywordModelTfLiteTensorVariableBufferDataSize = 10240; @@ -87,6 +85,12 @@ uint8_t test_conv_tensor_arena[kTestConvModelArenaSize]; constexpr int kTestConvModelTensorCount = 15; constexpr int kTestConvModelNodeAndRegistrationCount = 7; +#if defined(USE_TFLM_COMPRESSION) +constexpr int kKeywordModelPersistentBufferDataSize = 920; +#else +constexpr int kKeywordModelPersistentBufferDataSize = 840; +#endif + // NOTE: These values are measured on x86-64: // TODO(b/158651472): Consider auditing these values on non-64 bit systems. #ifdef TF_LITE_STATIC_MEMORY diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc index 92527adce53..8a3ee7b6716 100644 --- a/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc @@ -97,4 +97,10 @@ ifeq ($(OPTIMIZED_KERNEL_DIR), xtensa) $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/unidirectional_sequence_lstm.cc + + # override KERNEL_OPTIMIZATION_LEVEL to enable higher performance + # Xtensa intrinsics. +$(KERNEL_OBJDIR)$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/decompress.o: $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/decompress.cc + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) -O3 -LNO:simd $(INCLUDES) -c $< -o $@ endif