From 45b7c41ef04fa91cc41a1726d3ee4c0bb9aabaa8 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Date: Fri, 2 Aug 2024 19:19:04 -0400 Subject: [PATCH] [MIGraphX EP] Set External Data Path (#21598) ### Description Changes to add in Set external data path for model weight files. Additional fixes to ensure this compiles off the latest v1.19 Onnxruntime ### Motivation and Context Separate weights used for larger models (like stable diffusion) is motivation for this change set --------- Co-authored-by: Jeff Daily Co-authored-by: Artur Wojcik Co-authored-by: Ted Themistokleous --- .../core/providers/migraphx/migraphx_execution_provider.cc | 7 +++++-- .../core/providers/migraphx/migraphx_execution_provider.h | 4 +++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 097b16ecde536..314e278695c49 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include "core/providers/shared_library/provider_api.h" #define ORT_API_MANUAL_INIT @@ -990,6 +991,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string onnx_string_buffer; model_proto->SerializeToString(onnx_string_buffer); + model_path_ = graph_viewer.ModelPath(); // dump onnx file if environment var is set if (dump_model_ops_) { @@ -1168,7 +1170,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& auto param_shapes = prog.get_parameter_shapes(); // Add all calibration data read in from int8 table - for (auto& [cal_key, cal_val] : dynamic_range_map) { + for (auto& [cal_key, cal_val] : dynamic_range_map_) { auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); } @@ -1217,7 +1219,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_, - int8_calibration_cache_available_, dynamic_range_map, + int8_calibration_cache_available_, dynamic_range_map_, save_compiled_model_, save_compiled_path_, load_compiled_model_, load_compiled_path_, dump_model_ops_}; *state = p.release(); @@ -1297,6 +1299,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (!input_shape_match) { if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling" << std::endl; + cmp_options.set_external_data_path(model_path_.has_parent_path() ? model_path_.parent_path().string() : std::filesystem::current_path().string()); prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); // Read in the calibration data and map it to an migraphx paramater map for the calibration ops diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index f34ca320d0a5a..21b582de8f86e 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -11,6 +11,7 @@ #include #include +#include namespace onnxruntime { @@ -91,7 +92,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { bool int8_calibration_cache_available_ = false; bool int8_use_native_migraphx_calibration_table_ = false; std::string calibration_cache_path_; - std::unordered_map dynamic_range_map; + std::unordered_map dynamic_range_map_; bool save_compiled_model_ = false; std::string save_compiled_path_; bool load_compiled_model_ = false; @@ -100,6 +101,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { migraphx::target t_; OrtMutex mgx_mu_; hipStream_t stream_ = nullptr; + mutable std::filesystem::path model_path_; std::unordered_map map_progs_; std::unordered_map map_onnx_string_;