From 0c9d031a8955d83ad20797890fabf6c9ca9666fa Mon Sep 17 00:00:00 2001 From: jslhcl Date: Tue, 13 Feb 2024 17:10:00 -0800 Subject: [PATCH 1/9] Compute reuse count only once and do not clear during every stream --- onnxruntime/core/framework/allocation_planner.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index ea7a6432a7507..6a4e755557f97 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -1341,9 +1341,9 @@ class PlannerImpl { std::vector ort_value_usecount; ort_value_usecount.reserve(ort_value_info_.size()); #endif + ORT_RETURN_IF_ERROR(ComputeReuseCount()); for (size_t i = 0; i < stream_nodes_.size(); ++i) { // compute use count first - ORT_RETURN_IF_ERROR(ComputeReuseCount()); #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) if (i == 0) { for (auto ort_value_info : ort_value_info_) { @@ -1352,7 +1352,7 @@ class PlannerImpl { } #endif ORT_RETURN_IF_ERROR(ComputeSingleStreamReusePlan(i)); - ClearUseCount(); +// ClearUseCount(); freelist_.clear(); // DONOT share freelist across streams } #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) From 20734bac43a1071c29469c7bbe373ac1709ab613 Mon Sep 17 00:00:00 2001 From: jslhcl Date: Wed, 14 Feb 2024 06:46:17 -0800 Subject: [PATCH 2/9] DONOT decrease use count for the reused buffer if it is already 0 --- onnxruntime/core/framework/allocation_planner.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 6a4e755557f97..6512c6916cbf7 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -1341,9 +1341,9 @@ class PlannerImpl { std::vector ort_value_usecount; ort_value_usecount.reserve(ort_value_info_.size()); #endif - ORT_RETURN_IF_ERROR(ComputeReuseCount()); for (size_t i = 0; i < stream_nodes_.size(); ++i) { // compute use count first + ORT_RETURN_IF_ERROR(ComputeReuseCount()); #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) if (i == 0) { for (auto ort_value_info : ort_value_info_) { @@ -1352,7 +1352,7 @@ class PlannerImpl { } #endif ORT_RETURN_IF_ERROR(ComputeSingleStreamReusePlan(i)); -// ClearUseCount(); + ClearUseCount(); freelist_.clear(); // DONOT share freelist across streams } #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) @@ -1471,6 +1471,7 @@ class PlannerImpl { auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. + if (original != -1 && UseCount(original) == 0) continue; if ((original != -1) && (0 == DecrementUseCount(original))) { freelist_.push_front(FreeBufferInfo(original, program_counter)); } @@ -1483,6 +1484,7 @@ class PlannerImpl { auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. + if (original != -1 && UseCount(original) == 0) continue; if ((original != -1) && (0 == DecrementUseCount(original))) { freelist_.push_front(FreeBufferInfo(original, program_counter)); } @@ -1496,6 +1498,7 @@ class PlannerImpl { auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. + if (UseCount(original) == 0) continue; if (0 == DecrementUseCount(original)) { freelist_.push_front(FreeBufferInfo(original, program_counter)); } From 52c28b80bde60a150db052e905fd2cc0750475be Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 14 Feb 2024 13:53:48 -0800 Subject: [PATCH 3/9] reset reused_buffer_index every time computing next stream's reuse plan --- onnxruntime/core/framework/allocation_planner.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 6512c6916cbf7..ded06e95c86f7 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -1353,6 +1353,7 @@ class PlannerImpl { #endif ORT_RETURN_IF_ERROR(ComputeSingleStreamReusePlan(i)); ClearUseCount(); + for (size_t j = 0; j < ort_value_info_.size(); j++) ort_value_info_[j].reused_buffer_index = static_cast(j); freelist_.clear(); // DONOT share freelist across streams } #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) @@ -1471,7 +1472,6 @@ class PlannerImpl { auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. - if (original != -1 && UseCount(original) == 0) continue; if ((original != -1) && (0 == DecrementUseCount(original))) { freelist_.push_front(FreeBufferInfo(original, program_counter)); } @@ -1484,7 +1484,6 @@ class PlannerImpl { auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. - if (original != -1 && UseCount(original) == 0) continue; if ((original != -1) && (0 == DecrementUseCount(original))) { freelist_.push_front(FreeBufferInfo(original, program_counter)); } @@ -1498,7 +1497,6 @@ class PlannerImpl { auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. - if (UseCount(original) == 0) continue; if (0 == DecrementUseCount(original)) { freelist_.push_front(FreeBufferInfo(original, program_counter)); } From 02b7043b6002c56cbb003ab6cd54efd5660d4e04 Mon Sep 17 00:00:00 2001 From: jslhcl Date: Mon, 19 Feb 2024 16:27:40 -0800 Subject: [PATCH 4/9] introduce reused_buffer_index_per_stream --- onnxruntime/core/framework/allocation_planner.cc | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 6512c6916cbf7..a7a4d674512f2 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -200,6 +200,7 @@ class PlannerImpl { #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) OrtValueIndex inplace_reused_buffer_index = -1; // index of original buffer to reuse inplace #endif + OrtValueIndex reused_buffer_index_per_stream; }; // ort_value_info_ is indexed by an OrtValueIndex @@ -277,6 +278,7 @@ class PlannerImpl { OrtValueIndex original = Buffer(reused); // record that the new buffer will reuse that original buffer Buffer(reused_for) = original; + ort_value_info_[reused_for].reused_buffer_index_per_stream = original; // adjust original buffer's usecount UseCount(original) += UseCount(reused_for); @@ -357,7 +359,8 @@ class PlannerImpl { auto p_input_arg = input_args[pair.first]; if (p_input_arg->Exists()) { auto input_arg_index = Index(p_input_arg->Name()); - auto original = Buffer(input_arg_index); + //auto original = Buffer(input_arg_index); + auto original = ort_value_info_[input_arg_index].reused_buffer_index_per_stream; if (1 == UseCount(original)) { if (SameSize(*p_input_arg, *p_output_arg)) { // we can reuse this input since it is its last use and permitted for in-place update @@ -1344,6 +1347,7 @@ class PlannerImpl { for (size_t i = 0; i < stream_nodes_.size(); ++i) { // compute use count first ORT_RETURN_IF_ERROR(ComputeReuseCount()); + for (size_t j = 0; j < ort_value_info_.size(); j++) ort_value_info_[j].reused_buffer_index_per_stream = static_cast(j); #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) if (i == 0) { for (auto ort_value_info : ort_value_info_) { @@ -1468,7 +1472,8 @@ class PlannerImpl { for (auto node_input : pnode->InputDefs()) { if (node_input->Exists()) { auto& sym = node_input->Name(); - auto original = Buffer(Index(sym)); + //auto original = Buffer(Index(sym)); + auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream; // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. if (original != -1 && UseCount(original) == 0) continue; @@ -1481,7 +1486,8 @@ class PlannerImpl { for (auto node_input : pnode->ImplicitInputDefs()) { if (node_input->Exists()) { auto& sym = node_input->Name(); - auto original = Buffer(Index(sym)); + //auto original = Buffer(Index(sym)); + auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream; // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. if (original != -1 && UseCount(original) == 0) continue; @@ -1495,7 +1501,8 @@ class PlannerImpl { for (auto node_output : pnode->OutputDefs()) { if (node_output->Exists()) { auto& sym = node_output->Name(); - auto original = Buffer(Index(sym)); + //auto original = Buffer(Index(sym)); + auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream; // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. if (UseCount(original) == 0) continue; From e99a8cd5cf9c6e8ad0d1d1d88f3d288db7073bc3 Mon Sep 17 00:00:00 2001 From: jslhcl Date: Mon, 19 Feb 2024 16:32:02 -0800 Subject: [PATCH 5/9] undo previous commit for the changes on reused_buffer_index --- onnxruntime/core/framework/allocation_planner.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 91e0467d1c1bd..9594ca2dc6199 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -1357,7 +1357,6 @@ class PlannerImpl { #endif ORT_RETURN_IF_ERROR(ComputeSingleStreamReusePlan(i)); ClearUseCount(); - for (size_t j = 0; j < ort_value_info_.size(); j++) ort_value_info_[j].reused_buffer_index = static_cast(j); freelist_.clear(); // DONOT share freelist across streams } #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) From 8262d0b43694e4b06be0126df2da8713e2eae8de Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Tue, 20 Feb 2024 10:38:40 -0800 Subject: [PATCH 6/9] add test case --- .../core/framework/allocation_planner.cc | 50 ++++++++---------- .../test/framework/allocation_planner_test.cc | 25 +++++++++ .../multi_stream_models/issue_19480.onnx | Bin 0 -> 760 bytes 3 files changed, 46 insertions(+), 29 deletions(-) create mode 100644 onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 9594ca2dc6199..4ca959528bab9 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -182,7 +182,6 @@ class PlannerImpl { // upstream_node_0 and upstream_node_1 are the immmediate upstream nodes of downstream_node // upstream_node_2 is the immediate nodes ahead of downstream_node in the same logic stream InlinedHashMap> dependence_graph_; - InlinedHashMap> value_consumer_map_; InlinedHashMap value_node_map_; // OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation: @@ -200,6 +199,8 @@ class PlannerImpl { #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) OrtValueIndex inplace_reused_buffer_index = -1; // index of original buffer to reuse inplace #endif + // reused_buffer_index_per_stream will be reset and used in each ComputeSingleStreamReusePlan() + // reused_buffer_index will be updated with reused_buffer_index_per_stream and preserved for the following GenerateDeallocationPlan() OrtValueIndex reused_buffer_index_per_stream; }; @@ -297,7 +298,7 @@ class PlannerImpl { } #endif - // Find if there exists some input tensor that we can use in-place for output_arg_num-th input in the node. + // Find if there exists some input tensor that we can use in-place for output_arg_num-th output in the node. bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input, bool* is_strided_tensor) { *is_strided_tensor = false; @@ -359,7 +360,6 @@ class PlannerImpl { auto p_input_arg = input_args[pair.first]; if (p_input_arg->Exists()) { auto input_arg_index = Index(p_input_arg->Name()); - //auto original = Buffer(input_arg_index); auto original = ort_value_info_[input_arg_index].reused_buffer_index_per_stream; if (1 == UseCount(original)) { if (SameSize(*p_input_arg, *p_output_arg)) { @@ -1068,7 +1068,8 @@ class PlannerImpl { // build the consumer list for each value int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1; - value_consumer_map_.reserve(num_ml_values); + InlinedHashMap> value_consumer_map; + value_consumer_map.reserve(num_ml_values); // iterate each stream from back, so the first element is the last consumer in single stream case for (auto& stream : stream_nodes_) { @@ -1084,7 +1085,7 @@ class PlannerImpl { auto origin = Buffer(value_idx); if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) { // add current node as consumer for origin buffer - value_consumer_map_[origin].insert(node_index); + value_consumer_map[origin].insert(node_index); } } return Status::OK(); @@ -1141,8 +1142,8 @@ class PlannerImpl { std::cout << p_input_arg->Name() << " reused by " << p_output_arg->Name() << " as input" << std::endl; allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; - value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); found_reusable = true; break; @@ -1171,8 +1172,8 @@ class PlannerImpl { allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate) { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; - value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); continue; } // if @@ -1190,11 +1191,11 @@ class PlannerImpl { OrtValueIndex input_arg_index{}; if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() && allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) { - if (value_consumer_map_[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { + if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = input_arg_index; - value_consumer_map_[input_arg_index].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(input_arg_index); } } @@ -1269,7 +1270,7 @@ class PlannerImpl { } bool all_covered = true; - for (auto consumer : value_consumer_map_[output_idx_global]) { + for (auto consumer : value_consumer_map[output_idx_global]) { if (deps->find(consumer) == deps->end()) { all_covered = false; break; @@ -1280,9 +1281,9 @@ class PlannerImpl { allocation_plan[downstream_value].reused_buffer = output_idx_global; get_reused = true; // add new consumer for the value to be reused - value_consumer_map_[output_idx_global].insert(value_node_map_[downstream_value]); - value_consumer_map_[output_idx_global].insert(value_consumer_map_[downstream_value].begin(), - value_consumer_map_[downstream_value].end()); + value_consumer_map[output_idx_global].insert(value_node_map_[downstream_value]); + value_consumer_map[output_idx_global].insert(value_consumer_map[downstream_value].begin(), + value_consumer_map[downstream_value].end()); node_iter = size_iter->second.erase(node_iter); if (size_iter->second.empty()) { local_iter->second.erase(size_iter); @@ -1339,22 +1340,16 @@ class PlannerImpl { // use parallel execution context to generate a baseline first (no memory sharing) context_ = gsl::not_null(&no_mem_reuse_context); } -#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) // copy the use counts to a vector, before computing reuse std::vector ort_value_usecount; ort_value_usecount.reserve(ort_value_info_.size()); -#endif + ORT_RETURN_IF_ERROR(ComputeReuseCount()); + for (auto& ort_value_info : ort_value_info_) ort_value_usecount.push_back(ort_value_info.usecount); for (size_t i = 0; i < stream_nodes_.size(); ++i) { - // compute use count first - ORT_RETURN_IF_ERROR(ComputeReuseCount()); for (size_t j = 0; j < ort_value_info_.size(); j++) ort_value_info_[j].reused_buffer_index_per_stream = static_cast(j); -#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) - if (i == 0) { - for (auto ort_value_info : ort_value_info_) { - ort_value_usecount.push_back(ort_value_info.usecount); - } + if (i > 0) { + for (size_t k = 0; k < ort_value_usecount.size(); k++) UseCount(static_cast(k)) = ort_value_usecount[k]; } -#endif ORT_RETURN_IF_ERROR(ComputeSingleStreamReusePlan(i)); ClearUseCount(); freelist_.clear(); // DONOT share freelist across streams @@ -1472,7 +1467,6 @@ class PlannerImpl { for (auto node_input : pnode->InputDefs()) { if (node_input->Exists()) { auto& sym = node_input->Name(); - //auto original = Buffer(Index(sym)); auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream; // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. @@ -1485,7 +1479,6 @@ class PlannerImpl { for (auto node_input : pnode->ImplicitInputDefs()) { if (node_input->Exists()) { auto& sym = node_input->Name(); - //auto original = Buffer(Index(sym)); auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream; // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. @@ -1499,7 +1492,6 @@ class PlannerImpl { for (auto node_output : pnode->OutputDefs()) { if (node_output->Exists()) { auto& sym = node_output->Name(); - //auto original = Buffer(Index(sym)); auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream; // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index d7b1de5c930c5..b40aac8efeae1 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -1974,6 +1974,31 @@ TEST_F(PlannerTest, TestCpuIf) { ASSERT_TRUE(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()) == WaitOnEPStep); } } + +TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) { + SessionOptions sess_opt; + sess_opt.graph_optimization_level = TransformerLevel::Default; + + InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/issue_19480.onnx")); + auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); + status = sess.Load(); + status = sess.Initialize(); + ASSERT_TRUE(status.IsOK()) << "No crash"; + const SequentialExecutionPlan* plan = sess.GetSessionState().GetExecutionPlan(); + ASSERT_EQ(plan->allocation_plan[14].alloc_kind, AllocKind::kReuse) << "The input of reshape and gather will reuse the output of shape"; + + int gather_count = 0; + for (size_t i = 0; i < plan->execution_plan[1]->steps_.size(); i++) { + if (strstr(typeid(*(plan->execution_plan[1]->steps_[i])).name(), "LaunchKernelStep")) { + const Node* node = sess.GetSessionState().GetGraphViewer().GetNode(plan->execution_plan[1]->steps_[i]->GetNodeIndex()); + if (node->OpType() == "Gather") + gather_count++; + else + FAIL() << "CPU stream should contain only gather ops"; + } + } + ASSERT_EQ(gather_count, 4) << "4 gather ops are all placed in CPU stream"; +} #endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx b/onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx new file mode 100644 index 0000000000000000000000000000000000000000..dc7d39206dd49f4ef6daf65b7d58c5b456ecf331 GIT binary patch literal 760 zcmaixKTm@|7>Bw3f%9#m_0_6_F_p#1gab^#v5RqW(2b?J(o1>?g{IKO$uH`6@t{2^ z#m0q%=Y9D7j(aJ^(>&%f;j=_M&c!l&{_evy4DtnEiK$Fin*vE__dm*aU~nQ+XN$p9 zA11aA3GP#64zs z+VGAUzBYVq;6Ud2Mod}g2Tt_Ry!IQoq685v?9X@+FQ7}m2pC{Q_TCzB1Q$v>tF;at zE9X-02LY%OdZ2hTthTjJs;u4R{+GpCSxtg_7ivO}nrK8dbFt05KbWuCO#ReubJg)l W4Oj)N8n}nRI|Tj~OnP7p&wl_GGP8{U literal 0 HcmV?d00001 From 3e50a00ce9b1915384bcc4c37242e5e76839b7bf Mon Sep 17 00:00:00 2001 From: jslhcl Date: Tue, 20 Feb 2024 18:19:24 -0800 Subject: [PATCH 7/9] fix comments --- .../core/framework/allocation_planner.cc | 47 ++++++++++--------- .../test/framework/allocation_planner_test.cc | 30 ++++++++++++ 2 files changed, 55 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 4ca959528bab9..85272100f0b58 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -199,9 +199,6 @@ class PlannerImpl { #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) OrtValueIndex inplace_reused_buffer_index = -1; // index of original buffer to reuse inplace #endif - // reused_buffer_index_per_stream will be reset and used in each ComputeSingleStreamReusePlan() - // reused_buffer_index will be updated with reused_buffer_index_per_stream and preserved for the following GenerateDeallocationPlan() - OrtValueIndex reused_buffer_index_per_stream; }; // ort_value_info_ is indexed by an OrtValueIndex @@ -279,7 +276,6 @@ class PlannerImpl { OrtValueIndex original = Buffer(reused); // record that the new buffer will reuse that original buffer Buffer(reused_for) = original; - ort_value_info_[reused_for].reused_buffer_index_per_stream = original; // adjust original buffer's usecount UseCount(original) += UseCount(reused_for); @@ -360,7 +356,7 @@ class PlannerImpl { auto p_input_arg = input_args[pair.first]; if (p_input_arg->Exists()) { auto input_arg_index = Index(p_input_arg->Name()); - auto original = ort_value_info_[input_arg_index].reused_buffer_index_per_stream; + auto original = Buffer(input_arg_index); if (1 == UseCount(original)) { if (SameSize(*p_input_arg, *p_output_arg)) { // we can reuse this input since it is its last use and permitted for in-place update @@ -533,6 +529,7 @@ class PlannerImpl { // Initialize allocation plan: plan_.allocation_plan.resize(num_ml_values); + for (size_t i = 0; i < num_ml_values; i++) plan_.allocation_plan[i].reused_buffer = static_cast(i); } bool HasExternalOutputs(const Node& node) const { @@ -1082,8 +1079,8 @@ class PlannerImpl { const auto& name = input.Name(); int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); - auto origin = Buffer(value_idx); - if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) { + auto origin = AllocPlan(value_idx).reused_buffer; + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { // add current node as consumer for origin buffer value_consumer_map[origin].insert(node_index); } @@ -1143,7 +1140,7 @@ class PlannerImpl { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), - value_consumer_map[output_idx_global].end()); + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); found_reusable = true; break; @@ -1173,7 +1170,7 @@ class PlannerImpl { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), - value_consumer_map[output_idx_global].end()); + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); continue; } // if @@ -1195,7 +1192,7 @@ class PlannerImpl { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = input_arg_index; value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(), - value_consumer_map[output_idx_global].end()); + value_consumer_map[output_idx_global].end()); reused.insert(input_arg_index); } } @@ -1283,7 +1280,7 @@ class PlannerImpl { // add new consumer for the value to be reused value_consumer_map[output_idx_global].insert(value_node_map_[downstream_value]); value_consumer_map[output_idx_global].insert(value_consumer_map[downstream_value].begin(), - value_consumer_map[downstream_value].end()); + value_consumer_map[downstream_value].end()); node_iter = size_iter->second.erase(node_iter); if (size_iter->second.empty()) { local_iter->second.erase(size_iter); @@ -1340,16 +1337,22 @@ class PlannerImpl { // use parallel execution context to generate a baseline first (no memory sharing) context_ = gsl::not_null(&no_mem_reuse_context); } +#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) // copy the use counts to a vector, before computing reuse std::vector ort_value_usecount; ort_value_usecount.reserve(ort_value_info_.size()); - ORT_RETURN_IF_ERROR(ComputeReuseCount()); - for (auto& ort_value_info : ort_value_info_) ort_value_usecount.push_back(ort_value_info.usecount); +#endif for (size_t i = 0; i < stream_nodes_.size(); ++i) { - for (size_t j = 0; j < ort_value_info_.size(); j++) ort_value_info_[j].reused_buffer_index_per_stream = static_cast(j); - if (i > 0) { - for (size_t k = 0; k < ort_value_usecount.size(); k++) UseCount(static_cast(k)) = ort_value_usecount[k]; + // compute use count first. TODO(leca): call ComputeReuseCount() only once is enough + ORT_RETURN_IF_ERROR(ComputeReuseCount()); + for (size_t j = 0; j < ort_value_info_.size(); j++) Buffer(static_cast(j)) = static_cast(j); +#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) + if (i == 0) { + for (auto ort_value_info : ort_value_info_) { + ort_value_usecount.push_back(ort_value_info.usecount); + } } +#endif ORT_RETURN_IF_ERROR(ComputeSingleStreamReusePlan(i)); ClearUseCount(); freelist_.clear(); // DONOT share freelist across streams @@ -1467,7 +1470,7 @@ class PlannerImpl { for (auto node_input : pnode->InputDefs()) { if (node_input->Exists()) { auto& sym = node_input->Name(); - auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream; + auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. if ((original != -1) && (0 == DecrementUseCount(original))) { @@ -1479,7 +1482,7 @@ class PlannerImpl { for (auto node_input : pnode->ImplicitInputDefs()) { if (node_input->Exists()) { auto& sym = node_input->Name(); - auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream; + auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. if ((original != -1) && (0 == DecrementUseCount(original))) { @@ -1492,7 +1495,7 @@ class PlannerImpl { for (auto node_output : pnode->OutputDefs()) { if (node_output->Exists()) { auto& sym = node_output->Name(); - auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream; + auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. if (0 == DecrementUseCount(original)) { @@ -1692,8 +1695,8 @@ class PlannerImpl { const auto& name = input.Name(); int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); - auto origin = Buffer(value_idx); - if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) { + auto origin = AllocPlan(value_idx).reused_buffer; + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { // add current node as consumer for origin buffer value_consumers[origin].push_back(node_index); } @@ -1888,7 +1891,7 @@ class PlannerImpl { // 2. the consumer is in the same stream(non-cpu device), but it consumes a CPU tensor from an non-shape op. // for example, a resize cuda kernel consumer a tensor from MemCpyToHost cuda kernel on the same stream. // in this case, the FIFO can't guarantee the cpu tensor is ready when resize kernel is launching - OrtDevice::DeviceType output_arg_device = plan_.allocation_plan[output_arg_idx].location.Type(); + OrtDevice::DeviceType output_arg_device = AllocPlan(output_arg_idx).location.Type(); WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, output_arg_device); if ((node_stream_map_[it->Index()] != i || output_arg_device == OrtDevice::CPU) && wait_handle != nullptr) { if (node_to_notification.find(node_index) == node_to_notification.end()) { diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index b40aac8efeae1..dc6467bf736d3 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -1975,6 +1975,36 @@ TEST_F(PlannerTest, TestCpuIf) { } } +// model looks like: +// |-----------> Gather +// |-----------> Gather +// |-----------> Gather +// |-----------> Gather +// Shape ----------------> Reshape --> Shape ------------------> Reshape +// ^ ^ +// InstanceNormalization ----| InstanceNormalization ------| +// +// Python script to create this model: +// def CreateModelFor19480(): +// #shape->reshape->shape->reshape, 4 gather +// graphNodes = [] +// graphNodes.append(h.make_node('Shape', inputs=['shape_input'], outputs=['9'])) +// graphNodes.append(h.make_node('InstanceNormalization', inputs=['in0_input', 'scale0', 'B0'], outputs=['8'])) +// graphNodes.append(h.make_node('Reshape', inputs=['8', '9'], outputs=['Reshape15_output'])) +// graphNodes.append(h.make_node('Shape', inputs=['Reshape15_output'], outputs=['281'])) +// graphNodes.append(h.make_node('InstanceNormalization', inputs=['in1_input', 'scale1', 'B1'], outputs=['293'])) +// graphNodes.append(h.make_node('Reshape', inputs=['293', '281'], outputs=['output0'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices1'], outputs=['output1'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices2'], outputs=['output2'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices3'], outputs=['output3'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices4'], outputs=['output4'])) +// g = h.make_graph(graphNodes, 'issue_19480', [h.make_tensor_value_info('shape_input', tp.FLOAT, ['batch', 128, None, None]), h.make_tensor_value_info('in0_input', tp.FLOAT, ['batch', 32, None]), h.make_tensor_value_info('scale0', tp.FLOAT, [32]), h.make_tensor_value_info('B0', tp.FLOAT, [32]), +// h.make_tensor_value_info('in1_input', tp.FLOAT, ['batch', 32, None]), h.make_tensor_value_info('scale1', tp.FLOAT, [32]), h.make_tensor_value_info('B1', tp.FLOAT, [32]), +// h.make_tensor_value_info('indices1', tp.INT32, []), h.make_tensor_value_info('indices2', tp.INT32, []), h.make_tensor_value_info('indices3', tp.INT32, []), h.make_tensor_value_info('indices4', tp.INT32, [])], +// [h.make_tensor_value_info('output0', tp.FLOAT, None), h.make_tensor_value_info('output1', tp.INT64, None), h.make_tensor_value_info('output2', tp.INT64, None), h.make_tensor_value_info('output3', tp.INT64, None), h.make_tensor_value_info('output4', tp.INT64, None)]) +// model = h.make_model(g, opset_imports=[h.make_operatorsetid("", 17)], producer_name='producer_name') +// onnx.save(model, 'issue_19480.onnx') +// TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) { SessionOptions sess_opt; sess_opt.graph_optimization_level = TransformerLevel::Default; From 45747fd4a2e7215977a1d4ae3dececb824230d2d Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 21 Feb 2024 09:14:44 -0800 Subject: [PATCH 8/9] fix lint --- .../core/framework/allocation_planner.cc | 4 +-- .../test/framework/allocation_planner_test.cc | 25 ++++++++++++++----- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 85272100f0b58..3bd965e66db8d 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -529,7 +529,7 @@ class PlannerImpl { // Initialize allocation plan: plan_.allocation_plan.resize(num_ml_values); - for (size_t i = 0; i < num_ml_values; i++) plan_.allocation_plan[i].reused_buffer = static_cast(i); + for (int i = 0; static_cast(i) < num_ml_values; i++) AllocPlan(i).reused_buffer = i; } bool HasExternalOutputs(const Node& node) const { @@ -1345,7 +1345,7 @@ class PlannerImpl { for (size_t i = 0; i < stream_nodes_.size(); ++i) { // compute use count first. TODO(leca): call ComputeReuseCount() only once is enough ORT_RETURN_IF_ERROR(ComputeReuseCount()); - for (size_t j = 0; j < ort_value_info_.size(); j++) Buffer(static_cast(j)) = static_cast(j); + for (int j = 0; static_cast(j) < ort_value_info_.size(); j++) Buffer(j) = j; #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) if (i == 0) { for (auto ort_value_info : ort_value_info_) { diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index dc6467bf736d3..3e0d94e94e48c 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -1982,8 +1982,8 @@ TEST_F(PlannerTest, TestCpuIf) { // |-----------> Gather // Shape ----------------> Reshape --> Shape ------------------> Reshape // ^ ^ -// InstanceNormalization ----| InstanceNormalization ------| -// +// InstanceNormalization ----| InstanceNormalization ------| +// // Python script to create this model: // def CreateModelFor19480(): // #shape->reshape->shape->reshape, 4 gather @@ -1998,10 +1998,23 @@ TEST_F(PlannerTest, TestCpuIf) { // graphNodes.append(h.make_node('Gather', inputs=['281', 'indices2'], outputs=['output2'])) // graphNodes.append(h.make_node('Gather', inputs=['281', 'indices3'], outputs=['output3'])) // graphNodes.append(h.make_node('Gather', inputs=['281', 'indices4'], outputs=['output4'])) -// g = h.make_graph(graphNodes, 'issue_19480', [h.make_tensor_value_info('shape_input', tp.FLOAT, ['batch', 128, None, None]), h.make_tensor_value_info('in0_input', tp.FLOAT, ['batch', 32, None]), h.make_tensor_value_info('scale0', tp.FLOAT, [32]), h.make_tensor_value_info('B0', tp.FLOAT, [32]), -// h.make_tensor_value_info('in1_input', tp.FLOAT, ['batch', 32, None]), h.make_tensor_value_info('scale1', tp.FLOAT, [32]), h.make_tensor_value_info('B1', tp.FLOAT, [32]), -// h.make_tensor_value_info('indices1', tp.INT32, []), h.make_tensor_value_info('indices2', tp.INT32, []), h.make_tensor_value_info('indices3', tp.INT32, []), h.make_tensor_value_info('indices4', tp.INT32, [])], -// [h.make_tensor_value_info('output0', tp.FLOAT, None), h.make_tensor_value_info('output1', tp.INT64, None), h.make_tensor_value_info('output2', tp.INT64, None), h.make_tensor_value_info('output3', tp.INT64, None), h.make_tensor_value_info('output4', tp.INT64, None)]) +// g = h.make_graph(graphNodes, 'issue_19480', +// [h.make_tensor_value_info('shape_input', tp.FLOAT, ['batch', 128, None, None]), +// h.make_tensor_value_info('in0_input', tp.FLOAT, ['batch', 32, None]), +// h.make_tensor_value_info('scale0', tp.FLOAT, [32]), +// h.make_tensor_value_info('B0', tp.FLOAT, [32]), +// h.make_tensor_value_info('in1_input', tp.FLOAT, ['batch', 32, None]), +// h.make_tensor_value_info('scale1', tp.FLOAT, [32]), +// h.make_tensor_value_info('B1', tp.FLOAT, [32]), +// h.make_tensor_value_info('indices1', tp.INT32, []), +// h.make_tensor_value_info('indices2', tp.INT32, []), +// h.make_tensor_value_info('indices3', tp.INT32, []), +// h.make_tensor_value_info('indices4', tp.INT32, [])], +// [h.make_tensor_value_info('output0', tp.FLOAT, None), +// h.make_tensor_value_info('output1', tp.INT64, None), +// h.make_tensor_value_info('output2', tp.INT64, None), +// h.make_tensor_value_info('output3', tp.INT64, None), +// h.make_tensor_value_info('output4', tp.INT64, None)]) // model = h.make_model(g, opset_imports=[h.make_operatorsetid("", 17)], producer_name='producer_name') // onnx.save(model, 'issue_19480.onnx') // From 16507c3302535208f6492e846b30af43fa2ba884 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 21 Feb 2024 17:38:15 -0800 Subject: [PATCH 9/9] reset lint --- onnxruntime/core/framework/allocation_planner.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 3bd965e66db8d..158ab8ed610f4 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -1343,7 +1343,7 @@ class PlannerImpl { ort_value_usecount.reserve(ort_value_info_.size()); #endif for (size_t i = 0; i < stream_nodes_.size(); ++i) { - // compute use count first. TODO(leca): call ComputeReuseCount() only once is enough + // compute use count first. TODO(leca): call ComputeReuseCount() only once is enough! ORT_RETURN_IF_ERROR(ComputeReuseCount()); for (int j = 0; static_cast(j) < ort_value_info_.size(); j++) Buffer(j) = j; #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)