From 3e50a00ce9b1915384bcc4c37242e5e76839b7bf Mon Sep 17 00:00:00 2001 From: jslhcl Date: Tue, 20 Feb 2024 18:19:24 -0800 Subject: [PATCH] fix comments --- .../core/framework/allocation_planner.cc | 47 ++++++++++--------- .../test/framework/allocation_planner_test.cc | 30 ++++++++++++ 2 files changed, 55 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 4ca959528bab9..85272100f0b58 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -199,9 +199,6 @@ class PlannerImpl { #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) OrtValueIndex inplace_reused_buffer_index = -1; // index of original buffer to reuse inplace #endif - // reused_buffer_index_per_stream will be reset and used in each ComputeSingleStreamReusePlan() - // reused_buffer_index will be updated with reused_buffer_index_per_stream and preserved for the following GenerateDeallocationPlan() - OrtValueIndex reused_buffer_index_per_stream; }; // ort_value_info_ is indexed by an OrtValueIndex @@ -279,7 +276,6 @@ class PlannerImpl { OrtValueIndex original = Buffer(reused); // record that the new buffer will reuse that original buffer Buffer(reused_for) = original; - ort_value_info_[reused_for].reused_buffer_index_per_stream = original; // adjust original buffer's usecount UseCount(original) += UseCount(reused_for); @@ -360,7 +356,7 @@ class PlannerImpl { auto p_input_arg = input_args[pair.first]; if (p_input_arg->Exists()) { auto input_arg_index = Index(p_input_arg->Name()); - auto original = ort_value_info_[input_arg_index].reused_buffer_index_per_stream; + auto original = Buffer(input_arg_index); if (1 == UseCount(original)) { if (SameSize(*p_input_arg, *p_output_arg)) { // we can reuse this input since it is its last use and permitted for in-place update @@ -533,6 +529,7 @@ class PlannerImpl { // Initialize allocation plan: plan_.allocation_plan.resize(num_ml_values); + for (size_t i = 0; i < num_ml_values; i++) plan_.allocation_plan[i].reused_buffer = static_cast(i); } bool HasExternalOutputs(const Node& node) const { @@ -1082,8 +1079,8 @@ class PlannerImpl { const auto& name = input.Name(); int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); - auto origin = Buffer(value_idx); - if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) { + auto origin = AllocPlan(value_idx).reused_buffer; + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { // add current node as consumer for origin buffer value_consumer_map[origin].insert(node_index); } @@ -1143,7 +1140,7 @@ class PlannerImpl { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), - value_consumer_map[output_idx_global].end()); + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); found_reusable = true; break; @@ -1173,7 +1170,7 @@ class PlannerImpl { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), - value_consumer_map[output_idx_global].end()); + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); continue; } // if @@ -1195,7 +1192,7 @@ class PlannerImpl { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = input_arg_index; value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(), - value_consumer_map[output_idx_global].end()); + value_consumer_map[output_idx_global].end()); reused.insert(input_arg_index); } } @@ -1283,7 +1280,7 @@ class PlannerImpl { // add new consumer for the value to be reused value_consumer_map[output_idx_global].insert(value_node_map_[downstream_value]); value_consumer_map[output_idx_global].insert(value_consumer_map[downstream_value].begin(), - value_consumer_map[downstream_value].end()); + value_consumer_map[downstream_value].end()); node_iter = size_iter->second.erase(node_iter); if (size_iter->second.empty()) { local_iter->second.erase(size_iter); @@ -1340,16 +1337,22 @@ class PlannerImpl { // use parallel execution context to generate a baseline first (no memory sharing) context_ = gsl::not_null(&no_mem_reuse_context); } +#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) // copy the use counts to a vector, before computing reuse std::vector ort_value_usecount; ort_value_usecount.reserve(ort_value_info_.size()); - ORT_RETURN_IF_ERROR(ComputeReuseCount()); - for (auto& ort_value_info : ort_value_info_) ort_value_usecount.push_back(ort_value_info.usecount); +#endif for (size_t i = 0; i < stream_nodes_.size(); ++i) { - for (size_t j = 0; j < ort_value_info_.size(); j++) ort_value_info_[j].reused_buffer_index_per_stream = static_cast(j); - if (i > 0) { - for (size_t k = 0; k < ort_value_usecount.size(); k++) UseCount(static_cast(k)) = ort_value_usecount[k]; + // compute use count first. TODO(leca): call ComputeReuseCount() only once is enough + ORT_RETURN_IF_ERROR(ComputeReuseCount()); + for (size_t j = 0; j < ort_value_info_.size(); j++) Buffer(static_cast(j)) = static_cast(j); +#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) + if (i == 0) { + for (auto ort_value_info : ort_value_info_) { + ort_value_usecount.push_back(ort_value_info.usecount); + } } +#endif ORT_RETURN_IF_ERROR(ComputeSingleStreamReusePlan(i)); ClearUseCount(); freelist_.clear(); // DONOT share freelist across streams @@ -1467,7 +1470,7 @@ class PlannerImpl { for (auto node_input : pnode->InputDefs()) { if (node_input->Exists()) { auto& sym = node_input->Name(); - auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream; + auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. if ((original != -1) && (0 == DecrementUseCount(original))) { @@ -1479,7 +1482,7 @@ class PlannerImpl { for (auto node_input : pnode->ImplicitInputDefs()) { if (node_input->Exists()) { auto& sym = node_input->Name(); - auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream; + auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. if ((original != -1) && (0 == DecrementUseCount(original))) { @@ -1492,7 +1495,7 @@ class PlannerImpl { for (auto node_output : pnode->OutputDefs()) { if (node_output->Exists()) { auto& sym = node_output->Name(); - auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream; + auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. if (0 == DecrementUseCount(original)) { @@ -1692,8 +1695,8 @@ class PlannerImpl { const auto& name = input.Name(); int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); - auto origin = Buffer(value_idx); - if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) { + auto origin = AllocPlan(value_idx).reused_buffer; + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { // add current node as consumer for origin buffer value_consumers[origin].push_back(node_index); } @@ -1888,7 +1891,7 @@ class PlannerImpl { // 2. the consumer is in the same stream(non-cpu device), but it consumes a CPU tensor from an non-shape op. // for example, a resize cuda kernel consumer a tensor from MemCpyToHost cuda kernel on the same stream. // in this case, the FIFO can't guarantee the cpu tensor is ready when resize kernel is launching - OrtDevice::DeviceType output_arg_device = plan_.allocation_plan[output_arg_idx].location.Type(); + OrtDevice::DeviceType output_arg_device = AllocPlan(output_arg_idx).location.Type(); WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, output_arg_device); if ((node_stream_map_[it->Index()] != i || output_arg_device == OrtDevice::CPU) && wait_handle != nullptr) { if (node_to_notification.find(node_index) == node_to_notification.end()) { diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index b40aac8efeae1..dc6467bf736d3 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -1975,6 +1975,36 @@ TEST_F(PlannerTest, TestCpuIf) { } } +// model looks like: +// |-----------> Gather +// |-----------> Gather +// |-----------> Gather +// |-----------> Gather +// Shape ----------------> Reshape --> Shape ------------------> Reshape +// ^ ^ +// InstanceNormalization ----| InstanceNormalization ------| +// +// Python script to create this model: +// def CreateModelFor19480(): +// #shape->reshape->shape->reshape, 4 gather +// graphNodes = [] +// graphNodes.append(h.make_node('Shape', inputs=['shape_input'], outputs=['9'])) +// graphNodes.append(h.make_node('InstanceNormalization', inputs=['in0_input', 'scale0', 'B0'], outputs=['8'])) +// graphNodes.append(h.make_node('Reshape', inputs=['8', '9'], outputs=['Reshape15_output'])) +// graphNodes.append(h.make_node('Shape', inputs=['Reshape15_output'], outputs=['281'])) +// graphNodes.append(h.make_node('InstanceNormalization', inputs=['in1_input', 'scale1', 'B1'], outputs=['293'])) +// graphNodes.append(h.make_node('Reshape', inputs=['293', '281'], outputs=['output0'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices1'], outputs=['output1'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices2'], outputs=['output2'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices3'], outputs=['output3'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices4'], outputs=['output4'])) +// g = h.make_graph(graphNodes, 'issue_19480', [h.make_tensor_value_info('shape_input', tp.FLOAT, ['batch', 128, None, None]), h.make_tensor_value_info('in0_input', tp.FLOAT, ['batch', 32, None]), h.make_tensor_value_info('scale0', tp.FLOAT, [32]), h.make_tensor_value_info('B0', tp.FLOAT, [32]), +// h.make_tensor_value_info('in1_input', tp.FLOAT, ['batch', 32, None]), h.make_tensor_value_info('scale1', tp.FLOAT, [32]), h.make_tensor_value_info('B1', tp.FLOAT, [32]), +// h.make_tensor_value_info('indices1', tp.INT32, []), h.make_tensor_value_info('indices2', tp.INT32, []), h.make_tensor_value_info('indices3', tp.INT32, []), h.make_tensor_value_info('indices4', tp.INT32, [])], +// [h.make_tensor_value_info('output0', tp.FLOAT, None), h.make_tensor_value_info('output1', tp.INT64, None), h.make_tensor_value_info('output2', tp.INT64, None), h.make_tensor_value_info('output3', tp.INT64, None), h.make_tensor_value_info('output4', tp.INT64, None)]) +// model = h.make_model(g, opset_imports=[h.make_operatorsetid("", 17)], producer_name='producer_name') +// onnx.save(model, 'issue_19480.onnx') +// TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) { SessionOptions sess_opt; sess_opt.graph_optimization_level = TransformerLevel::Default;