diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 9594ca2dc6199..4ca959528bab9 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -182,7 +182,6 @@ class PlannerImpl { // upstream_node_0 and upstream_node_1 are the immmediate upstream nodes of downstream_node // upstream_node_2 is the immediate nodes ahead of downstream_node in the same logic stream InlinedHashMap> dependence_graph_; - InlinedHashMap> value_consumer_map_; InlinedHashMap value_node_map_; // OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation: @@ -200,6 +199,8 @@ class PlannerImpl { #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) OrtValueIndex inplace_reused_buffer_index = -1; // index of original buffer to reuse inplace #endif + // reused_buffer_index_per_stream will be reset and used in each ComputeSingleStreamReusePlan() + // reused_buffer_index will be updated with reused_buffer_index_per_stream and preserved for the following GenerateDeallocationPlan() OrtValueIndex reused_buffer_index_per_stream; }; @@ -297,7 +298,7 @@ class PlannerImpl { } #endif - // Find if there exists some input tensor that we can use in-place for output_arg_num-th input in the node. + // Find if there exists some input tensor that we can use in-place for output_arg_num-th output in the node. bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input, bool* is_strided_tensor) { *is_strided_tensor = false; @@ -359,7 +360,6 @@ class PlannerImpl { auto p_input_arg = input_args[pair.first]; if (p_input_arg->Exists()) { auto input_arg_index = Index(p_input_arg->Name()); - //auto original = Buffer(input_arg_index); auto original = ort_value_info_[input_arg_index].reused_buffer_index_per_stream; if (1 == UseCount(original)) { if (SameSize(*p_input_arg, *p_output_arg)) { @@ -1068,7 +1068,8 @@ class PlannerImpl { // build the consumer list for each value int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1; - value_consumer_map_.reserve(num_ml_values); + InlinedHashMap> value_consumer_map; + value_consumer_map.reserve(num_ml_values); // iterate each stream from back, so the first element is the last consumer in single stream case for (auto& stream : stream_nodes_) { @@ -1084,7 +1085,7 @@ class PlannerImpl { auto origin = Buffer(value_idx); if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) { // add current node as consumer for origin buffer - value_consumer_map_[origin].insert(node_index); + value_consumer_map[origin].insert(node_index); } } return Status::OK(); @@ -1141,8 +1142,8 @@ class PlannerImpl { std::cout << p_input_arg->Name() << " reused by " << p_output_arg->Name() << " as input" << std::endl; allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; - value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); found_reusable = true; break; @@ -1171,8 +1172,8 @@ class PlannerImpl { allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate) { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; - value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); continue; } // if @@ -1190,11 +1191,11 @@ class PlannerImpl { OrtValueIndex input_arg_index{}; if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() && allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) { - if (value_consumer_map_[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { + if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = input_arg_index; - value_consumer_map_[input_arg_index].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(input_arg_index); } } @@ -1269,7 +1270,7 @@ class PlannerImpl { } bool all_covered = true; - for (auto consumer : value_consumer_map_[output_idx_global]) { + for (auto consumer : value_consumer_map[output_idx_global]) { if (deps->find(consumer) == deps->end()) { all_covered = false; break; @@ -1280,9 +1281,9 @@ class PlannerImpl { allocation_plan[downstream_value].reused_buffer = output_idx_global; get_reused = true; // add new consumer for the value to be reused - value_consumer_map_[output_idx_global].insert(value_node_map_[downstream_value]); - value_consumer_map_[output_idx_global].insert(value_consumer_map_[downstream_value].begin(), - value_consumer_map_[downstream_value].end()); + value_consumer_map[output_idx_global].insert(value_node_map_[downstream_value]); + value_consumer_map[output_idx_global].insert(value_consumer_map[downstream_value].begin(), + value_consumer_map[downstream_value].end()); node_iter = size_iter->second.erase(node_iter); if (size_iter->second.empty()) { local_iter->second.erase(size_iter); @@ -1339,22 +1340,16 @@ class PlannerImpl { // use parallel execution context to generate a baseline first (no memory sharing) context_ = gsl::not_null(&no_mem_reuse_context); } -#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) // copy the use counts to a vector, before computing reuse std::vector ort_value_usecount; ort_value_usecount.reserve(ort_value_info_.size()); -#endif + ORT_RETURN_IF_ERROR(ComputeReuseCount()); + for (auto& ort_value_info : ort_value_info_) ort_value_usecount.push_back(ort_value_info.usecount); for (size_t i = 0; i < stream_nodes_.size(); ++i) { - // compute use count first - ORT_RETURN_IF_ERROR(ComputeReuseCount()); for (size_t j = 0; j < ort_value_info_.size(); j++) ort_value_info_[j].reused_buffer_index_per_stream = static_cast(j); -#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) - if (i == 0) { - for (auto ort_value_info : ort_value_info_) { - ort_value_usecount.push_back(ort_value_info.usecount); - } + if (i > 0) { + for (size_t k = 0; k < ort_value_usecount.size(); k++) UseCount(static_cast(k)) = ort_value_usecount[k]; } -#endif ORT_RETURN_IF_ERROR(ComputeSingleStreamReusePlan(i)); ClearUseCount(); freelist_.clear(); // DONOT share freelist across streams @@ -1472,7 +1467,6 @@ class PlannerImpl { for (auto node_input : pnode->InputDefs()) { if (node_input->Exists()) { auto& sym = node_input->Name(); - //auto original = Buffer(Index(sym)); auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream; // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. @@ -1485,7 +1479,6 @@ class PlannerImpl { for (auto node_input : pnode->ImplicitInputDefs()) { if (node_input->Exists()) { auto& sym = node_input->Name(); - //auto original = Buffer(Index(sym)); auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream; // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. @@ -1499,7 +1492,6 @@ class PlannerImpl { for (auto node_output : pnode->OutputDefs()) { if (node_output->Exists()) { auto& sym = node_output->Name(); - //auto original = Buffer(Index(sym)); auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream; // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index d7b1de5c930c5..b40aac8efeae1 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -1974,6 +1974,31 @@ TEST_F(PlannerTest, TestCpuIf) { ASSERT_TRUE(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()) == WaitOnEPStep); } } + +TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) { + SessionOptions sess_opt; + sess_opt.graph_optimization_level = TransformerLevel::Default; + + InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/issue_19480.onnx")); + auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); + status = sess.Load(); + status = sess.Initialize(); + ASSERT_TRUE(status.IsOK()) << "No crash"; + const SequentialExecutionPlan* plan = sess.GetSessionState().GetExecutionPlan(); + ASSERT_EQ(plan->allocation_plan[14].alloc_kind, AllocKind::kReuse) << "The input of reshape and gather will reuse the output of shape"; + + int gather_count = 0; + for (size_t i = 0; i < plan->execution_plan[1]->steps_.size(); i++) { + if (strstr(typeid(*(plan->execution_plan[1]->steps_[i])).name(), "LaunchKernelStep")) { + const Node* node = sess.GetSessionState().GetGraphViewer().GetNode(plan->execution_plan[1]->steps_[i]->GetNodeIndex()); + if (node->OpType() == "Gather") + gather_count++; + else + FAIL() << "CPU stream should contain only gather ops"; + } + } + ASSERT_EQ(gather_count, 4) << "4 gather ops are all placed in CPU stream"; +} #endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx b/onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx new file mode 100644 index 0000000000000..dc7d39206dd49 Binary files /dev/null and b/onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx differ