Skip to content

Commit

Permalink
Delegate memory-mapping the model file to the resource system
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707080341
  • Loading branch information
MediaPipe Team authored and copybara-github committed Dec 17, 2024
1 parent bd154ba commit d12c810
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 49 deletions.
3 changes: 0 additions & 3 deletions mediapipe/util/tflite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,12 @@ cc_library_with_tflite(
],
visibility = ["//visibility:public"],
deps = [
":error_reporter",
"//mediapipe/framework:resources",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/util:resource_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
Expand Down
55 changes: 10 additions & 45 deletions mediapipe/util/tflite/tflite_model_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,75 +15,40 @@
#include "mediapipe/util/tflite/tflite_model_loader.h"

#include <memory>
#include <optional>
#include <string>
#include <utility>

#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/resources.h"
#include "mediapipe/util/resource_util.h"
#include "mediapipe/util/tflite/error_reporter.h"
#include "tensorflow/lite/allocation.h"
#include "tensorflow/lite/model_builder.h"

namespace mediapipe {

using ::mediapipe::util::tflite::ErrorReporter;
using ::tflite::Allocation;
using ::tflite::FlatBufferModel;
using ::tflite::MMAPAllocation;

absl::StatusOr<api2::Packet<TfLiteModelPtr>> TfLiteModelLoader::LoadFromPath(
const Resources& resources, const std::string& path, bool try_mmap) {
std::string model_path = path;

bool file_exists = file::Exists(model_path).ok();
if (!file_exists) {
// TODO: get rid of manual resolving with PathToResourceAsFile
// as soon as it's incorporated into GetResourceContents.
absl::StatusOr<std::string> resolved_model_path =
mediapipe::PathToResourceAsFile(model_path);
if (resolved_model_path.ok()) {
VLOG(2) << "Loading the model from " << model_path;
model_path = *std::move(resolved_model_path);
file_exists = true;
}
}

// Try to memory map file if available. Falls back to loading from buffer on
// error.
if (file_exists && try_mmap && MMAPAllocation::IsSupported()) {
ErrorReporter error_reporter;
std::unique_ptr<Allocation> allocation =
std::make_unique<MMAPAllocation>(model_path.c_str(), &error_reporter);

if (!error_reporter.HasError()) {
auto model = FlatBufferModel::BuildFromAllocation(std::move(allocation));
if (model) {
return api2::MakePacket<TfLiteModelPtr>(
model.release(), [](FlatBufferModel* model) { delete model; });
}
}

ABSL_LOG(WARNING) << "Failed to memory map model from path '" << model_path
<< "'; falling back to loading from buffer. Error: "
<< error_reporter.message();
}

// Load model resource.
MP_ASSIGN_OR_RETURN(std::unique_ptr<Resource> model_resource,
resources.Get(model_path));
MP_ASSIGN_OR_RETURN(
std::unique_ptr<Resource> model_resource,
resources.Get(
model_path,
Resources::Options{
.mmap_mode = try_mmap ? std::make_optional(MMapMode::kMMapOrRead)
: std::nullopt}));
absl::string_view model_view = model_resource->ToStringView();
auto model = FlatBufferModel::VerifyAndBuildFromBuffer(model_view.data(),
model_view.size());

RET_CHECK(model) << "Failed to load model from path " << model_path;
RET_CHECK(model) << "Failed to load model from path (resource ID) "
<< model_path;
return api2::MakePacket<TfLiteModelPtr>(
model.release(), [model_resource = model_resource.release()](
FlatBufferModel* model) mutable {
Expand Down
1 change: 0 additions & 1 deletion mediapipe/util/tflite/tflite_model_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include <memory>
#include <string>

#include "absl/base/attributes.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/packet.h"
Expand Down

0 comments on commit d12c810

Please sign in to comment.