Skip to content

Commit

Permalink
disable caching in runtime.
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Dec 19, 2024
1 parent e49112c commit d7b867c
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ static const char* const kCoremlProviderOption_AllowLowPrecisionAccumulationOnGP
// If this path is not specified, the model will be saved to a temp directory and deleted after the session is closed.
// otherwise, the model will be saved to the specified path and User should manage to delete the model.

// we do NOT detect if the onnx model has changed and no longer matches the cached model.
// we do NOT detect if the onnx model has changed and no longer matches the cached model.
// the user should carefully manage the cache if modifying/replacing a model.
// The cache key is generated by
// 1. User provided key in metadata_props if found (preferred)
// 2. Hash of the model url the inference session was created with
// 3. Hash of the input/output names of the model
// Please find out how to set metadata_props in the onnxruntime API documentation. https://onnxruntime.ai/docs/execution-providers/CoreML-ExecutionProvider.html#configuration-options
static const char* const kCoremlProviderOption_ModelCacheDirectory = "ModelCacheDirectory";

// User provided cache-key in metadata_props.
Expand Down
17 changes: 13 additions & 4 deletions onnxruntime/core/providers/coreml/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,8 @@ void CreateEmptyFile(const std::string& filename) {
#endif // defined(COREML_ENABLE_MLPROGRAM)

std::string GetModelOutputPath(const CoreMLOptions& coreml_options,
const GraphViewer& graph_viewer) {
const GraphViewer& graph_viewer,
const logging::Logger& logger) {
const std::string& subgraph_name = graph_viewer.Name();
std::string path;
if (coreml_options.ModelCacheDirectory().empty()) {
Expand All @@ -411,7 +412,11 @@ std::string GetModelOutputPath(const CoreMLOptions& coreml_options,
std::string_view subgraph_short_name = std::string_view(subgraph_name)
.substr(subgraph_name.find_last_of("_") + 1);
path = MakeString(std::string(coreml_options.ModelCacheDirectory()), "/", cache_key);
ORT_THROW_IF_ERROR(Env::Default().CreateFolder(path));
if (!Env::Default().CreateFolder(path).IsOK()) {
LOGS(logger, WARNING) << "Failed to create cache directory " << path << ". Model caching is disabled.";
coreml_options.DisableModelCache();
return GetModelOutputPath(coreml_options, graph_viewer, logger);
}
// Write the model path to a file in the cache directory.
// This is for developers to know what the cached model is as we used a hash for the directory name.
if (!Env::Default().FileExists(ToPathString(path + "/model.txt"))) {
Expand All @@ -438,7 +443,11 @@ std::string GetModelOutputPath(const CoreMLOptions& coreml_options,
} else {
path += "_nn";
}
ORT_THROW_IF_ERROR(Env::Default().CreateFolder(path));
if (!Env::Default().CreateFolder(path).IsOK()) {
LOGS(logger, WARNING) << "Failed to create cache directory " << path << ". Model caching is disabled.";
coreml_options.DisableModelCache();
return GetModelOutputPath(coreml_options, graph_viewer, logger);
}
path += "/model";
}
return path;
Expand All @@ -454,7 +463,7 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge
coreml_version_(coreml_version),
coreml_options_(coreml_options),
create_ml_program_(coreml_options.CreateMLProgram()),
model_output_path_(GetModelOutputPath(coreml_options, graph_viewer)),
model_output_path_(GetModelOutputPath(coreml_options_, graph_viewer, logger)), // coreml_options_ must be set before this
onnx_input_names_(std::move(onnx_input_names)),
onnx_output_names_(std::move(onnx_output_names)),
coreml_model_(std::make_unique<CoreML::Specification::Model>()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
if (user_provided_key.size() > 64 ||
std::any_of(user_provided_key.begin(), user_provided_key.end(),
[](unsigned char c) { return !std::isalnum(c); })) {
user_provided_key = std::to_string(std::hash<std::string>{}(user_provided_key));
LOGS(logger, ERROR) << "[" << kCOREML_CACHE_KEY << ":" << user_provided_key << "] is not a valid cache key."
<< " It should be alphanumeric and less than 64 characters.";
}
// invalid cache-key
if (user_provided_key.size() == 0) {
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/providers/coreml/coreml_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class CoreMLOptions {
bool profile_compute_plan_{false};
bool allow_low_precision_accumulation_on_gpu_{false};
// path to store the converted coreml model
std::string model_cache_directory_;
// we may run DisableModelCache() to disable model caching
mutable std::string model_cache_directory_;

Check warning on line 22 in onnxruntime/core/providers/coreml/coreml_options.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/coreml_options.h:22: Add #include <string> for string [build/include_what_you_use] [4]

public:
explicit CoreMLOptions(uint32_t coreml_flags);
Expand All @@ -35,6 +36,8 @@ class CoreMLOptions {
bool ProfileComputePlan() const { return profile_compute_plan_ && create_mlprogram_; }

std::string_view ModelCacheDirectory() const { return model_cache_directory_; }
// mark const as model_cache_directory_ is mutable and we may update it in const functions.
void DisableModelCache() const { model_cache_directory_.clear(); }

private:
void ValidateAndParseProviderOption(const ProviderOptions& options);
Expand Down

0 comments on commit d7b867c

Please sign in to comment.