Skip to content

Commit

Permalink
Fix MicroProfiler bug with ClearEvents().
Browse files Browse the repository at this point in the history
Add pre-inference profiling to the Generic Benchmark.
  • Loading branch information
ddavis-2015 committed Nov 2, 2024
1 parent 4dca8e7 commit 0d889e0
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
19 changes: 14 additions & 5 deletions tensorflow/lite/micro/micro_profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ void MicroProfiler::LogTicksPerTagCsv() {
TFLITE_DCHECK(tags_[i] != nullptr);
int position = FindExistingOrNextPosition(tags_[i]);
TFLITE_DCHECK(position >= 0);
total_ticks_per_tag[position].tag = tags_[i];
total_ticks_per_tag[position].ticks =
total_ticks_per_tag[position].ticks + ticks;
total_ticks_per_tag_[position].tag = tags_[i];
total_ticks_per_tag_[position].ticks =
total_ticks_per_tag_[position].ticks + ticks;
total_ticks += ticks;
}

for (int i = 0; i < num_events_; ++i) {
TicksPerTag each_tag_entry = total_ticks_per_tag[i];
TicksPerTag each_tag_entry = total_ticks_per_tag_[i];
if (each_tag_entry.tag == nullptr) {
break;
}
Expand All @@ -112,12 +112,21 @@ void MicroProfiler::LogTicksPerTagCsv() {
int MicroProfiler::FindExistingOrNextPosition(const char* tag_name) {
int pos = 0;
for (; pos < num_events_; pos++) {
TicksPerTag each_tag_entry = total_ticks_per_tag[pos];
TicksPerTag each_tag_entry = total_ticks_per_tag_[pos];
if (each_tag_entry.tag == nullptr ||
strcmp(each_tag_entry.tag, tag_name) == 0) {
return pos;
}
}
return pos < num_events_ ? pos : -1;
}

void MicroProfiler::ClearEvents() {
for (int i = 0; i < num_events_; i++) {
total_ticks_per_tag_[i].tag = nullptr;
}

num_events_ = 0;
}

} // namespace tflite
6 changes: 3 additions & 3 deletions tensorflow/lite/micro/micro_profiler.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -45,7 +45,7 @@ class MicroProfiler : public MicroProfilerInterface {
virtual void EndEvent(uint32_t event_handle) override;

// Clears all the events that have been currently profiled.
void ClearEvents() { num_events_ = 0; }
void ClearEvents();

// Returns the sum of the ticks taken across all the events. This number
// is only meaningful if all of the events are disjoint (the end time of
Expand Down Expand Up @@ -83,7 +83,7 @@ class MicroProfiler : public MicroProfilerInterface {
// In practice, the number of tags will be much lower than the number of
// events. But it is theoretically possible that each event to be unique and
// hence we allow total_ticks_per_tag to have kMaxEvents entries.
TicksPerTag total_ticks_per_tag[kMaxEvents] = {};
TicksPerTag total_ticks_per_tag_[kMaxEvents] = {};

int FindExistingOrNextPosition(const char* tag_name);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,24 +182,35 @@ int Benchmark(const uint8_t* model_data, tflite::PrettyPrintType print_type) {
constexpr bool using_compression = false;
#endif // USE_TFLM_COMPRESSION

[[maybe_unused]]
alignas(16) static uint8_t tensor_arena[kTensorArenaSize];

uint32_t event_handle = profiler.BeginEvent("TfliteGetModel");
uint32_t event_handle = profiler.BeginEvent("tflite::GetModel");
const tflite::Model* model = tflite::GetModel(model_data);
profiler.EndEvent(event_handle);

event_handle = profiler.BeginEvent("tflite::CreateOpResolver");
TflmOpResolver op_resolver;
TF_LITE_ENSURE_STATUS(CreateOpResolver(op_resolver));
profiler.EndEvent(event_handle);

event_handle = profiler.BeginEvent("tflite::RecordingMicroAllocator::Create");
tflite::RecordingMicroAllocator* allocator(
tflite::RecordingMicroAllocator::Create(tensor_arena, kTensorArenaSize));
tflite::RecordingMicroAllocator::Create(
reinterpret_cast<uint8_t*>(0xe1000000), kTensorArenaSize));
profiler.EndEvent(event_handle);
event_handle = profiler.BeginEvent("tflite::MicroInterpreter instantiation");
tflite::RecordingMicroInterpreter interpreter(
model, op_resolver, allocator,
tflite::MicroResourceVariables::Create(allocator, kNumResourceVariable),
&profiler);
profiler.EndEvent(event_handle);
event_handle =
profiler.BeginEvent("tflite::MicroInterpreter::AllocateTensors");
TF_LITE_ENSURE_STATUS(interpreter.AllocateTensors());
profiler.EndEvent(event_handle);

profiler.Log();
profiler.LogTicksPerTagCsv();
profiler.ClearEvents();

if (using_compression) {
Expand Down

0 comments on commit 0d889e0

Please sign in to comment.