From 81e548b95ae3489d3fec162c4a278cfab0f9b414 Mon Sep 17 00:00:00 2001 From: ddavis-2015 Date: Thu, 12 Dec 2024 16:23:15 -0800 Subject: [PATCH] Add GetOptionalTensorData (four parameter version). Change GetTensorData (four parameter version) to add DCHECK. Update FULLY_CONNECTED kernels for optional bias tensors when conditional compression compilation enabled. --- .../lite/micro/kernels/fully_connected.cc | 10 ++++----- tensorflow/lite/micro/kernels/kernel_util.h | 22 ++++++++++++++----- .../micro/kernels/xtensa/fully_connected.cc | 6 ++--- .../kernels/xtensa/fully_connected_int8.cc | 4 ++-- .../kernels/xtensa/fully_connected_vision.cc | 2 ++ 5 files changed, 29 insertions(+), 15 deletions(-) diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index b044a4bbab2..6902728043f 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -148,8 +148,8 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { weights_comp_td, data.weights_scratch_index), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, - data.bias_scratch_index), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), #else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), @@ -194,7 +194,7 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { micro_context, filter, weights_comp_td, data.weights_scratch_index), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData( + tflite::micro::GetOptionalTensorData( micro_context, bias, bias_comp_td, data.bias_scratch_index), #else // USE_TFLM_COMPRESSION @@ -214,7 +214,7 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { micro_context, filter, weights_comp_td, data.weights_scratch_index), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData( + tflite::micro::GetOptionalTensorData( micro_context, bias, bias_comp_td, data.bias_scratch_index), #else // USE_TFLM_COMPRESSION @@ -248,7 +248,7 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { weights_comp_td, data.weights_scratch_index), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData( + tflite::micro::GetOptionalTensorData( micro_context, bias, bias_comp_td, data.bias_scratch_index), #else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), diff --git a/tensorflow/lite/micro/kernels/kernel_util.h b/tensorflow/lite/micro/kernels/kernel_util.h index 5f8eb2e4ec3..5cb71af7953 100644 --- a/tensorflow/lite/micro/kernels/kernel_util.h +++ b/tensorflow/lite/micro/kernels/kernel_util.h @@ -100,13 +100,13 @@ const T* GetOptionalTensorData(const TfLiteEvalTensor* tensor) { #ifdef USE_TFLM_COMPRESSION -// Overloads existing GetTensorData. If not compressed, this will return +// Overloads existing GetOptionalTensorData. If not compressed, this will return // tensor->data. template -const T* GetTensorData(MicroContext* micro_context, - const TfLiteEvalTensor* tensor, - const CompressionTensorData* compression_data, - int scratch_buffer_handle) { +const T* GetOptionalTensorData(MicroContext* micro_context, + const TfLiteEvalTensor* tensor, + const CompressionTensorData* compression_data, + int scratch_buffer_handle) { if (tensor == nullptr) { return nullptr; } @@ -128,6 +128,18 @@ const T* GetTensorData(MicroContext* micro_context, return reinterpret_cast(uncompressed_data); } +// Overloads existing GetTensorData. If not compressed, this will return +// tensor->data. +template +const T* GetTensorData(MicroContext* micro_context, + const TfLiteEvalTensor* tensor, + const CompressionTensorData* compression_data, + int scratch_buffer_handle) { + TFLITE_DCHECK(tensor != nullptr); + return GetOptionalTensorData(micro_context, tensor, compression_data, + scratch_buffer_handle); +} + #endif // USE_TFLM_COMPRESSION // Returns the shape of a TfLiteEvalTensor struct. diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc index 4a141784d8f..511335a550f 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc @@ -75,8 +75,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { weights_comp_td, data.weights_scratch_index), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, - data.bias_scratch_index), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), #else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), @@ -119,7 +119,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { weights_comp_td, data.weights_scratch_index), tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData( + tflite::micro::GetOptionalTensorData( micro_context, bias, bias_comp_td, data.bias_scratch_index), #else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc index 32dfba2e5a8..1901ea99df6 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc @@ -47,8 +47,8 @@ TfLiteStatus XtensaEvalFullyConnectedQuantizedInt8( const int32_t* bias_data = #ifdef USE_TFLM_COMPRESSION - tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, - data.bias_scratch_index); + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index); #else // USE_TFLM_COMPRESSION tflite::micro::GetOptionalTensorData(bias); #endif // USE_TFLM_COMPRESSION diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc index 24fd1258277..e81855f3e8c 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected_vision.cc @@ -135,6 +135,7 @@ TfLiteStatus FullyConnectedPrepareVision(TfLiteContext* context, const CompressionTensorData* bias_comp_td = micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); if (bias_comp_td != nullptr) { + TFLITE_DCHECK(bias != nullptr); const size_t bias_data_size = NumElements(bias) * TfLiteTypeGetSize(kTfLiteInt32); bias_data = reinterpret_cast( @@ -144,6 +145,7 @@ TfLiteStatus FullyConnectedPrepareVision(TfLiteContext* context, } const TfLiteEvalTensor* bias_eval = tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor); + TFLITE_DCHECK(bias_eval != nullptr); bias_data = static_cast(micro_context->DecompressTensorToBuffer( *bias_eval, *bias_comp_td, bias_data)); } else {