Skip to content

Commit

Permalink
TFLM compression changes (3rd) (tensorflow#2658)
Browse files Browse the repository at this point in the history
@tensorflow/micro

Updates to support TFLM compression:

MicroContext
MicroInterpreterContext
FakeMicroContext
KernelRunner

bug=tensorflow#2657
  • Loading branch information
ddavis-2015 authored Aug 14, 2024
1 parent 95da7a8 commit 2b127fd
Show file tree
Hide file tree
Showing 4 changed files with 339 additions and 11 deletions.
79 changes: 74 additions & 5 deletions tensorflow/lite/micro/fake_micro_context.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -23,10 +23,23 @@ limitations under the License.

namespace tflite {

FakeMicroContext::FakeMicroContext(TfLiteTensor* tensors,
SingleArenaBufferAllocator* allocator,
MicroGraph* micro_graph)
: graph_(*micro_graph), tensors_(tensors), allocator_(allocator) {}
FakeMicroContext::FakeMicroContext(
TfLiteTensor* tensors, SingleArenaBufferAllocator* allocator,
MicroGraph* micro_graph
#ifdef USE_TFLM_COMPRESSION
,
const CompressedTensorList* compressed_tensors
#endif // USE_TFLM_COMPRESSION
)
: graph_(*micro_graph),
tensors_(tensors),
allocator_(allocator)
#ifdef USE_TFLM_COMPRESSION
,
compressed_tensors_(compressed_tensors)
#endif // USE_TFLM_COMPRESSION
{
}

TfLiteTensor* FakeMicroContext::AllocateTempTfLiteTensor(int tensor_index) {
allocated_temp_count_++;
Expand Down Expand Up @@ -112,4 +125,60 @@ void* FakeMicroContext::external_context() { return nullptr; }

MicroGraph& FakeMicroContext::graph() { return graph_; }

#ifdef USE_TFLM_COMPRESSION

// Available during Prepare & Eval. Returns false if tensor is not
// compressed.
bool FakeMicroContext::IsTensorCompressed(const TfLiteNode* node,
int tensor_idx) {
if (compressed_tensors_ != nullptr && tensor_idx < node->inputs->size) {
int index = node->inputs->data[tensor_idx];
if (index >= 0 && compressed_tensors_->tensors[index] != nullptr) {
return true;
}
}

return false;
}

// Only available during Prepare. The kernel is responsible for storing the
// scratch buffer handle.
int FakeMicroContext::AllocateDecompressionScratchBuffer(const TfLiteNode* node,
int tensor_idx) {
if (compressed_tensors_ == nullptr || tensor_idx >= node->inputs->size) {
return -1;
}
int index = node->inputs->data[tensor_idx];
if (index < 0 || compressed_tensors_->tensors[index] == nullptr) {
return -1;
}
TfLiteTensor* tensor = &tensors_[index];
int scratch_index = -1;
TfLiteStatus result =
RequestScratchBufferInArena(tensor->bytes, &scratch_index);
if (result != kTfLiteOk) {
return -1;
}

return scratch_index;
}

// Available during Prepare & Eval. Returns nullptr if tensor is not
// compressed.
const CompressionTensorData* FakeMicroContext::GetTensorCompressionData(
const TfLiteNode* node, int tensor_idx) {
if (compressed_tensors_ == nullptr || tensor_idx >= node->inputs->size) {
return nullptr;
}

int index = node->inputs->data[tensor_idx];
if (index < 0) {
return nullptr;
}

return compressed_tensors_->tensors[index];
}

#endif // USE_TFLM_COMPRESSION

} // namespace tflite
17 changes: 13 additions & 4 deletions tensorflow/lite/micro/kernels/kernel_runner.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h"
#include "tensorflow/lite/micro/micro_arena_constants.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/test_helpers.h"

namespace tflite {
namespace micro {
Expand All @@ -38,12 +37,22 @@ KernelRunner::KernelRunner(const TFLMRegistration& registration,
TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs, TfLiteIntArray* outputs,
const void* builtin_data,
TfLiteIntArray* intermediates)
TfLiteIntArray* intermediates
#ifdef USE_TFLM_COMPRESSION
,
const CompressedTensorList* compressed_tensors
#endif // USE_TFLM_COMPRESSION
)
: registration_(registration),
allocator_(SingleArenaBufferAllocator::Create(kKernelRunnerBuffer_,
kKernelRunnerBufferSize_)),
mock_micro_graph_(allocator_),
fake_micro_context_(tensors, allocator_, &mock_micro_graph_) {
fake_micro_context_(tensors, allocator_, &mock_micro_graph_
#ifdef USE_TFLM_COMPRESSION
,
compressed_tensors
#endif // USE_TFLM_COMPRESSION
) {
// Prepare TfLiteContext:
context_.impl_ = static_cast<void*>(&fake_micro_context_);
context_.ReportError = MicroContextReportOpError;
Expand Down
153 changes: 152 additions & 1 deletion tensorflow/lite/micro/micro_context.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -18,8 +18,10 @@ limitations under the License.
#include <cstdarg>
#include <cstddef>

#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/micro/micro_common.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/micro_utils.h"

namespace tflite {
namespace {
Expand All @@ -34,6 +36,103 @@ int GetTensorIndex(int index, int max_size, const int* tensor_indices) {
return -1;
}

#ifdef USE_TFLM_COMPRESSION

struct DecompressionState {
DecompressionState() = delete;

DecompressionState(const uint8_t* compressed_indices,
const size_t count_indices,
const CompressionTensorData& comp_data,
const size_t num_channels)
: compressed_indices_(compressed_indices),
count_indices_(count_indices),
comp_data_(comp_data),
num_channels_(num_channels) {}

template <typename T>
T* DecompressToBuffer(void* buffer);

size_t GetNextTableIndex();
void UpdateBufferAndChannelIndex();

private:
const uint8_t* compressed_indices_;
const size_t count_indices_;
const CompressionTensorData& comp_data_;
const size_t num_channels_;
const size_t compressed_bit_width_ =
comp_data_.data.lut_data->compressed_bit_width;
size_t channel_ = 0;
size_t index_in_channel_ = 0;
const size_t elements_per_channel_ =
comp_data_.data.lut_data->use_alternate_axis
? 1
: count_indices_ / num_channels_;
size_t buffer_index_ = 0;
size_t current_offset_ = 0;
size_t current_bits_remaining_ = 8;
uint8_t current_byte_ = compressed_indices_[0];
};

template <typename T>
T* DecompressionState::DecompressToBuffer(void* buffer) {
while (buffer_index_ < count_indices_) {
const size_t table_index = GetNextTableIndex();
static_cast<T*>(buffer)[buffer_index_] =
static_cast<const T*>(comp_data_.data.lut_data->value_table)
[table_index +
(channel_ * comp_data_.data.lut_data->value_table_channel_stride)];
UpdateBufferAndChannelIndex();
}

return static_cast<T*>(buffer);
}

size_t DecompressionState::GetNextTableIndex() {
TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth);
TFLITE_DCHECK(compressed_bit_width_ > 0);

size_t table_index_bits_to_fill = compressed_bit_width_;
size_t table_index = 0;

while (table_index_bits_to_fill > 0) {
if (current_bits_remaining_ == 0) {
current_offset_++;
current_byte_ = compressed_indices_[current_offset_];
current_bits_remaining_ = 8;
}

const uint8_t mask_bit_count =
std::min(table_index_bits_to_fill,
std::min(compressed_bit_width_, current_bits_remaining_));
const uint8_t current_byte_mask = (1 << mask_bit_count) - 1;
table_index <<= mask_bit_count;
table_index |=
(current_byte_ >> (current_bits_remaining_ - mask_bit_count)) &
current_byte_mask;

table_index_bits_to_fill -= mask_bit_count;
current_bits_remaining_ -= mask_bit_count;
}

return table_index;
}

void DecompressionState::UpdateBufferAndChannelIndex() {
buffer_index_++;
index_in_channel_++;
if (index_in_channel_ == elements_per_channel_) {
index_in_channel_ = 0;
channel_++;
if (channel_ == num_channels_) {
channel_ = 0;
}
}
}

#endif // USE_TFLM_COMPRESSION

} // namespace

TfLiteTensor* MicroContext::AllocateTempInputTensor(const TfLiteNode* node,
Expand Down Expand Up @@ -74,4 +173,56 @@ void MicroContextReportOpError(struct TfLiteContext* context,
va_end(args);
}

#ifdef USE_TFLM_COMPRESSION

void* MicroContext::DecompressTensorToScratchBuffer(
const TfLiteEvalTensor& tensor,
const CompressionTensorData& compression_data, int scratch_buffer_handle) {
TFLITE_DCHECK(compression_data.scheme == CompressionScheme::kBinQuant);
TFLITE_DCHECK(scratch_buffer_handle != -1);
void* scratch_buffer = GetScratchBuffer(scratch_buffer_handle);
TFLITE_DCHECK(scratch_buffer != nullptr);
size_t count = ElementCount(*tensor.dims);
size_t num_channels = 1;

if (compression_data.data.lut_data->is_per_channel_quantized) {
const size_t channel_axis =
compression_data.data.lut_data->use_alternate_axis
? tensor.dims->size - 1
: 0;
num_channels = tensor.dims->data[channel_axis];
}

DecompressionState ds(static_cast<uint8_t*>(tensor.data.data), count,
compression_data, num_channels);

switch (tensor.type) {
case kTfLiteBool: {
return ds.DecompressToBuffer<bool>(scratch_buffer);
} break;
case kTfLiteInt8: {
return ds.DecompressToBuffer<int8_t>(scratch_buffer);
} break;
case kTfLiteInt16: {
return ds.DecompressToBuffer<int16_t>(scratch_buffer);
} break;
case kTfLiteInt32: {
return ds.DecompressToBuffer<int32_t>(scratch_buffer);
} break;
case kTfLiteInt64: {
return ds.DecompressToBuffer<int64_t>(scratch_buffer);
} break;
case kTfLiteFloat32: {
return ds.DecompressToBuffer<float>(scratch_buffer);
} break;
default: {
MicroPrintf("Unsupported decompression tensor type %d", tensor.type);
} break;
}

return nullptr;
}

#endif // USE_TFLM_COMPRESSION

} // namespace tflite
Loading

0 comments on commit 2b127fd

Please sign in to comment.