Skip to content

Commit

Permalink
Add GetOptionalTensorData (four parameter version).
Browse files Browse the repository at this point in the history
Change GetTensorData (four parameter version) to add DCHECK.
Update FULLY_CONNECTED kernels for optional bias tensors when conditional compression compilation enabled.
  • Loading branch information
ddavis-2015 committed Dec 13, 2024
1 parent 4d07fef commit 81e548b
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 15 deletions.
10 changes: 5 additions & 5 deletions tensorflow/lite/micro/kernels/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(micro_context, bias, bias_comp_td,
data.bias_scratch_index),
tflite::micro::GetOptionalTensorData<float>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
Expand Down Expand Up @@ -194,7 +194,7 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
micro_context, filter, weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td,
data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
Expand All @@ -214,7 +214,7 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
micro_context, filter, weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td,
data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
Expand Down Expand Up @@ -248,7 +248,7 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int64_t>(
tflite::micro::GetOptionalTensorData<int64_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
Expand Down
22 changes: 17 additions & 5 deletions tensorflow/lite/micro/kernels/kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ const T* GetOptionalTensorData(const TfLiteEvalTensor* tensor) {

#ifdef USE_TFLM_COMPRESSION

// Overloads existing GetTensorData. If not compressed, this will return
// Overloads existing GetOptionalTensorData. If not compressed, this will return
// tensor->data.
template <typename T>
const T* GetTensorData(MicroContext* micro_context,
const TfLiteEvalTensor* tensor,
const CompressionTensorData* compression_data,
int scratch_buffer_handle) {
const T* GetOptionalTensorData(MicroContext* micro_context,
const TfLiteEvalTensor* tensor,
const CompressionTensorData* compression_data,
int scratch_buffer_handle) {
if (tensor == nullptr) {
return nullptr;
}
Expand All @@ -128,6 +128,18 @@ const T* GetTensorData(MicroContext* micro_context,
return reinterpret_cast<const T*>(uncompressed_data);
}

// Overloads existing GetTensorData. If not compressed, this will return
// tensor->data.
template <typename T>
const T* GetTensorData(MicroContext* micro_context,
const TfLiteEvalTensor* tensor,
const CompressionTensorData* compression_data,
int scratch_buffer_handle) {
TFLITE_DCHECK(tensor != nullptr);
return GetOptionalTensorData<T>(micro_context, tensor, compression_data,
scratch_buffer_handle);
}

#endif // USE_TFLM_COMPRESSION

// Returns the shape of a TfLiteEvalTensor struct.
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/lite/micro/kernels/xtensa/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(micro_context, bias, bias_comp_td,
data.bias_scratch_index),
tflite::micro::GetOptionalTensorData<float>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
Expand Down Expand Up @@ -119,7 +119,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int64_t>(
tflite::micro::GetOptionalTensorData<int64_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/lite/micro/kernels/xtensa/fully_connected_int8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ TfLiteStatus XtensaEvalFullyConnectedQuantizedInt8(

const int32_t* bias_data =
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int32_t>(micro_context, bias, bias_comp_td,
data.bias_scratch_index);
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index);
#else // USE_TFLM_COMPRESSION
tflite::micro::GetOptionalTensorData<int32_t>(bias);
#endif // USE_TFLM_COMPRESSION
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ TfLiteStatus FullyConnectedPrepareVision(TfLiteContext* context,
const CompressionTensorData* bias_comp_td =
micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor);
if (bias_comp_td != nullptr) {
TFLITE_DCHECK(bias != nullptr);
const size_t bias_data_size =
NumElements(bias) * TfLiteTypeGetSize(kTfLiteInt32);
bias_data = reinterpret_cast<int32_t*>(
Expand All @@ -144,6 +145,7 @@ TfLiteStatus FullyConnectedPrepareVision(TfLiteContext* context,
}
const TfLiteEvalTensor* bias_eval =
tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor);
TFLITE_DCHECK(bias_eval != nullptr);
bias_data = static_cast<int32_t*>(micro_context->DecompressTensorToBuffer(
*bias_eval, *bias_comp_td, bias_data));
} else {
Expand Down

0 comments on commit 81e548b

Please sign in to comment.