diff --git a/tensorflow/lite/micro/kernels/depthwise_conv.cc b/tensorflow/lite/micro/kernels/depthwise_conv.cc index fa55a705606..489e83f94f2 100644 --- a/tensorflow/lite/micro/kernels/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/depthwise_conv.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -52,6 +52,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { ? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor) : nullptr; +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kDepthwiseConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + switch (input->type) { // Already know in/out types are same. case kTfLiteFloat32: { tflite::reference_ops::DepthwiseConv( @@ -59,9 +71,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -94,9 +115,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -118,9 +148,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; diff --git a/tensorflow/lite/micro/kernels/depthwise_conv_common.cc b/tensorflow/lite/micro/kernels/depthwise_conv_common.cc index 52804de3315..0813d2b028e 100644 --- a/tensorflow/lite/micro/kernels/depthwise_conv_common.cc +++ b/tensorflow/lite/micro/kernels/depthwise_conv_common.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -127,7 +127,9 @@ TfLiteStatus CalculateOpDataDepthwiseConv( micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); - micro_context->DeallocateTempTfLiteTensor(bias); + if (has_bias) { + micro_context->DeallocateTempTfLiteTensor(bias); + } micro_context->DeallocateTempTfLiteTensor(output); return kTfLiteOk; @@ -209,6 +211,23 @@ TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node) { context, node, params, input_width, input_height, filter_width, filter_height, output_width, output_height, input->type, data)); +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + if (micro_context->IsTensorCompressed(node, kDepthwiseConvWeightsTensor) && + filter->type == kTfLiteInt4) { + MicroPrintf("Compression not supported with INT4 tensors"); + return kTfLiteError; + } + data->weights_scratch_index = + micro_context->AllocateDecompressionScratchBuffer( + node, kDepthwiseConvWeightsTensor); + data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer( + node, kDepthwiseConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); diff --git a/tensorflow/lite/micro/kernels/depthwise_conv_test.cc b/tensorflow/lite/micro/kernels/depthwise_conv_test.cc index b50b40ae6d6..adedcaeb04e 100644 --- a/tensorflow/lite/micro/kernels/depthwise_conv_test.cc +++ b/tensorflow/lite/micro/kernels/depthwise_conv_test.cc @@ -1,5 +1,5 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,6 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" @@ -32,17 +34,99 @@ constexpr int kOutputTensorIndex = 3; constexpr int kMaxFilterChannels = 64; constexpr int kMaxBiasChannels = 64; +#ifdef USE_TFLM_COMPRESSION + +constexpr size_t kDepthwiseConvMaxTensors = 4; +constexpr size_t kDepthwiseConvMaxInputTensors = 3; + +// Common inputs and outputs (quantized multi channel). +// data from TfLite test: +// PerChannelQuantizedDepthwiseConvolutionOpTest SimpleTestMixedOutputShift +static int kInputShapeQ1[] = {4, 1, 2, 3, 2}; +static constexpr float kInputDataQ1[] = { + // [1 * 2 * 3 * 2] as [batch, y, x, input_channel] + 3, 2, // batch = 0, y = 0, x = 0 + 1, -1, // batch = 0, y = 0, x = 1 + -2, -3, // batch = 0, y = 0, x = 2 + 4, 3, // batch = 0, y = 1, x = 0 + 2, -2, // batch = 0, y = 1, x = 1 + -3, -4, // batch = 0, y = 1, x = 2 +}; +constexpr size_t kInputElementsQ1 = std::extent::value; + +constexpr int kNumChannelsQ1 = 4; +static int kFilterShapeQ1[] = {4, 1, 2, 2, 4}; +static constexpr float kFilterDataQ1[] = { + // This is a compact value table. Original data is: + // [1 * 2 * 2 * 4] as [input_channel, y, x, output_channel] + // depth multiplier = 2 + // 1, 2, 3, 4, y = 0, x = 0 + // 3, 4, 5, 6, y = 0, x = 1 + // 7, 8, 5, 6, y = 1, x = 0 + // 3, 4, 1, 2, y = 1, x = 1 + 1, 3, 7, 8, 2, 4, 1, 3, 5, 2, 4, 6, +}; +constexpr size_t kFilterElementsQ1 = + std::extent::value; + +static int kBiasShapeQ1[] = {1, 4}; +static constexpr float kBiasDataQ1[] = {3, -2, 4, 6}; +constexpr size_t kBiasElementsQ1 = std::extent::value; + +static int kOutputShapeQ1[] = {4, 1, 1, 2, 4}; +static constexpr float kGoldenDataQ1[] = {43, 48, 21, 22, 3, -4, -30, -36}; +constexpr int kOutputElementsQ1 = std::extent::value; + +// compressed filter data for kBinQuant scheme, matches kFilterDataQ1 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantFilterDataQ1[] = {0x15, 0x6A, 0x8A, + 0x60}; +constexpr int kBinQuantFilterBitWidthQ1 = 2; +// compressed bias data for kBinQuant scheme, matches kBiasDataQ1 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasDataQ1[] = {0x00}; +constexpr int kBinQuantBiasBitWidthQ1 = 1; + +#endif // USE_TFLM_COMPRESSION + // Creates a DepthwiseConv opeerator, calls it with the provided input tensors // and some defaults parameters, and compares the output with // expected_output_data. // // The tensors parameter contains both the input tensors as well as a // preallocated output tensor into which the output is stored. -template +template TfLiteStatus ValidateDepthwiseConvGoldens( const T* expected_output_data, int output_length, TfLiteDepthwiseConvParams* conv_params, float tolerance, int tensors_size, - TfLiteTensor* tensors) { + TfLiteTensor* tensors +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { +#ifdef USE_TFLM_COMPRESSION + + TestCompressedList tcl; + if (filter_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*filter_comp_info, tensors[kDepthwiseConvWeightsTensor], + kDepthwiseConvWeightsTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + if (bias_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*bias_comp_info, tensors[kDepthwiseConvBiasTensor], + kDepthwiseConvBiasTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + const CompressedTensorList* comp_list_p = tcl.GetCompressedTensorList(); + +#endif // USE_TFLM_COMPRESSION + int inputs_array_data[] = {3, 0, 1, 2}; TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); int outputs_array_data[] = {1, 3}; @@ -50,8 +134,12 @@ TfLiteStatus ValidateDepthwiseConvGoldens( const TFLMRegistration registration = Register_DEPTHWISE_CONV_2D(); micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, - outputs_array, - reinterpret_cast(conv_params)); + outputs_array, reinterpret_cast(conv_params) +#ifdef USE_TFLM_COMPRESSION + , + nullptr, comp_list_p +#endif // USE_TFLM_COMPRESSION + ); int input_depth = tensors[0].dims->data[3]; int output_depth = tensors[1].dims->data[3]; @@ -183,18 +271,93 @@ void TestDepthwiseConvQuantizedPerChannel( output_scale, output_zero_point, conv_params, filter_packed_type); } +#ifdef USE_TFLM_COMPRESSION + +template +TfLiteStatus TestDepthwiseConvQuantizedCompressed( + int* input_dims_data, const float* input_data, TIO* input_quantized, + float input_scale, int input_zero_point, int* output_dims_data, + const float* expected_output_data, TIO* expected_output_quantized, + TIO* output_quantized, float output_scale, int output_zero_point, + TfLiteDepthwiseConvParams* conv_params, const unsigned int tolerance, + const TestCompressionQuantizedInfo* filter_comp_info, + const TestCompressionQuantizedInfo* bias_comp_info) { + // TODO(b/360169306): account for optional bias tensor + // bool null_bias = comp_info->bias_data == nullptr ? true : false; + + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInts(filter_comp_info->dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_comp_info->dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + + TfLiteFloatArray* filter_scales = + FloatArrayFromFloats(filter_comp_info->scales); + TfLiteIntArray* filter_zero_points = + IntArrayFromInts(filter_comp_info->zero_points); + TfLiteFloatArray* bias_scales = FloatArrayFromFloats(bias_comp_info->scales); + TfLiteIntArray* bias_zero_points = + IntArrayFromInts(bias_comp_info->zero_points); + + TfLiteAffineQuantization filter_quant = {}; + TfLiteTensor filter_tensor = CreatePerChannelQuantizedTensor( + filter_comp_info->compressed, filter_dims, filter_scales, + filter_zero_points, &filter_quant, kDepthwiseConvQuantizedDimension, + false /* is_variable */, kTfLiteInt8); + // Value tables are always in channel order, therefore do not use the + // quantized dimension. + SymmetricPerChannelQuantize( + filter_comp_info->data, filter_comp_info->value_table, + filter_scales->size * filter_comp_info->value_table_stride, + filter_scales->size, filter_scales->data, 0 /* see comment above */); + + TfLiteAffineQuantization bias_quant = {}; + TfLiteTensor bias_tensor = CreatePerChannelQuantizedBiasTensor( + bias_comp_info->compressed, bias_dims, input_scale, filter_scales, + bias_scales, bias_zero_points, &bias_quant, + 0 /* quantized dimension for bias tensor */, false /* is_variable */, + typeToTfLiteType()); + SymmetricPerChannelQuantize( + bias_comp_info->data, bias_comp_info->value_table, + bias_scales->size * bias_comp_info->value_table_stride, bias_scales->size, + bias_scales->data); + + constexpr int tensors_size = kDepthwiseConvMaxTensors; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_quantized, input_dims, + input_scale, input_zero_point), + filter_tensor, + bias_tensor, + CreateQuantizedTensor(output_quantized, output_dims, output_scale, + output_zero_point), + }; + + const int output_dims_count = ElementCount(*output_dims); + Quantize(expected_output_data, expected_output_quantized, output_dims_count, + output_scale, output_zero_point); + return ValidateDepthwiseConvGoldens( + expected_output_quantized, output_dims_count, conv_params, tolerance, + tensors_size, tensors, filter_comp_info, bias_comp_info); +} + +#endif // USE_TFLM_COMPRESSION + +// TODO(ddavis-2015): is this still valid? // Xtensa kernels do not support float activations., and the corresponding tests // are disabled. As a result, helper functions that are only needed for float // kernel tests also need to be ifdef'd out to avoid build errors due to unused // functions. #if !defined(XTENSA) -void TestDepthwiseConvFloat(int* input_dims_data, const float* input_data, - int* filter_dims_data, const float* filter_data, - int* bias_dims_data, const float* bias_data, - const float* expected_output_data, - int* output_dims_data, - TfLiteDepthwiseConvParams* conv_params, - float* output_data) { +void TestDepthwiseConvFloat( + int* input_dims_data, const float* input_data, int* filter_dims_data, + const float* filter_data, int* bias_dims_data, const float* bias_data, + const float* expected_output_data, int* output_dims_data, + TfLiteDepthwiseConvParams* conv_params, float* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); @@ -212,7 +375,12 @@ void TestDepthwiseConvFloat(int* input_dims_data, const float* input_data, }; ValidateDepthwiseConvGoldens(expected_output_data, output_dims_count, - conv_params, 1e-5, tensors_size, tensors); + conv_params, 1e-5, tensors_size, tensors +#ifdef USE_TFLM_COMPRESSION + , + filter_comp_info, bias_comp_info +#endif // USE_TFLM_COMPRESSION + ); } #endif // !defined(XTENSA) @@ -253,6 +421,60 @@ TF_LITE_MICRO_TEST(SimpleTest) { bias_values, golden, output_shape, &conv_params, output_data); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestCompressed) { + int input_shape[] = {4, 1, 3, 2, 2}; + const float input_values[] = {1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12}; + int filter_shape[] = {4, 1, 2, 2, 4}; + // Filter values: + // {1, 2, 3, 4, -9, 10, -11, 12, 5, 6, 7, 8, 13, -14, 15, -16} + // Align the tensor data the same as a Buffer in the schema + alignas(16) const uint8_t kBinQuantFilterData[] = {0x01, 0x23, 0xF8, 0xE9, + 0x45, 0x67, 0xAD, 0xBC}; + const float kBinQuantFilterValueTable[] = {1, 2, 3, 4, 5, 6, 7, 8, + 10, 12, 13, 15, -16, -14, -11, -9}; + int bias_shape[] = {4, 1, 1, 1, 4}; + const float bias_values[] = {1, 2, 3, 4}; + // Align the tensor data the same as a Buffer in the schema + alignas(16) const uint8_t kBinQuantBiasData[] = {0x1B}; + const float golden[] = { + 71, -34, 99, -20, 91, -26, 127, -4, + }; + int output_shape[] = {4, 1, 2, 1, 4}; + const int output_dims_count = std::extent::value; + float output_data[output_dims_count]; + + tflite::testing::TestCompressionInfo filter_comp_info = {}; + tflite::testing::TestCompressionInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = kBinQuantFilterValueTable; + filter_comp_info.value_table_stride = + std::extent::value; + filter_comp_info.bit_width = 4; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_values; + bias_comp_info.value_table_stride = std::extent::value; + bias_comp_info.bit_width = 2; + + TfLiteDepthwiseConvParams conv_params; + conv_params.activation = kTfLiteActNone; + conv_params.dilation_width_factor = 1; + conv_params.dilation_height_factor = 1; + conv_params.stride_height = 1; + conv_params.stride_width = 1; + + tflite::testing::TestDepthwiseConvFloat( + input_shape, input_values, filter_shape, + reinterpret_cast(kBinQuantFilterData), bias_shape, + reinterpret_cast(kBinQuantBiasData), golden, output_shape, + &conv_params, output_data, &filter_comp_info, &bias_comp_info); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(SimpleTestRelu) { int input_shape[] = {4, 1, 3, 2, 2}; const float input_values[] = {1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12}; @@ -1068,4 +1290,144 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelInt16InputInt8Filter) { bias_quantized, output_shape, golden, golden_quantized, output_data, output_scale, output_zero_point, &conv_params); } + +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelInt8Compressed) { + // data from TfLite test: + // PerChannelQuantizedDepthwiseConvolutionOpTest SimpleTestMixedOutputShift + const float input_scale = 0.5f; + const float output_scale = 0.5f; + const int input_zero_point = -1; + const int output_zero_point = -1; + constexpr float filter_scales[] = { + tflite::testing::kNumChannelsQ1, 0.1f, 2.0f, 3.0f, 0.4f, + }; + constexpr int filter_zero_points[] = { + tflite::testing::kNumChannelsQ1, 0, 0, 0, 0, + }; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int8_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kFilterElementsQ1]; + int32_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int8_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int8_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kFilterElementsQ1 / tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kFilterDataQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TfLiteDepthwiseConvParams conv_params = {}; + conv_params.activation = kTfLiteActNone; + conv_params.dilation_width_factor = 1; + conv_params.dilation_height_factor = 1; + conv_params.stride_height = 1; + conv_params.stride_width = 1; + + // tolerance of 3 is approx. 2.0f + // TODO(ddavis-2015): why does the tolerance differ from TfLite test??? + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestDepthwiseConvQuantizedCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1, + golden_quantized, output_quantized, output_scale, output_zero_point, + &conv_params, 3, &filter_comp_info, &bias_comp_info)); +} + +TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelInt16Compressed) { + // data from TfLite test: + // PerChannelQuantizedDepthwiseConvolutionOpTest SimpleTestMixedOutputShift + const float input_scale = + tflite::testing::SymmetricScaleFromMinMax(-4.0f, 4.0f); + const float output_scale = + tflite::testing::SymmetricScaleFromMinMax(-63.5f, 64.0f); + const int input_zero_point = 0; + const int output_zero_point = 0; + constexpr float filter_scales[] = { + tflite::testing::kNumChannelsQ1, 0.1f, 2.0f, 3.0f, 0.4f, + }; + constexpr int filter_zero_points[] = { + tflite::testing::kNumChannelsQ1, 0, 0, 0, 0, + }; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int16_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kFilterElementsQ1]; + int64_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int16_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int16_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kFilterElementsQ1 / tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kFilterDataQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TfLiteDepthwiseConvParams conv_params = {}; + conv_params.activation = kTfLiteActNone; + conv_params.dilation_width_factor = 1; + conv_params.dilation_height_factor = 1; + conv_params.stride_height = 1; + conv_params.stride_width = 1; + + // tolerance of 512 is approx. 1.0f + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestDepthwiseConvQuantizedCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1, + golden_quantized, output_quantized, output_scale, output_zero_point, + &conv_params, 512, &filter_comp_info, &bias_comp_info)); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc index 8536ff79507..838fdc0944e 100644 --- a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc @@ -1,5 +1,5 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -93,6 +93,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteEvalTensor filter_int8 = tflite::micro::MakeUnpackedInt4Tensor( context, op_data.reference_op_data.filter_buffer_index, filter); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kDepthwiseConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + switch (input->type) { // Already know in/out types are same. case kTfLiteInt8: { switch (filter_int8.type) { @@ -111,9 +123,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, &filter_int8, filter_comp_td, + op_data.reference_op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + op_data.reference_op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(&filter_int8), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); #endif // defined(HIFI3) || defined(HIFI4) || defined(HIFI5) @@ -136,9 +158,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, &filter_int8, filter_comp_td, + op_data.reference_op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + op_data.reference_op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(&filter_int8), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; diff --git a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc index 8c2052b23e7..09e84dee936 100644 --- a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc +++ b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -97,10 +97,22 @@ TfLiteStatus DepthwiseConvEvalHifi(TfLiteContext* context, TfLiteNode* node, const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) { +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kDepthwiseConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + // If dilation is not required use the optimized NN Library kernel. // Otherwise call the reference implementation. if ((params.dilation_width_factor == 1) && - (params.dilation_height_factor == 1)) { + (params.dilation_height_factor == 1) && bias != nullptr) { const int stride_width = params.stride_width; const int stride_height = params.stride_height; const int pad_width = data.reference_op_data.padding.width; @@ -133,8 +145,17 @@ TfLiteStatus DepthwiseConvEvalHifi(TfLiteContext* context, TfLiteNode* node, TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); const int8_t* input_data = tflite::micro::GetTensorData(input); +#ifdef USE_TFLM_COMPRESSION + const int8_t* filter_data = tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, + data.reference_op_data.weights_scratch_index); + const int32_t* bias_data = tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + data.reference_op_data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION const int8_t* filter_data = tflite::micro::GetTensorData(filter); const int32_t* bias_data = tflite::micro::GetTensorData(bias); +#endif // USE_TFLM_COMPRESSION int8_t* output_data = tflite::micro::GetTensorData(output); int32_t input_data_format = 0; @@ -178,9 +199,19 @@ TfLiteStatus DepthwiseConvEvalHifi(TfLiteContext* context, TfLiteNode* node, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, + data.reference_op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.reference_op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); diff --git a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc index 35fa8cf1c1a..23e18dc8342 100644 --- a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc +++ b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -53,7 +53,7 @@ TfLiteStatus DepthwiseConvPrepareVision(TfLiteContext* context, TF_LITE_ENSURE(context, filter != nullptr); TfLiteTensor* bias = micro_context->AllocateTempInputTensor(node, kDepthwiseConvBiasTensor); - TF_LITE_ENSURE(context, filter != nullptr); + TF_LITE_ENSURE(context, bias != nullptr); // Dynamically allocate per-channel quantization parameters. const int num_channels = @@ -135,18 +135,81 @@ TfLiteStatus DepthwiseConvPrepareVision(TfLiteContext* context, filter_int8 = *filter; } +#ifdef USE_TFLM_COMPRESSION + + uint8_t* filter_data = nullptr; + int32_t* bias_data = nullptr; + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kDepthwiseConvWeightsTensor); + if (filter_comp_td != nullptr) { + const size_t filter_data_size = + NumElements(&filter_int8) * TfLiteTypeGetSize(kTfLiteInt8); + filter_data = + micro_context->AllocateTempBuffer(filter_data_size, sizeof(int8_t)); + if (filter_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* filter_eval = + tflite::micro::GetEvalInput(context, node, kDepthwiseConvWeightsTensor); + filter_data = static_cast(micro_context->DecompressTensorToBuffer( + *filter_eval, *filter_comp_td, filter_data)); + } else { + filter_data = GetTensorData(&filter_int8); + } + + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor); + if (bias_comp_td != nullptr) { + const size_t bias_data_size = + NumElements(bias) * TfLiteTypeGetSize(kTfLiteInt32); + bias_data = reinterpret_cast( + micro_context->AllocateTempBuffer(bias_data_size, sizeof(int32_t))); + if (bias_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* bias_eval = + tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor); + bias_data = static_cast(micro_context->DecompressTensorToBuffer( + *bias_eval, *bias_comp_td, bias_data)); + } else { + bias_data = GetTensorData(bias); + } + + if (filter_data == nullptr || bias_data == nullptr) { + return kTfLiteError; + } + +#else // USE_TFLM_COMPRESSION + + uint8_t* filter_data = GetTensorData(&filter_int8); + int32_t* bias_data = GetTensorData(bias); + +#endif // USE_TFLM_COMPRESSION + status = xiDepthwiseConvDoCoeffReorder( data->p_context, data->context_size, reinterpret_cast(data->reorder_coefficient_bias), - data->reorder_coefficient_bias_size, - const_cast(GetTensorData(&filter_int8)), - const_cast(GetTensorData(bias))); + data->reorder_coefficient_bias_size, filter_data, bias_data); if (status) { return kTfLiteError; } if (filter->type == kTfLiteInt4) { micro_context->DeallocateTempBuffer(GetTensorData(&filter_int8)); } + +#ifdef USE_TFLM_COMPRESSION + + if (filter_comp_td) { + micro_context->DeallocateTempBuffer(filter_data); + } + if (bias_comp_td) { + micro_context->DeallocateTempBuffer(reinterpret_cast(bias_data)); + } + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); diff --git a/tensorflow/lite/micro/micro_utils.h b/tensorflow/lite/micro/micro_utils.h index 98ef81dc8ed..b362d3402bb 100644 --- a/tensorflow/lite/micro/micro_utils.h +++ b/tensorflow/lite/micro/micro_utils.h @@ -90,12 +90,19 @@ void SymmetricQuantize(const float* input, T* output, int num_elements, template void SymmetricPerChannelQuantize(const float* input, T* output, int num_elements, int num_channels, - float* scales) { + float* scales, + size_t quantized_dimension = 0) { int elements_per_channel = num_elements / num_channels; for (int i = 0; i < num_channels; i++) { for (int j = 0; j < elements_per_channel; j++) { - output[i * elements_per_channel + j] = FloatToSymmetricQuantizedType( - input[i * elements_per_channel + j], scales[i]); + size_t offset; + if (quantized_dimension == 0) { + offset = i * elements_per_channel + j; + } else { + offset = i + elements_per_channel * j; + } + output[offset] = + FloatToSymmetricQuantizedType(input[offset], scales[i]); } } }