Skip to content

Commit

Permalink
Merge branch 'main' into test-helper-include
Browse files Browse the repository at this point in the history
  • Loading branch information
suleshahid authored Dec 13, 2024
2 parents 16160f9 + a9c6e6a commit 3a13794
Show file tree
Hide file tree
Showing 11 changed files with 1,252 additions and 38 deletions.
70 changes: 68 additions & 2 deletions tensorflow/lite/micro/kernels/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ TfLiteStatus FullyConnectedPrepare(TfLiteContext* context, TfLiteNode* node) {
(input->type == kTfLiteInt8 &&
(filter->type != kTfLiteInt8 && filter->type != kTfLiteInt4)) ||
(input->type == kTfLiteInt16 && filter->type != kTfLiteInt8)) {
MicroPrintf("Input type: %s with filter type : %s not supported.",
MicroPrintf("Input type: %s with filter type: %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(filter->type));
return kTfLiteError;
Expand All @@ -79,6 +79,23 @@ TfLiteStatus FullyConnectedPrepare(TfLiteContext* context, TfLiteNode* node) {
context, params->activation, input->type,
input, filter, bias, output, data));

#ifdef USE_TFLM_COMPRESSION

// Compression scratch buffers.
// These will only be allocated if the tensor is compressed.
if (micro_context->IsTensorCompressed(node, kFullyConnectedWeightsTensor) &&
filter->type == kTfLiteInt4) {
MicroPrintf("Compression not supported with INT4 tensors");
return kTfLiteError;
}
data->weights_scratch_index =
micro_context->AllocateDecompressionScratchBuffer(
node, kFullyConnectedWeightsTensor);
data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer(
node, kFullyConnectedBiasTensor);

#endif // USE_TFLM_COMPRESSION

micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(filter);
if (bias != nullptr) {
Expand All @@ -102,8 +119,19 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor);

TFLITE_DCHECK(node->user_data != nullptr);
#ifdef USE_TFLM_COMPRESSION

MicroContext* micro_context = GetMicroContext(context);

const CompressionTensorData* weights_comp_td =
micro_context->GetTensorCompressionData(node,
kFullyConnectedWeightsTensor);
const CompressionTensorData* bias_comp_td =
micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor);

#endif // USE_TFLM_COMPRESSION

TFLITE_DCHECK(node->user_data != nullptr);
const auto& data =
*(static_cast<const OpDataFullyConnected*>(node->user_data));

Expand All @@ -115,9 +143,18 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(micro_context, filter,
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<float>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<float>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
Expand Down Expand Up @@ -152,19 +189,39 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(
micro_context, filter, weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td,
data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output))
: tflite::reference_integer_ops::FullyConnected(
FullyConnectedParamsQuantized(data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(
micro_context, filter, weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td,
data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
Expand All @@ -186,9 +243,18 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(micro_context, filter,
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
break;
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/lite/micro/kernels/fully_connected.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ struct OpDataFullyConnected {
int32_t* per_channel_output_shift;
bool is_per_channel;
#endif

#ifdef USE_TFLM_COMPRESSION

// scratch buffers for compressed tensors
int weights_scratch_index;
int bias_scratch_index;

#endif // USE_TFLM_COMPRESSION
};

extern const int kFullyConnectedInputTensor;
Expand Down
Loading

0 comments on commit 3a13794

Please sign in to comment.