diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index b942c06e9540a..ed712cf00e096 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -98,6 +98,17 @@ class OpKernel { return Status::OK(); } + // Override this function to return a list of attributes the session can safely remove + // after it is intialized and saved. This option is useful to reduce memory usage + // when the kernel does not reuse the operator attributes but copies them. + // All attributes returned by this method will be removed by method + // PruneRemovableAttributes of they exists. + // @param removable_attributes set of attributes the session can safely remove. + virtual Status GetRemovableAttributes(InlinedVector& removable_attributes) const { + removable_attributes.clear(); + return Status::OK(); + } + // Override this function to use provided pre-packed weight. // Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, // int input_idx, diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 280f4e1a603b7..133014f9897ab 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -407,6 +407,12 @@ class Node { bool ClearAttribute(const std::string& attr_name); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + /** + * Clears removable attributes. These are no longer needed after the initialization + * of the session. The function returns the number of removed attributes. + */ + int PruneRemovableAttributes(gsl::span removable_attributes); + #if !defined(ORT_MINIMAL_BUILD) /** Gets the Node's mutable attributes. */ NodeAttributes& GetMutableAttributes() noexcept { return attributes_; } @@ -560,7 +566,7 @@ class Node { // NOTE: This friendship relationship should ONLY be used for calling methods of the Node class and not accessing // the data members directly, so that the Node can maintain its internal invariants. friend class Graph; - Node(NodeIndex index, Graph& graph) : index_(index), graph_(&graph) {} + Node(NodeIndex index, Graph& graph) : index_(index), graph_(&graph), can_be_saved_(true) {} private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Node); @@ -651,6 +657,9 @@ class Node { // Graph instances for subgraphs that are owned by this Node std::vector> subgraphs_; + + // Can be saved? The node cannot be saved anymore if removable attributes have been cleared. + bool can_be_saved_; }; /** diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 66699c907fc5e..1f19fd86a37b0 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -245,6 +245,39 @@ Status SessionState::CreateKernels(const KernelRegistryManager& kernel_registry_ return Status::OK(); } +void SessionState::PruneRemovableAttributes() { + InlinedVector removable_attributes; + for (size_t i = 0; i < session_kernels_.size(); ++i) { + if (session_kernels_[i].get() == nullptr) + continue; + auto status = session_kernels_[i].get()->GetRemovableAttributes(removable_attributes); + if (!status.IsOK()) { + const Node& node_const = session_kernels_[i].get()->Node(); + LOGS(logger_, WARNING) << "failed at retrieving the removable attributes" + << "for node '" << node_const.Name() << "' ('" << node_const.OpType() << "')."; + continue; + } + if (removable_attributes.empty()) + continue; + auto index = session_kernels_[i].get()->Node().Index(); + Node* node = graph_.GetNode(index); + int n_removed = node->PruneRemovableAttributes(removable_attributes); + if (n_removed == 0) + continue; + LOGS(logger_, INFO) << "removed " << n_removed << " removable attributes " + << "for node '" << node->Name() << "' ('" << node->OpType() << "'), " + << "among attributes: " << [removable_attributes]() -> std::string{ + std::ostringstream os; + for(auto it = removable_attributes.cbegin(); it != removable_attributes.cend(); ++it) { + if (it != removable_attributes.cbegin()) + os << ", "; + os << *it; + } + return os.str(); + }() << "."; + } +} + const SequentialExecutionPlan* SessionState::GetExecutionPlan() const { if (!p_seq_exec_plan_.has_value()) { return nullptr; @@ -1377,10 +1410,10 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_stringInit(GetExecutionPlan(), GetOrtValueNameIdxMap()); diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 6028c19d3c3ba..07a26ffebf7a0 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -307,6 +307,10 @@ class SessionState { return parent_; } + // Clear all removable attributes if they exists. + // The function logs the list of removable attributes for every node. + void PruneRemovableAttributes(); + size_t GetNumberOfPrepacksCounter() const { return number_of_prepacks_counter_; } diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index c47569b5ca5c2..af2ff37c4cf0a 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -614,6 +614,11 @@ void Node::ToProto(NodeProto& proto, bool update_subgraphs) const { if (!description_.empty()) proto.set_doc_string(description_); + // Checks an attribute was not removed. + if (!can_be_saved_) { + ORT_THROW("Removable attributes were removed before the conversion is started."); + } + // Set attributes. proto.clear_attribute(); for (const auto& attribute : attributes_) { @@ -666,6 +671,11 @@ Status Node::SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, auto input_arg_counts = builder.CreateVector(definitions_.input_arg_count); auto implicit_inputs = GetNodeArgsOrtFormat(definitions_.implicit_input_defs); + // Checks an attribute was not removed. + if (!can_be_saved_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Removable attributes were removed before the node is saved."); + } + // Node attributes std::vector> attributes_vec; attributes_vec.reserve(attributes_.size()); @@ -835,6 +845,7 @@ void Node::Init(const std::string& name, definitions_.input_defs = input_args; definitions_.output_defs = output_args; domain_ = domain; + can_be_saved_ = true; priority_ = 0; if (kOnnxDomainAlias == domain_) { domain_ = kOnnxDomain; @@ -947,6 +958,17 @@ bool Node::ClearAttribute(const std::string& attr_name) { } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +int Node::PruneRemovableAttributes(gsl::span removable_attributes) { + graph_->SetGraphResolveNeeded(); + graph_->SetGraphProtoSyncNeeded(); + int n_removed = 0; + for (const auto& name : removable_attributes) { + n_removed += static_cast(attributes_.erase(name)); + } + can_be_saved_ = can_be_saved_ && n_removed == 0; + return n_removed; +} + #if !defined(ORT_MINIMAL_BUILD) Status Node::UpdateInputArgCount() { // The node refers to a primitive operator. diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc b/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc index addbd81244eb7..8e5444a64bd5a 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc @@ -41,6 +41,22 @@ TreeEnsembleClassifier::TreeEnsembleClassifier(const OpKernelInfo& info) : Op ORT_THROW_IF_ERROR(p_tree_ensemble_->Init(info)); } +template +Status TreeEnsembleClassifier::GetRemovableAttributes(InlinedVector& removable_attributes) const { + InlinedVector names {"base_values", "nodes_falsenodeids", "nodes_featureids", "nodes_hitrates", + "nodes_missing_value_tracks_true", "nodes_modes", "nodes_nodeids", "nodes_treeids", + "nodes_truenodeids", "nodes_values", "class_ids", "class_treeids", "class_nodeids", + "class_weights", "classlabels_strings", "classlabels_int64s" +#if !defined(ORT_MINIMAL_BUILD) + "base_values_as_tensor", "nodes_hitrates_as_tensor", "nodes_values_as_tensor", + "class_weights_as_tensor" +#endif + }; + removable_attributes.swap(names); + return Status::OK(); +} + + template common::Status TreeEnsembleClassifier::Compute(OpKernelContext* context) const { const Tensor& X = *context->Input(0); diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.h index 9efa0339261f0..13f2b24408f1d 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.h @@ -13,6 +13,7 @@ class TreeEnsembleClassifier final : public OpKernel { public: explicit TreeEnsembleClassifier(const OpKernelInfo& info); common::Status Compute(OpKernelContext* context) const override; + Status GetRemovableAttributes(InlinedVector& removable_attributes) const override; private: // Following pointer holds a pointer on one instance of diff --git a/onnxruntime/core/providers/cpu/ml/treeregressor.cc b/onnxruntime/core/providers/cpu/ml/treeregressor.cc index 7ea0d70ef8121..fb4bcdc80b86a 100644 --- a/onnxruntime/core/providers/cpu/ml/treeregressor.cc +++ b/onnxruntime/core/providers/cpu/ml/treeregressor.cc @@ -47,6 +47,20 @@ TreeEnsembleRegressor::TreeEnsembleRegressor(const OpKernelInfo& info) : OpKe ORT_THROW_IF_ERROR(p_tree_ensemble_->Init(info)); } +template +Status TreeEnsembleRegressor::GetRemovableAttributes(InlinedVector& removable_attributes) const { + InlinedVector names {"base_values", "nodes_falsenodeids", "nodes_featureids", "nodes_hitrates", + "nodes_missing_value_tracks_true", "nodes_modes", "nodes_nodeids", "nodes_treeids", + "nodes_truenodeids", "nodes_values", + "target_ids", "target_treeids", "target_nodeids", "target_weights" +#if !defined(ORT_MINIMAL_BUILD) + "base_values_as_tensor", "nodes_hitrates_as_tensor", "nodes_values_as_tensor", + "class_weights_as_tensor" +#endif + }; + removable_attributes.swap(names); + return Status::OK(); +} template common::Status TreeEnsembleRegressor::Compute(OpKernelContext* context) const { diff --git a/onnxruntime/core/providers/cpu/ml/treeregressor.h b/onnxruntime/core/providers/cpu/ml/treeregressor.h index 3547089f8baab..ebc831fbe94ec 100644 --- a/onnxruntime/core/providers/cpu/ml/treeregressor.h +++ b/onnxruntime/core/providers/cpu/ml/treeregressor.h @@ -13,6 +13,7 @@ class TreeEnsembleRegressor final : public OpKernel { public: explicit TreeEnsembleRegressor(const OpKernelInfo& info); common::Status Compute(OpKernelContext* context) const override; + Status GetRemovableAttributes(InlinedVector& removable_attributes) const override; private: // Following pointer holds a pointer on one instance of diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 11995da5be1dc..e07e6a2d6ba0f 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1522,12 +1522,16 @@ common::Status InferenceSession::Initialize() { std::vector().swap(ort_format_model_bytes_data_holder_); } + // once the model is saved, we may remove unnecessary attributes for inference + session_state_->PruneRemovableAttributes(); + // and log telemetry bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); env.GetTelemetryProvider().LogSessionCreation( session_id_, model_->IrVersion(), model_->ProducerName(), model_->ProducerVersion(), model_->Domain(), model_->MainGraph().DomainToVersionMap(), model_->MainGraph().Name(), model_->MetaData(), telemetry_.event_name_, execution_providers_.GetIds(), model_has_fp16_inputs); + LOGS(*session_logger_, INFO) << "Session successfully initialized."; } ORT_CATCH(const NotImplementedException& ex) {