Skip to content

Commit

Permalink
Introduce RemovableAttributes (microsoft#14868)
Browse files Browse the repository at this point in the history
### Description
TreeEnsemble* kernels fully copies all the parameters from the onnx
graph. Even if they are no longer needed or unused (hitrates), they
remain in memory. For big models >= 200 trees, max_depth > 10, the model
usually weights more than 10 Mb. This change offers a kernel the
possibility to remove all unneeded attributes after they were used to
create the session. Attributes are deleted after the model was possibly
saved, at the of the session creation.

The current design is to be debatted:
* it stored the list of removable attributes in class
`onnxruntime::Node`,
* the node is marked as `const` everytime this implementation needs to
register the name of a removable attribute or to remove them.

The current implementation is just a POC as it needs to cast
`onnxruntime::Node*` into `const onnxruntime::Node*`.

Should we keep the list of removable attributes in `onnxruntime::Node`?

### Motivation and Context
Motivation is mostly to reduce memory consumption.

---------

Signed-off-by: xadupre <[email protected]>
  • Loading branch information
xadupre authored Mar 7, 2023
1 parent be1416d commit 5930e7e
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 4 deletions.
11 changes: 11 additions & 0 deletions include/onnxruntime/core/framework/op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,17 @@ class OpKernel {
return Status::OK();
}

// Override this function to return a list of attributes the session can safely remove
// after it is intialized and saved. This option is useful to reduce memory usage
// when the kernel does not reuse the operator attributes but copies them.
// All attributes returned by this method will be removed by method
// PruneRemovableAttributes of they exists.
// @param removable_attributes set of attributes the session can safely remove.
virtual Status GetRemovableAttributes(InlinedVector<std::string>& removable_attributes) const {
removable_attributes.clear();
return Status::OK();
}

// Override this function to use provided pre-packed weight.
// Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
// int input_idx,
Expand Down
11 changes: 10 additions & 1 deletion include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,12 @@ class Node {
bool ClearAttribute(const std::string& attr_name);
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

/**
* Clears removable attributes. These are no longer needed after the initialization
* of the session. The function returns the number of removed attributes.
*/
int PruneRemovableAttributes(gsl::span<const std::string> removable_attributes);

#if !defined(ORT_MINIMAL_BUILD)
/** Gets the Node's mutable attributes. */
NodeAttributes& GetMutableAttributes() noexcept { return attributes_; }
Expand Down Expand Up @@ -560,7 +566,7 @@ class Node {
// NOTE: This friendship relationship should ONLY be used for calling methods of the Node class and not accessing
// the data members directly, so that the Node can maintain its internal invariants.
friend class Graph;
Node(NodeIndex index, Graph& graph) : index_(index), graph_(&graph) {}
Node(NodeIndex index, Graph& graph) : index_(index), graph_(&graph), can_be_saved_(true) {}

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Node);
Expand Down Expand Up @@ -651,6 +657,9 @@ class Node {

// Graph instances for subgraphs that are owned by this Node
std::vector<std::unique_ptr<Graph>> subgraphs_;

// Can be saved? The node cannot be saved anymore if removable attributes have been cleared.
bool can_be_saved_;
};

/**
Expand Down
39 changes: 36 additions & 3 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,39 @@ Status SessionState::CreateKernels(const KernelRegistryManager& kernel_registry_
return Status::OK();
}

void SessionState::PruneRemovableAttributes() {
InlinedVector<std::string> removable_attributes;
for (size_t i = 0; i < session_kernels_.size(); ++i) {
if (session_kernels_[i].get() == nullptr)
continue;
auto status = session_kernels_[i].get()->GetRemovableAttributes(removable_attributes);
if (!status.IsOK()) {
const Node& node_const = session_kernels_[i].get()->Node();
LOGS(logger_, WARNING) << "failed at retrieving the removable attributes"
<< "for node '" << node_const.Name() << "' ('" << node_const.OpType() << "').";
continue;
}
if (removable_attributes.empty())
continue;
auto index = session_kernels_[i].get()->Node().Index();
Node* node = graph_.GetNode(index);
int n_removed = node->PruneRemovableAttributes(removable_attributes);
if (n_removed == 0)
continue;
LOGS(logger_, INFO) << "removed " << n_removed << " removable attributes "
<< "for node '" << node->Name() << "' ('" << node->OpType() << "'), "
<< "among attributes: " << [removable_attributes]() -> std::string{
std::ostringstream os;
for(auto it = removable_attributes.cbegin(); it != removable_attributes.cend(); ++it) {
if (it != removable_attributes.cbegin())
os << ", ";
os << *it;
}
return os.str();
}() << ".";
}
}

const SequentialExecutionPlan* SessionState::GetExecutionPlan() const {
if (!p_seq_exec_plan_.has_value()) {
return nullptr;
Expand Down Expand Up @@ -1377,10 +1410,10 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
p_seq_exec_plan_);
ORT_RETURN_IF_ERROR(status);

// Record the allocation plan
// Record the allocation plan

// Uncomment the below to dump the allocation plan to std::cout
// LOGS(logger_, VERBOSE) << std::make_pair(p_seq_exec_plan_.get(), this);
// Uncomment the below to dump the allocation plan to std::cout
// LOGS(logger_, VERBOSE) << std::make_pair(p_seq_exec_plan_.get(), this);

#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
GetMemoryProfiler()->Init(GetExecutionPlan(), GetOrtValueNameIdxMap());
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/framework/session_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ class SessionState {
return parent_;
}

// Clear all removable attributes if they exists.
// The function logs the list of removable attributes for every node.
void PruneRemovableAttributes();

size_t GetNumberOfPrepacksCounter() const {
return number_of_prepacks_counter_;
}
Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,11 @@ void Node::ToProto(NodeProto& proto, bool update_subgraphs) const {
if (!description_.empty())
proto.set_doc_string(description_);

// Checks an attribute was not removed.
if (!can_be_saved_) {
ORT_THROW("Removable attributes were removed before the conversion is started.");
}

// Set attributes.
proto.clear_attribute();
for (const auto& attribute : attributes_) {
Expand Down Expand Up @@ -666,6 +671,11 @@ Status Node::SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder,
auto input_arg_counts = builder.CreateVector(definitions_.input_arg_count);
auto implicit_inputs = GetNodeArgsOrtFormat(definitions_.implicit_input_defs);

// Checks an attribute was not removed.
if (!can_be_saved_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Removable attributes were removed before the node is saved.");
}

// Node attributes
std::vector<flatbuffers::Offset<fbs::Attribute>> attributes_vec;
attributes_vec.reserve(attributes_.size());
Expand Down Expand Up @@ -835,6 +845,7 @@ void Node::Init(const std::string& name,
definitions_.input_defs = input_args;
definitions_.output_defs = output_args;
domain_ = domain;
can_be_saved_ = true;
priority_ = 0;
if (kOnnxDomainAlias == domain_) {
domain_ = kOnnxDomain;
Expand Down Expand Up @@ -947,6 +958,17 @@ bool Node::ClearAttribute(const std::string& attr_name) {
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

int Node::PruneRemovableAttributes(gsl::span<const std::string> removable_attributes) {
graph_->SetGraphResolveNeeded();
graph_->SetGraphProtoSyncNeeded();
int n_removed = 0;
for (const auto& name : removable_attributes) {
n_removed += static_cast<int>(attributes_.erase(name));
}
can_be_saved_ = can_be_saved_ && n_removed == 0;
return n_removed;
}

#if !defined(ORT_MINIMAL_BUILD)
Status Node::UpdateInputArgCount() {
// The node refers to a primitive operator.
Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ TreeEnsembleClassifier<T>::TreeEnsembleClassifier(const OpKernelInfo& info) : Op
ORT_THROW_IF_ERROR(p_tree_ensemble_->Init(info));
}

template <typename T>
Status TreeEnsembleClassifier<T>::GetRemovableAttributes(InlinedVector<std::string>& removable_attributes) const {
InlinedVector<std::string> names {"base_values", "nodes_falsenodeids", "nodes_featureids", "nodes_hitrates",
"nodes_missing_value_tracks_true", "nodes_modes", "nodes_nodeids", "nodes_treeids",
"nodes_truenodeids", "nodes_values", "class_ids", "class_treeids", "class_nodeids",
"class_weights", "classlabels_strings", "classlabels_int64s"
#if !defined(ORT_MINIMAL_BUILD)
"base_values_as_tensor", "nodes_hitrates_as_tensor", "nodes_values_as_tensor",
"class_weights_as_tensor"
#endif
};
removable_attributes.swap(names);
return Status::OK();
}


template <typename T>
common::Status TreeEnsembleClassifier<T>::Compute(OpKernelContext* context) const {
const Tensor& X = *context->Input<Tensor>(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class TreeEnsembleClassifier final : public OpKernel {
public:
explicit TreeEnsembleClassifier(const OpKernelInfo& info);
common::Status Compute(OpKernelContext* context) const override;
Status GetRemovableAttributes(InlinedVector<std::string>& removable_attributes) const override;

private:
// Following pointer holds a pointer on one instance of
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/cpu/ml/treeregressor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ TreeEnsembleRegressor<T>::TreeEnsembleRegressor(const OpKernelInfo& info) : OpKe
ORT_THROW_IF_ERROR(p_tree_ensemble_->Init(info));
}

template <typename T>
Status TreeEnsembleRegressor<T>::GetRemovableAttributes(InlinedVector<std::string>& removable_attributes) const {
InlinedVector<std::string> names {"base_values", "nodes_falsenodeids", "nodes_featureids", "nodes_hitrates",
"nodes_missing_value_tracks_true", "nodes_modes", "nodes_nodeids", "nodes_treeids",
"nodes_truenodeids", "nodes_values",
"target_ids", "target_treeids", "target_nodeids", "target_weights"
#if !defined(ORT_MINIMAL_BUILD)
"base_values_as_tensor", "nodes_hitrates_as_tensor", "nodes_values_as_tensor",
"class_weights_as_tensor"
#endif
};
removable_attributes.swap(names);
return Status::OK();
}

template <typename T>
common::Status TreeEnsembleRegressor<T>::Compute(OpKernelContext* context) const {
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cpu/ml/treeregressor.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class TreeEnsembleRegressor final : public OpKernel {
public:
explicit TreeEnsembleRegressor(const OpKernelInfo& info);
common::Status Compute(OpKernelContext* context) const override;
Status GetRemovableAttributes(InlinedVector<std::string>& removable_attributes) const override;

private:
// Following pointer holds a pointer on one instance of
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1522,12 +1522,16 @@ common::Status InferenceSession::Initialize() {
std::vector<uint8_t>().swap(ort_format_model_bytes_data_holder_);
}

// once the model is saved, we may remove unnecessary attributes for inference
session_state_->PruneRemovableAttributes();

// and log telemetry
bool model_has_fp16_inputs = ModelHasFP16Inputs(graph);
env.GetTelemetryProvider().LogSessionCreation(
session_id_, model_->IrVersion(), model_->ProducerName(), model_->ProducerVersion(), model_->Domain(),
model_->MainGraph().DomainToVersionMap(), model_->MainGraph().Name(), model_->MetaData(),
telemetry_.event_name_, execution_providers_.GetIds(), model_has_fp16_inputs);

LOGS(*session_logger_, INFO) << "Session successfully initialized.";
}
ORT_CATCH(const NotImplementedException& ex) {
Expand Down

0 comments on commit 5930e7e

Please sign in to comment.