From d7b867cd233e430490d0029757e300251c978bc4 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Thu, 19 Dec 2024 15:26:27 +0800 Subject: [PATCH] disable caching in runtime. --- .../providers/coreml/coreml_provider_factory.h | 3 ++- .../providers/coreml/builders/model_builder.cc | 17 +++++++++++++---- .../coreml/coreml_execution_provider.cc | 3 ++- .../core/providers/coreml/coreml_options.h | 5 ++++- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h index eedf76b5542b6..351eafc2a4675 100644 --- a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h +++ b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h @@ -66,12 +66,13 @@ static const char* const kCoremlProviderOption_AllowLowPrecisionAccumulationOnGP // If this path is not specified, the model will be saved to a temp directory and deleted after the session is closed. // otherwise, the model will be saved to the specified path and User should manage to delete the model. -// we do NOT detect if the onnx model has changed and no longer matches the cached model. +// we do NOT detect if the onnx model has changed and no longer matches the cached model. // the user should carefully manage the cache if modifying/replacing a model. // The cache key is generated by // 1. User provided key in metadata_props if found (preferred) // 2. Hash of the model url the inference session was created with // 3. Hash of the input/output names of the model +// Please find out how to set metadata_props in the onnxruntime API documentation. https://onnxruntime.ai/docs/execution-providers/CoreML-ExecutionProvider.html#configuration-options static const char* const kCoremlProviderOption_ModelCacheDirectory = "ModelCacheDirectory"; // User provided cache-key in metadata_props. diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index 952529a8a918d..fbd591a0e809a 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -391,7 +391,8 @@ void CreateEmptyFile(const std::string& filename) { #endif // defined(COREML_ENABLE_MLPROGRAM) std::string GetModelOutputPath(const CoreMLOptions& coreml_options, - const GraphViewer& graph_viewer) { + const GraphViewer& graph_viewer, + const logging::Logger& logger) { const std::string& subgraph_name = graph_viewer.Name(); std::string path; if (coreml_options.ModelCacheDirectory().empty()) { @@ -411,7 +412,11 @@ std::string GetModelOutputPath(const CoreMLOptions& coreml_options, std::string_view subgraph_short_name = std::string_view(subgraph_name) .substr(subgraph_name.find_last_of("_") + 1); path = MakeString(std::string(coreml_options.ModelCacheDirectory()), "/", cache_key); - ORT_THROW_IF_ERROR(Env::Default().CreateFolder(path)); + if (!Env::Default().CreateFolder(path).IsOK()) { + LOGS(logger, WARNING) << "Failed to create cache directory " << path << ". Model caching is disabled."; + coreml_options.DisableModelCache(); + return GetModelOutputPath(coreml_options, graph_viewer, logger); + } // Write the model path to a file in the cache directory. // This is for developers to know what the cached model is as we used a hash for the directory name. if (!Env::Default().FileExists(ToPathString(path + "/model.txt"))) { @@ -438,7 +443,11 @@ std::string GetModelOutputPath(const CoreMLOptions& coreml_options, } else { path += "_nn"; } - ORT_THROW_IF_ERROR(Env::Default().CreateFolder(path)); + if (!Env::Default().CreateFolder(path).IsOK()) { + LOGS(logger, WARNING) << "Failed to create cache directory " << path << ". Model caching is disabled."; + coreml_options.DisableModelCache(); + return GetModelOutputPath(coreml_options, graph_viewer, logger); + } path += "/model"; } return path; @@ -454,7 +463,7 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge coreml_version_(coreml_version), coreml_options_(coreml_options), create_ml_program_(coreml_options.CreateMLProgram()), - model_output_path_(GetModelOutputPath(coreml_options, graph_viewer)), + model_output_path_(GetModelOutputPath(coreml_options_, graph_viewer, logger)), // coreml_options_ must be set before this onnx_input_names_(std::move(onnx_input_names)), onnx_output_names_(std::move(onnx_output_names)), coreml_model_(std::make_unique()) { diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index 8367bc6dbbb62..096853b9b07ee 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -68,7 +68,8 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie if (user_provided_key.size() > 64 || std::any_of(user_provided_key.begin(), user_provided_key.end(), [](unsigned char c) { return !std::isalnum(c); })) { - user_provided_key = std::to_string(std::hash{}(user_provided_key)); + LOGS(logger, ERROR) << "[" << kCOREML_CACHE_KEY << ":" << user_provided_key << "] is not a valid cache key." + << " It should be alphanumeric and less than 64 characters."; } // invalid cache-key if (user_provided_key.size() == 0) { diff --git a/onnxruntime/core/providers/coreml/coreml_options.h b/onnxruntime/core/providers/coreml/coreml_options.h index 87b6a51f5d9fb..b7e584c36fda5 100644 --- a/onnxruntime/core/providers/coreml/coreml_options.h +++ b/onnxruntime/core/providers/coreml/coreml_options.h @@ -18,7 +18,8 @@ class CoreMLOptions { bool profile_compute_plan_{false}; bool allow_low_precision_accumulation_on_gpu_{false}; // path to store the converted coreml model - std::string model_cache_directory_; + // we may run DisableModelCache() to disable model caching + mutable std::string model_cache_directory_; public: explicit CoreMLOptions(uint32_t coreml_flags); @@ -35,6 +36,8 @@ class CoreMLOptions { bool ProfileComputePlan() const { return profile_compute_plan_ && create_mlprogram_; } std::string_view ModelCacheDirectory() const { return model_cache_directory_; } + // mark const as model_cache_directory_ is mutable and we may update it in const functions. + void DisableModelCache() const { model_cache_directory_.clear(); } private: void ValidateAndParseProviderOption(const ProviderOptions& options);