Skip to content

Commit

Permalink
refactor decompression code into reference and platform specific
Browse files Browse the repository at this point in the history
Apply some Xtensa acceleration code changes
  • Loading branch information
ddavis-2015 committed Oct 11, 2024
1 parent b84853c commit 2388549
Show file tree
Hide file tree
Showing 6 changed files with 956 additions and 738 deletions.
65 changes: 65 additions & 0 deletions tensorflow/lite/micro/kernels/decompress.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifdef USE_TFLM_COMPRESSION

#include "tensorflow/lite/micro/kernels/decompress.h"

#include <cstddef>
#include <type_traits>

#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/micro/micro_common.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/micro_profiler.h"
#include "tensorflow/lite/micro/micro_utils.h"
#include "tensorflow/lite/portable_type_to_tflitetype.h"

namespace tflite {

template <typename T>
T* DecompressionState::DecompressToBuffer(void* buffer) {
TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth);
TFLITE_DCHECK(compressed_bit_width_ > 0);

if (std::is_same<T, int8_t>::value &&
comp_data_.data.lut_data->compressed_bit_width == 4 &&
!comp_data_.data.lut_data->use_alternate_axis) {
DecompressToBufferWidth4_16(static_cast<int8_t*>(buffer));
} else if (std::is_same<T, int8_t>::value &&
comp_data_.data.lut_data->compressed_bit_width == 3 &&
!comp_data_.data.lut_data->use_alternate_axis) {
DecompressToBufferWidth3_32(static_cast<int8_t*>(buffer));
} else if (std::is_same<T, int8_t>::value &&
comp_data_.data.lut_data->compressed_bit_width == 2 &&
!comp_data_.data.lut_data->use_alternate_axis) {
DecompressToBufferWidth2_16(static_cast<int8_t*>(buffer));
} else {
DecompressToBufferWidthAny<T>(static_cast<T*>(buffer));
}

return static_cast<T*>(buffer);
}

template bool* DecompressionState::DecompressToBuffer<bool>(void*);
template float* DecompressionState::DecompressToBuffer<float>(void*);
template int8_t* DecompressionState::DecompressToBuffer<int8_t>(void*);
template int16_t* DecompressionState::DecompressToBuffer<int16_t>(void*);
template int32_t* DecompressionState::DecompressToBuffer<int32_t>(void*);
template int64_t* DecompressionState::DecompressToBuffer<int64_t>(void*);

} // namespace tflite

#endif // USE_TFLM_COMPRESSION
82 changes: 82 additions & 0 deletions tensorflow/lite/micro/kernels/decompress.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>

#include "tensorflow/lite/micro/micro_context.h"

namespace tflite {

#ifdef USE_TFLM_COMPRESSION

struct DecompressionState {
DecompressionState() = delete;

DecompressionState(const uint8_t* compressed_indices,
const size_t count_indices,
const CompressionTensorData& comp_data,
const size_t num_channels, MicroContext* micro_context)
: compressed_indices_(compressed_indices),
count_indices_(count_indices),
comp_data_(comp_data),
num_channels_(num_channels),
micro_context_(micro_context) {}

DecompressionState(const DecompressionState& other)
: compressed_indices_(other.compressed_indices_),
count_indices_(other.count_indices_),
comp_data_(other.comp_data_),
num_channels_(other.num_channels_),
micro_context_(other.micro_context_) {}

template <typename T>
T* DecompressToBuffer(void* buffer);

protected:
// optimized C++ for INT8, use_alt_axis == false
void DecompressToBufferWidth4_16(int8_t* buffer);
void DecompressToBufferWidth3_32(int8_t* buffer);
void DecompressToBufferWidth2_16(int8_t* buffer);

// generic C++ for any bit width and value table type
template <typename T>
void DecompressToBufferWidthAny(T* buffer);

// Optimized C++ table index fetch
inline size_t GetNextTableIndexWidth7(const size_t current_offset);
inline size_t GetNextTableIndexWidth6(const size_t current_offset);
inline size_t GetNextTableIndexWidth5(const size_t current_offset);
inline size_t GetNextTableIndexWidth4(const size_t current_offset);
inline size_t GetNextTableIndexWidth3(const size_t current_offset);
inline size_t GetNextTableIndexWidth2(const size_t current_offset);
inline size_t GetNextTableIndexWidth1(const size_t current_offset);

protected:
const uint8_t* compressed_indices_;
const size_t count_indices_;
const CompressionTensorData& comp_data_;
const size_t num_channels_;
const size_t compressed_bit_width_ =
comp_data_.data.lut_data->compressed_bit_width;
const size_t elements_per_channel_ =
comp_data_.data.lut_data->use_alternate_axis
? 1
: count_indices_ / num_channels_;
MicroContext* micro_context_;
};

#endif // USE_TFLM_COMPRESSION

} // namespace tflite
Loading

0 comments on commit 2388549

Please sign in to comment.