Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jslhcl committed Feb 21, 2024
1 parent 8262d0b commit 3e50a00
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 22 deletions.
47 changes: 25 additions & 22 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,6 @@ class PlannerImpl {
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
OrtValueIndex inplace_reused_buffer_index = -1; // index of original buffer to reuse inplace
#endif
// reused_buffer_index_per_stream will be reset and used in each ComputeSingleStreamReusePlan()
// reused_buffer_index will be updated with reused_buffer_index_per_stream and preserved for the following GenerateDeallocationPlan()
OrtValueIndex reused_buffer_index_per_stream;
};

// ort_value_info_ is indexed by an OrtValueIndex
Expand Down Expand Up @@ -279,7 +276,6 @@ class PlannerImpl {
OrtValueIndex original = Buffer(reused);
// record that the new buffer will reuse that original buffer
Buffer(reused_for) = original;
ort_value_info_[reused_for].reused_buffer_index_per_stream = original;
// adjust original buffer's usecount
UseCount(original) += UseCount(reused_for);

Expand Down Expand Up @@ -360,7 +356,7 @@ class PlannerImpl {
auto p_input_arg = input_args[pair.first];
if (p_input_arg->Exists()) {
auto input_arg_index = Index(p_input_arg->Name());
auto original = ort_value_info_[input_arg_index].reused_buffer_index_per_stream;
auto original = Buffer(input_arg_index);
if (1 == UseCount(original)) {
if (SameSize(*p_input_arg, *p_output_arg)) {
// we can reuse this input since it is its last use and permitted for in-place update
Expand Down Expand Up @@ -533,6 +529,7 @@ class PlannerImpl {

// Initialize allocation plan:
plan_.allocation_plan.resize(num_ml_values);
for (size_t i = 0; i < num_ml_values; i++) plan_.allocation_plan[i].reused_buffer = static_cast<OrtValueIndex>(i);
}

bool HasExternalOutputs(const Node& node) const {
Expand Down Expand Up @@ -1082,8 +1079,8 @@ class PlannerImpl {
const auto& name = input.Name();
int value_idx;
ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx));
auto origin = Buffer(value_idx);
if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) {
auto origin = AllocPlan(value_idx).reused_buffer;
if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) {
// add current node as consumer for origin buffer
value_consumer_map[origin].insert(node_index);
}
Expand Down Expand Up @@ -1143,7 +1140,7 @@ class PlannerImpl {
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = reusable_input;
value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(),
value_consumer_map[output_idx_global].end());
value_consumer_map[output_idx_global].end());
reused.insert(reusable_input);
found_reusable = true;
break;
Expand Down Expand Up @@ -1173,7 +1170,7 @@ class PlannerImpl {
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = reusable_input;
value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(),
value_consumer_map[output_idx_global].end());
value_consumer_map[output_idx_global].end());
reused.insert(reusable_input);
continue;
} // if
Expand All @@ -1195,7 +1192,7 @@ class PlannerImpl {
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = input_arg_index;
value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(),
value_consumer_map[output_idx_global].end());
value_consumer_map[output_idx_global].end());
reused.insert(input_arg_index);
}
}
Expand Down Expand Up @@ -1283,7 +1280,7 @@ class PlannerImpl {
// add new consumer for the value to be reused
value_consumer_map[output_idx_global].insert(value_node_map_[downstream_value]);
value_consumer_map[output_idx_global].insert(value_consumer_map[downstream_value].begin(),
value_consumer_map[downstream_value].end());
value_consumer_map[downstream_value].end());
node_iter = size_iter->second.erase(node_iter);
if (size_iter->second.empty()) {
local_iter->second.erase(size_iter);
Expand Down Expand Up @@ -1340,16 +1337,22 @@ class PlannerImpl {
// use parallel execution context to generate a baseline first (no memory sharing)
context_ = gsl::not_null<const ISequentialPlannerContext*>(&no_mem_reuse_context);
}
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
// copy the use counts to a vector, before computing reuse
std::vector<int> ort_value_usecount;
ort_value_usecount.reserve(ort_value_info_.size());
ORT_RETURN_IF_ERROR(ComputeReuseCount());
for (auto& ort_value_info : ort_value_info_) ort_value_usecount.push_back(ort_value_info.usecount);
#endif
for (size_t i = 0; i < stream_nodes_.size(); ++i) {
for (size_t j = 0; j < ort_value_info_.size(); j++) ort_value_info_[j].reused_buffer_index_per_stream = static_cast<OrtValueIndex>(j);
if (i > 0) {
for (size_t k = 0; k < ort_value_usecount.size(); k++) UseCount(static_cast<OrtValueIndex>(k)) = ort_value_usecount[k];
// compute use count first. TODO(leca): call ComputeReuseCount() only once is enough
ORT_RETURN_IF_ERROR(ComputeReuseCount());
for (size_t j = 0; j < ort_value_info_.size(); j++) Buffer(static_cast<OrtValueIndex>(j)) = static_cast<OrtValueIndex>(j);

Check warning on line 1348 in onnxruntime/core/framework/allocation_planner.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/allocation_planner.cc:1348: Lines should be <= 120 characters long [whitespace/line_length] [2]
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
if (i == 0) {
for (auto ort_value_info : ort_value_info_) {
ort_value_usecount.push_back(ort_value_info.usecount);
}
}
#endif
ORT_RETURN_IF_ERROR(ComputeSingleStreamReusePlan(i));
ClearUseCount();
freelist_.clear(); // DONOT share freelist across streams
Expand Down Expand Up @@ -1467,7 +1470,7 @@ class PlannerImpl {
for (auto node_input : pnode->InputDefs()) {
if (node_input->Exists()) {
auto& sym = node_input->Name();
auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream;
auto original = Buffer(Index(sym));
// The index will be -1 if it's an initializer that was removed as part of a temporary workaround.
// See comments in the OrtValueInfo definition.
if ((original != -1) && (0 == DecrementUseCount(original))) {
Expand All @@ -1479,7 +1482,7 @@ class PlannerImpl {
for (auto node_input : pnode->ImplicitInputDefs()) {
if (node_input->Exists()) {
auto& sym = node_input->Name();
auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream;
auto original = Buffer(Index(sym));
// The index will be -1 if it's an initializer that was removed as part of a temporary workaround.
// See comments in the OrtValueInfo definition.
if ((original != -1) && (0 == DecrementUseCount(original))) {
Expand All @@ -1492,7 +1495,7 @@ class PlannerImpl {
for (auto node_output : pnode->OutputDefs()) {
if (node_output->Exists()) {
auto& sym = node_output->Name();
auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream;
auto original = Buffer(Index(sym));
// The index will be -1 if it's an initializer that was removed as part of a temporary workaround.
// See comments in the OrtValueInfo definition.
if (0 == DecrementUseCount(original)) {
Expand Down Expand Up @@ -1692,8 +1695,8 @@ class PlannerImpl {
const auto& name = input.Name();
int value_idx;
ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx));
auto origin = Buffer(value_idx);
if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) {
auto origin = AllocPlan(value_idx).reused_buffer;
if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) {
// add current node as consumer for origin buffer
value_consumers[origin].push_back(node_index);
}
Expand Down Expand Up @@ -1888,7 +1891,7 @@ class PlannerImpl {
// 2. the consumer is in the same stream(non-cpu device), but it consumes a CPU tensor from an non-shape op.
// for example, a resize cuda kernel consumer a tensor from MemCpyToHost cuda kernel on the same stream.
// in this case, the FIFO can't guarantee the cpu tensor is ready when resize kernel is launching
OrtDevice::DeviceType output_arg_device = plan_.allocation_plan[output_arg_idx].location.Type();
OrtDevice::DeviceType output_arg_device = AllocPlan(output_arg_idx).location.Type();
WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, output_arg_device);
if ((node_stream_map_[it->Index()] != i || output_arg_device == OrtDevice::CPU) && wait_handle != nullptr) {
if (node_to_notification.find(node_index) == node_to_notification.end()) {
Expand Down
30 changes: 30 additions & 0 deletions onnxruntime/test/framework/allocation_planner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1975,6 +1975,36 @@ TEST_F(PlannerTest, TestCpuIf) {
}
}

// model looks like:
// |-----------> Gather
// |-----------> Gather
// |-----------> Gather
// |-----------> Gather
// Shape ----------------> Reshape --> Shape ------------------> Reshape
// ^ ^
// InstanceNormalization ----| InstanceNormalization ------|

Check warning on line 1985 in onnxruntime/test/framework/allocation_planner_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/test/framework/allocation_planner_test.cc:1985: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
//

Check warning on line 1986 in onnxruntime/test/framework/allocation_planner_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/test/framework/allocation_planner_test.cc:1986: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
// Python script to create this model:
// def CreateModelFor19480():
// #shape->reshape->shape->reshape, 4 gather
// graphNodes = []
// graphNodes.append(h.make_node('Shape', inputs=['shape_input'], outputs=['9']))
// graphNodes.append(h.make_node('InstanceNormalization', inputs=['in0_input', 'scale0', 'B0'], outputs=['8']))
// graphNodes.append(h.make_node('Reshape', inputs=['8', '9'], outputs=['Reshape15_output']))
// graphNodes.append(h.make_node('Shape', inputs=['Reshape15_output'], outputs=['281']))
// graphNodes.append(h.make_node('InstanceNormalization', inputs=['in1_input', 'scale1', 'B1'], outputs=['293']))
// graphNodes.append(h.make_node('Reshape', inputs=['293', '281'], outputs=['output0']))
// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices1'], outputs=['output1']))
// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices2'], outputs=['output2']))
// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices3'], outputs=['output3']))
// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices4'], outputs=['output4']))
// g = h.make_graph(graphNodes, 'issue_19480', [h.make_tensor_value_info('shape_input', tp.FLOAT, ['batch', 128, None, None]), h.make_tensor_value_info('in0_input', tp.FLOAT, ['batch', 32, None]), h.make_tensor_value_info('scale0', tp.FLOAT, [32]), h.make_tensor_value_info('B0', tp.FLOAT, [32]),

Check warning on line 2001 in onnxruntime/test/framework/allocation_planner_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/test/framework/allocation_planner_test.cc:2001: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]

Check warning on line 2001 in onnxruntime/test/framework/allocation_planner_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/framework/allocation_planner_test.cc:2001: Lines should be <= 120 characters long [whitespace/line_length] [2]
// h.make_tensor_value_info('in1_input', tp.FLOAT, ['batch', 32, None]), h.make_tensor_value_info('scale1', tp.FLOAT, [32]), h.make_tensor_value_info('B1', tp.FLOAT, [32]),

Check warning on line 2002 in onnxruntime/test/framework/allocation_planner_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/test/framework/allocation_planner_test.cc:2002: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]

Check warning on line 2002 in onnxruntime/test/framework/allocation_planner_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/framework/allocation_planner_test.cc:2002: Lines should be <= 120 characters long [whitespace/line_length] [2]
// h.make_tensor_value_info('indices1', tp.INT32, []), h.make_tensor_value_info('indices2', tp.INT32, []), h.make_tensor_value_info('indices3', tp.INT32, []), h.make_tensor_value_info('indices4', tp.INT32, [])],

Check warning on line 2003 in onnxruntime/test/framework/allocation_planner_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/test/framework/allocation_planner_test.cc:2003: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]

Check warning on line 2003 in onnxruntime/test/framework/allocation_planner_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/framework/allocation_planner_test.cc:2003: Lines should be <= 120 characters long [whitespace/line_length] [2]
// [h.make_tensor_value_info('output0', tp.FLOAT, None), h.make_tensor_value_info('output1', tp.INT64, None), h.make_tensor_value_info('output2', tp.INT64, None), h.make_tensor_value_info('output3', tp.INT64, None), h.make_tensor_value_info('output4', tp.INT64, None)])

Check warning on line 2004 in onnxruntime/test/framework/allocation_planner_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/framework/allocation_planner_test.cc:2004: Lines should be <= 120 characters long [whitespace/line_length] [2]
// model = h.make_model(g, opset_imports=[h.make_operatorsetid("", 17)], producer_name='producer_name')
// onnx.save(model, 'issue_19480.onnx')
//
TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) {
SessionOptions sess_opt;
sess_opt.graph_optimization_level = TransformerLevel::Default;
Expand Down

0 comments on commit 3e50a00

Please sign in to comment.