Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce reused_buffer_index_per_stream in allocation planner which will be reset after computing the reuse buffer for each stream #19515

Merged
merged 10 commits into from
Feb 22, 2024
57 changes: 28 additions & 29 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@
// upstream_node_0 and upstream_node_1 are the immmediate upstream nodes of downstream_node
// upstream_node_2 is the immediate nodes ahead of downstream_node in the same logic stream
InlinedHashMap<onnxruntime::NodeIndex, InlinedHashSet<onnxruntime::NodeIndex>> dependence_graph_;
InlinedHashMap<onnxruntime::OrtValueIndex, InlinedHashSet<onnxruntime::NodeIndex>> value_consumer_map_;
InlinedHashMap<onnxruntime::OrtValueIndex, onnxruntime::NodeIndex> value_node_map_;

// OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation:
Expand All @@ -200,6 +199,9 @@
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
OrtValueIndex inplace_reused_buffer_index = -1; // index of original buffer to reuse inplace
#endif
// reused_buffer_index_per_stream will be reset and used in each ComputeSingleStreamReusePlan()
// reused_buffer_index will be updated with reused_buffer_index_per_stream and preserved for the following GenerateDeallocationPlan()

Check warning on line 203 in onnxruntime/core/framework/allocation_planner.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/allocation_planner.cc:203: Lines should be <= 120 characters long [whitespace/line_length] [2]
OrtValueIndex reused_buffer_index_per_stream;
jslhcl marked this conversation as resolved.
Show resolved Hide resolved
};

// ort_value_info_ is indexed by an OrtValueIndex
Expand Down Expand Up @@ -277,6 +279,7 @@
OrtValueIndex original = Buffer(reused);
// record that the new buffer will reuse that original buffer
Buffer(reused_for) = original;
ort_value_info_[reused_for].reused_buffer_index_per_stream = original;
// adjust original buffer's usecount
UseCount(original) += UseCount(reused_for);

Expand All @@ -295,7 +298,7 @@
}
#endif

// Find if there exists some input tensor that we can use in-place for output_arg_num-th input in the node.
// Find if there exists some input tensor that we can use in-place for output_arg_num-th output in the node.
bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input,
bool* is_strided_tensor) {
*is_strided_tensor = false;
Expand Down Expand Up @@ -357,7 +360,7 @@
auto p_input_arg = input_args[pair.first];
if (p_input_arg->Exists()) {
auto input_arg_index = Index(p_input_arg->Name());
auto original = Buffer(input_arg_index);
auto original = ort_value_info_[input_arg_index].reused_buffer_index_per_stream;
if (1 == UseCount(original)) {
if (SameSize(*p_input_arg, *p_output_arg)) {
// we can reuse this input since it is its last use and permitted for in-place update
Expand Down Expand Up @@ -1065,7 +1068,8 @@

// build the consumer list for each value
int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1;
value_consumer_map_.reserve(num_ml_values);
InlinedHashMap<onnxruntime::OrtValueIndex, InlinedHashSet<onnxruntime::NodeIndex>> value_consumer_map;
jslhcl marked this conversation as resolved.
Show resolved Hide resolved
value_consumer_map.reserve(num_ml_values);

// iterate each stream from back, so the first element is the last consumer in single stream case
for (auto& stream : stream_nodes_) {
Expand All @@ -1081,7 +1085,7 @@
auto origin = Buffer(value_idx);
if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) {
// add current node as consumer for origin buffer
value_consumer_map_[origin].insert(node_index);
value_consumer_map[origin].insert(node_index);
}
}
return Status::OK();
Expand Down Expand Up @@ -1138,8 +1142,8 @@
std::cout << p_input_arg->Name() << " reused by " << p_output_arg->Name() << " as input" << std::endl;
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = reusable_input;
value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(),
value_consumer_map_[output_idx_global].end());
value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(),
value_consumer_map[output_idx_global].end());
reused.insert(reusable_input);
found_reusable = true;
break;
Expand Down Expand Up @@ -1168,8 +1172,8 @@
allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate) {
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = reusable_input;
value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(),
value_consumer_map_[output_idx_global].end());
value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(),
value_consumer_map[output_idx_global].end());
reused.insert(reusable_input);
continue;
} // if
Expand All @@ -1187,11 +1191,11 @@
OrtValueIndex input_arg_index{};
if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() &&
allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) {
if (value_consumer_map_[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) {
if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) {
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = input_arg_index;
value_consumer_map_[input_arg_index].insert(value_consumer_map_[output_idx_global].begin(),
value_consumer_map_[output_idx_global].end());
value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(),
value_consumer_map[output_idx_global].end());
reused.insert(input_arg_index);
}
}
Expand Down Expand Up @@ -1266,7 +1270,7 @@
}

bool all_covered = true;
for (auto consumer : value_consumer_map_[output_idx_global]) {
for (auto consumer : value_consumer_map[output_idx_global]) {
if (deps->find(consumer) == deps->end()) {
all_covered = false;
break;
Expand All @@ -1277,9 +1281,9 @@
allocation_plan[downstream_value].reused_buffer = output_idx_global;
get_reused = true;
// add new consumer for the value to be reused
value_consumer_map_[output_idx_global].insert(value_node_map_[downstream_value]);
value_consumer_map_[output_idx_global].insert(value_consumer_map_[downstream_value].begin(),
value_consumer_map_[downstream_value].end());
value_consumer_map[output_idx_global].insert(value_node_map_[downstream_value]);
value_consumer_map[output_idx_global].insert(value_consumer_map[downstream_value].begin(),
value_consumer_map[downstream_value].end());
node_iter = size_iter->second.erase(node_iter);
if (size_iter->second.empty()) {
local_iter->second.erase(size_iter);
Expand Down Expand Up @@ -1336,21 +1340,16 @@
// use parallel execution context to generate a baseline first (no memory sharing)
context_ = gsl::not_null<const ISequentialPlannerContext*>(&no_mem_reuse_context);
}
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
jslhcl marked this conversation as resolved.
Show resolved Hide resolved
// copy the use counts to a vector, before computing reuse
std::vector<int> ort_value_usecount;
ort_value_usecount.reserve(ort_value_info_.size());
#endif
ORT_RETURN_IF_ERROR(ComputeReuseCount());
for (auto& ort_value_info : ort_value_info_) ort_value_usecount.push_back(ort_value_info.usecount);
for (size_t i = 0; i < stream_nodes_.size(); ++i) {
// compute use count first
ORT_RETURN_IF_ERROR(ComputeReuseCount());
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
if (i == 0) {
for (auto ort_value_info : ort_value_info_) {
ort_value_usecount.push_back(ort_value_info.usecount);
}
for (size_t j = 0; j < ort_value_info_.size(); j++) ort_value_info_[j].reused_buffer_index_per_stream = static_cast<OrtValueIndex>(j);

Check warning on line 1349 in onnxruntime/core/framework/allocation_planner.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/allocation_planner.cc:1349: Lines should be <= 120 characters long [whitespace/line_length] [2]
if (i > 0) {
for (size_t k = 0; k < ort_value_usecount.size(); k++) UseCount(static_cast<OrtValueIndex>(k)) = ort_value_usecount[k];

Check warning on line 1351 in onnxruntime/core/framework/allocation_planner.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/allocation_planner.cc:1351: Lines should be <= 120 characters long [whitespace/line_length] [2]
}
#endif
ORT_RETURN_IF_ERROR(ComputeSingleStreamReusePlan(i));
ClearUseCount();
freelist_.clear(); // DONOT share freelist across streams
Expand Down Expand Up @@ -1468,7 +1467,7 @@
for (auto node_input : pnode->InputDefs()) {
if (node_input->Exists()) {
auto& sym = node_input->Name();
auto original = Buffer(Index(sym));
auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream;
// The index will be -1 if it's an initializer that was removed as part of a temporary workaround.
// See comments in the OrtValueInfo definition.
if ((original != -1) && (0 == DecrementUseCount(original))) {
Expand All @@ -1480,7 +1479,7 @@
for (auto node_input : pnode->ImplicitInputDefs()) {
if (node_input->Exists()) {
auto& sym = node_input->Name();
auto original = Buffer(Index(sym));
auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream;
// The index will be -1 if it's an initializer that was removed as part of a temporary workaround.
// See comments in the OrtValueInfo definition.
if ((original != -1) && (0 == DecrementUseCount(original))) {
Expand All @@ -1493,7 +1492,7 @@
for (auto node_output : pnode->OutputDefs()) {
if (node_output->Exists()) {
auto& sym = node_output->Name();
auto original = Buffer(Index(sym));
auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream;
// The index will be -1 if it's an initializer that was removed as part of a temporary workaround.
// See comments in the OrtValueInfo definition.
if (0 == DecrementUseCount(original)) {
Expand Down
25 changes: 25 additions & 0 deletions onnxruntime/test/framework/allocation_planner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1974,6 +1974,31 @@
ASSERT_TRUE(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()) == WaitOnEPStep);
}
}

TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) {
SessionOptions sess_opt;
sess_opt.graph_optimization_level = TransformerLevel::Default;

InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/issue_19480.onnx"));
auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider());
status = sess.Load();
status = sess.Initialize();
ASSERT_TRUE(status.IsOK()) << "No crash";
const SequentialExecutionPlan* plan = sess.GetSessionState().GetExecutionPlan();
ASSERT_EQ(plan->allocation_plan[14].alloc_kind, AllocKind::kReuse) << "The input of reshape and gather will reuse the output of shape";

Check warning on line 1988 in onnxruntime/test/framework/allocation_planner_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/framework/allocation_planner_test.cc:1988: Lines should be <= 120 characters long [whitespace/line_length] [2]

int gather_count = 0;
for (size_t i = 0; i < plan->execution_plan[1]->steps_.size(); i++) {
if (strstr(typeid(*(plan->execution_plan[1]->steps_[i])).name(), "LaunchKernelStep")) {
const Node* node = sess.GetSessionState().GetGraphViewer().GetNode(plan->execution_plan[1]->steps_[i]->GetNodeIndex());

Check warning on line 1993 in onnxruntime/test/framework/allocation_planner_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/framework/allocation_planner_test.cc:1993: Lines should be <= 120 characters long [whitespace/line_length] [2]
if (node->OpType() == "Gather")
gather_count++;
else
FAIL() << "CPU stream should contain only gather ops";
}
}
ASSERT_EQ(gather_count, 4) << "4 gather ops are all placed in CPU stream";
}
#endif
} // namespace test
} // namespace onnxruntime
Binary file not shown.
Loading