Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(tflite): Fix memory leaks in tflite integration #2842

Merged
merged 1 commit into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 2 additions & 37 deletions code/components/jomjol_tfliteclass/CTfLiteClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,9 @@

static const char *TAG = "TFLITE";

/// Static Resolver muss mit allen Operatoren geladen Werden, die benöägit werden - ABER nur 1x --> gesonderte Funktion /////////////////////////////
static bool MakeStaticResolverDone = false;
static tflite::MicroMutableOpResolver<15> resolver;

void MakeStaticResolver()
void CTfLiteClass::MakeStaticResolver()
{
if (MakeStaticResolverDone)
return;

MakeStaticResolverDone = true;

resolver.AddFullyConnected();
resolver.AddReshape();
resolver.AddSoftmax();
Expand All @@ -34,7 +26,6 @@ void MakeStaticResolver()
resolver.AddLeakyRelu();
resolver.AddDequantize();
}
////////////////////////////////////////////////////////////////////////////////////////


float CTfLiteClass::GetOutputValue(int nr)
Expand Down Expand Up @@ -207,23 +198,19 @@ bool CTfLiteClass::LoadInputImageBasis(CImageBasis *rs)

bool CTfLiteClass::MakeAllocate()
{

MakeStaticResolver();

MakeStaticResolver();

#ifdef DEBUG_DETAIL_ON
LogFile.WriteHeapInfo("CTLiteClass::Alloc start");
#endif

LogFile.WriteToFile(ESP_LOG_DEBUG, TAG, "CTfLiteClass::MakeAllocate");
this->interpreter = new tflite::MicroInterpreter(this->model, resolver, this->tensor_arena, this->kTensorArenaSize);
// this->interpreter = new tflite::MicroInterpreter(this->model, resolver, this->tensor_arena, this->kTensorArenaSize, this->error_reporter);

if (this->interpreter)
{
TfLiteStatus allocate_status = this->interpreter->AllocateTensors();
if (allocate_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
LogFile.WriteToFile(ESP_LOG_ERROR, TAG, "AllocateTensors() failed");

this->GetInputDimension();
Expand Down Expand Up @@ -313,13 +300,6 @@ bool CTfLiteClass::ReadFileToModel(std::string _fn)

bool CTfLiteClass::LoadModel(std::string _fn)
{
#ifdef SUPRESS_TFLITE_ERRORS
// this->error_reporter = new tflite::ErrorReporter;
this->error_reporter = new tflite::OwnMicroErrorReporter;
#else
this->error_reporter = new tflite::MicroErrorReporter;
#endif

LogFile.WriteToFile(ESP_LOG_DEBUG, TAG, "CTfLiteClass::LoadModel");

if (!ReadFileToModel(_fn.c_str())) {
Expand Down Expand Up @@ -350,21 +330,6 @@ CTfLiteClass::CTfLiteClass()
CTfLiteClass::~CTfLiteClass()
{
delete this->interpreter;
// delete this->error_reporter;

psram_free_shared_tensor_arena_and_model_memory();
}

#ifdef SUPRESS_TFLITE_ERRORS
namespace tflite
{
//tflite::ErrorReporter
// int OwnMicroErrorReporter::Report(const char* format, va_list args)

int OwnMicroErrorReporter::Report(const char* format, va_list args)
{
return 0;
}
}
#endif

28 changes: 3 additions & 25 deletions code/components/jomjol_tfliteclass/CTfLiteClass.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,18 @@

#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"

#include "tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"

#include "esp_err.h"
#include "esp_log.h"

#include "CImageBasis.h"



#ifdef SUPRESS_TFLITE_ERRORS
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/micro/compatibility.h"
#include "tensorflow/lite/micro/debug_log.h"
///// OwnErrorReporter to prevent printing of Errors (especially unavoidable in CalculateActivationRangeQuantized@kerne_util.cc)
namespace tflite {
class OwnMicroErrorReporter : public ErrorReporter {
public:
int Report(const char* format, va_list args) override;
};
} // namespace tflite
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#endif


class CTfLiteClass
{
protected:
tflite::ErrorReporter *error_reporter;
tflite::MicroMutableOpResolver<10> resolver;
const tflite::Model* model;
tflite::MicroInterpreter* interpreter;
TfLiteTensor* output = nullptr;
Expand All @@ -54,6 +33,7 @@ class CTfLiteClass

long GetFileSize(std::string filename);
bool ReadFileToModel(std::string _fn);
void MakeStaticResolver();

public:
CTfLiteClass();
Expand All @@ -74,6 +54,4 @@ class CTfLiteClass
int ReadInputDimenstion(int _dim);
};

void MakeStaticResolver();

#endif //CTFLITECLASS_H
9 changes: 0 additions & 9 deletions code/include/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,6 @@
#define LWT_DISCONNECTED "connection lost"


//CTfLiteClass
#define TFLITE_MINIMAL_CHECK(x) \
if (!(x)) { \
fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \
exit(1); \
}
// #define SUPRESS_TFLITE_ERRORS // use, to avoid error messages from TFLITE


// connect_wlan.cpp
//******************************
/* WIFI roaming functionalities 802.11k+v (uses ca. 6kB - 8kB internal RAM; if SCAN CACHE activated: + 1kB / beacon)
Expand Down
Loading