Skip to content

Commit

Permalink
[VitisAI] update graph_save (microsoft#20979)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Fix the threshold limit


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
We use this to debug the graph after each graph transformation. But, the
const data inside the model is useless.
So, we decided to remove that information to save disk space.

Co-authored-by: Chunye Wang <[email protected]>
  • Loading branch information
BoarQing and Chunye Wang authored Jun 17, 2024
1 parent 11e7a1b commit 0babc33
Showing 1 changed file with 9 additions and 18 deletions.
27 changes: 9 additions & 18 deletions onnxruntime/core/providers/vitisai/imp/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,18 @@ void graph_remove_node(Graph& graph, const NodeInput& node_input) {
}

void graph_save(const Graph& graph, const std::string& filename, const std::string& filename_dat, size_t initializer_size_threshold) {
auto& model = const_cast<Model&>(graph.GetModel());
std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto;

auto model_proto = const_cast<onnxruntime::Model&>(graph.GetModel()).ToProto();
auto graph_proto_subgraph = graph.ToGraphProto();
*model_proto->mutable_graph() = *graph_proto_subgraph;
auto& logger = logging::LoggingManager::DefaultLogger();
auto filename_data_relative_path = std::filesystem::path();
auto model = Model::Create(std::move(*model_proto), ToPathString(filename), nullptr, logger);
if (initializer_size_threshold == std::numeric_limits<size_t>::max()) {
model_proto = model.ToProto();
model_proto = model->ToProto();
} else {
model_proto = model.ToGraphProtoWithExternalInitializers(filename_dat, graph.ModelPath().ToPathString(), initializer_size_threshold);
model_proto = model->ToGraphProtoWithExternalInitializers(filename_dat, ToPathString(filename), initializer_size_threshold);
}
auto& metadata = model.MetaData();
auto& metadata = model->MetaData();
if (!metadata.empty()) {
auto metadata_props = model_proto->mutable_metadata_props();
metadata_props->Clear();
Expand All @@ -121,18 +124,6 @@ void graph_save(const Graph& graph, const std::string& filename, const std::stri
*prop->mutable_value() = m.second;
}
}
// use relative path as data storage.
auto graph_proto = model_proto->mutable_graph();
*graph_proto = *graph.ToGraphProto();
for (int i = 0; i < graph_proto->mutable_initializer()->size(); i++) {
auto mutable_external_data = graph_proto->mutable_initializer()->at(i).mutable_external_data();
for (int j = 0; j < mutable_external_data->size(); j++) {
auto& external_data = mutable_external_data->at(j);
if (*external_data.mutable_key() == "location")
*external_data.mutable_value() = std::filesystem::path(*external_data.mutable_value()).filename().u8string();
}
}

std::fstream output(filename, std::ios::out | std::ios::trunc | std::ios::binary);
bool result = model_proto->SerializeToOstream(output);
output << std::flush;
Expand Down

0 comments on commit 0babc33

Please sign in to comment.