Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce reused_buffer_index_per_stream in allocation planner which will be reset after computing the reuse buffer for each stream #19515

Merged
merged 10 commits into from
Feb 22, 2024
44 changes: 23 additions & 21 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ class PlannerImpl {
// upstream_node_0 and upstream_node_1 are the immmediate upstream nodes of downstream_node
// upstream_node_2 is the immediate nodes ahead of downstream_node in the same logic stream
InlinedHashMap<onnxruntime::NodeIndex, InlinedHashSet<onnxruntime::NodeIndex>> dependence_graph_;
InlinedHashMap<onnxruntime::OrtValueIndex, InlinedHashSet<onnxruntime::NodeIndex>> value_consumer_map_;
InlinedHashMap<onnxruntime::OrtValueIndex, onnxruntime::NodeIndex> value_node_map_;

// OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation:
Expand Down Expand Up @@ -295,7 +294,7 @@ class PlannerImpl {
}
#endif

// Find if there exists some input tensor that we can use in-place for output_arg_num-th input in the node.
// Find if there exists some input tensor that we can use in-place for output_arg_num-th output in the node.
bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input,
bool* is_strided_tensor) {
*is_strided_tensor = false;
Expand Down Expand Up @@ -530,6 +529,7 @@ class PlannerImpl {

// Initialize allocation plan:
plan_.allocation_plan.resize(num_ml_values);
for (int i = 0; static_cast<size_t>(i) < num_ml_values; i++) AllocPlan(i).reused_buffer = i;
}

bool HasExternalOutputs(const Node& node) const {
Expand Down Expand Up @@ -1065,7 +1065,8 @@ class PlannerImpl {

// build the consumer list for each value
int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1;
value_consumer_map_.reserve(num_ml_values);
InlinedHashMap<onnxruntime::OrtValueIndex, InlinedHashSet<onnxruntime::NodeIndex>> value_consumer_map;
jslhcl marked this conversation as resolved.
Show resolved Hide resolved
value_consumer_map.reserve(num_ml_values);

// iterate each stream from back, so the first element is the last consumer in single stream case
for (auto& stream : stream_nodes_) {
Expand All @@ -1078,10 +1079,10 @@ class PlannerImpl {
const auto& name = input.Name();
int value_idx;
ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx));
auto origin = Buffer(value_idx);
if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) {
auto origin = AllocPlan(value_idx).reused_buffer;
if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) {
// add current node as consumer for origin buffer
value_consumer_map_[origin].insert(node_index);
value_consumer_map[origin].insert(node_index);
}
}
return Status::OK();
Expand Down Expand Up @@ -1138,8 +1139,8 @@ class PlannerImpl {
std::cout << p_input_arg->Name() << " reused by " << p_output_arg->Name() << " as input" << std::endl;
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = reusable_input;
value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(),
value_consumer_map_[output_idx_global].end());
value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(),
value_consumer_map[output_idx_global].end());
reused.insert(reusable_input);
found_reusable = true;
break;
Expand Down Expand Up @@ -1168,8 +1169,8 @@ class PlannerImpl {
allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate) {
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = reusable_input;
value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(),
value_consumer_map_[output_idx_global].end());
value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(),
value_consumer_map[output_idx_global].end());
reused.insert(reusable_input);
continue;
} // if
Expand All @@ -1187,11 +1188,11 @@ class PlannerImpl {
OrtValueIndex input_arg_index{};
if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() &&
allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) {
if (value_consumer_map_[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) {
if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) {
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = input_arg_index;
value_consumer_map_[input_arg_index].insert(value_consumer_map_[output_idx_global].begin(),
value_consumer_map_[output_idx_global].end());
value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(),
value_consumer_map[output_idx_global].end());
reused.insert(input_arg_index);
}
}
Expand Down Expand Up @@ -1266,7 +1267,7 @@ class PlannerImpl {
}

bool all_covered = true;
for (auto consumer : value_consumer_map_[output_idx_global]) {
for (auto consumer : value_consumer_map[output_idx_global]) {
if (deps->find(consumer) == deps->end()) {
all_covered = false;
break;
Expand All @@ -1277,9 +1278,9 @@ class PlannerImpl {
allocation_plan[downstream_value].reused_buffer = output_idx_global;
get_reused = true;
// add new consumer for the value to be reused
value_consumer_map_[output_idx_global].insert(value_node_map_[downstream_value]);
value_consumer_map_[output_idx_global].insert(value_consumer_map_[downstream_value].begin(),
value_consumer_map_[downstream_value].end());
value_consumer_map[output_idx_global].insert(value_node_map_[downstream_value]);
value_consumer_map[output_idx_global].insert(value_consumer_map[downstream_value].begin(),
value_consumer_map[downstream_value].end());
node_iter = size_iter->second.erase(node_iter);
if (size_iter->second.empty()) {
local_iter->second.erase(size_iter);
Expand Down Expand Up @@ -1342,8 +1343,9 @@ class PlannerImpl {
ort_value_usecount.reserve(ort_value_info_.size());
#endif
for (size_t i = 0; i < stream_nodes_.size(); ++i) {
// compute use count first
// compute use count first. TODO(leca): call ComputeReuseCount() only once is enough!
ORT_RETURN_IF_ERROR(ComputeReuseCount());
for (int j = 0; static_cast<size_t>(j) < ort_value_info_.size(); j++) Buffer(j) = j;
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
if (i == 0) {
for (auto ort_value_info : ort_value_info_) {
Expand Down Expand Up @@ -1693,8 +1695,8 @@ class PlannerImpl {
const auto& name = input.Name();
int value_idx;
ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx));
auto origin = Buffer(value_idx);
if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) {
auto origin = AllocPlan(value_idx).reused_buffer;
if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) {
// add current node as consumer for origin buffer
value_consumers[origin].push_back(node_index);
}
Expand Down Expand Up @@ -1889,7 +1891,7 @@ class PlannerImpl {
// 2. the consumer is in the same stream(non-cpu device), but it consumes a CPU tensor from an non-shape op.
// for example, a resize cuda kernel consumer a tensor from MemCpyToHost cuda kernel on the same stream.
// in this case, the FIFO can't guarantee the cpu tensor is ready when resize kernel is launching
OrtDevice::DeviceType output_arg_device = plan_.allocation_plan[output_arg_idx].location.Type();
OrtDevice::DeviceType output_arg_device = AllocPlan(output_arg_idx).location.Type();
WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, output_arg_device);
if ((node_stream_map_[it->Index()] != i || output_arg_device == OrtDevice::CPU) && wait_handle != nullptr) {
if (node_to_notification.find(node_index) == node_to_notification.end()) {
Expand Down
68 changes: 68 additions & 0 deletions onnxruntime/test/framework/allocation_planner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1974,6 +1974,74 @@
ASSERT_TRUE(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()) == WaitOnEPStep);
}
}

// model looks like:
// |-----------> Gather
// |-----------> Gather
// |-----------> Gather
// |-----------> Gather
// Shape ----------------> Reshape --> Shape ------------------> Reshape
// ^ ^
// InstanceNormalization ----| InstanceNormalization ------|
//
// Python script to create this model:
// def CreateModelFor19480():
// #shape->reshape->shape->reshape, 4 gather
// graphNodes = []
// graphNodes.append(h.make_node('Shape', inputs=['shape_input'], outputs=['9']))
// graphNodes.append(h.make_node('InstanceNormalization', inputs=['in0_input', 'scale0', 'B0'], outputs=['8']))
// graphNodes.append(h.make_node('Reshape', inputs=['8', '9'], outputs=['Reshape15_output']))
// graphNodes.append(h.make_node('Shape', inputs=['Reshape15_output'], outputs=['281']))
// graphNodes.append(h.make_node('InstanceNormalization', inputs=['in1_input', 'scale1', 'B1'], outputs=['293']))
// graphNodes.append(h.make_node('Reshape', inputs=['293', '281'], outputs=['output0']))
// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices1'], outputs=['output1']))
// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices2'], outputs=['output2']))
// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices3'], outputs=['output3']))
// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices4'], outputs=['output4']))
// g = h.make_graph(graphNodes, 'issue_19480',
// [h.make_tensor_value_info('shape_input', tp.FLOAT, ['batch', 128, None, None]),
// h.make_tensor_value_info('in0_input', tp.FLOAT, ['batch', 32, None]),
// h.make_tensor_value_info('scale0', tp.FLOAT, [32]),
// h.make_tensor_value_info('B0', tp.FLOAT, [32]),
// h.make_tensor_value_info('in1_input', tp.FLOAT, ['batch', 32, None]),
// h.make_tensor_value_info('scale1', tp.FLOAT, [32]),
// h.make_tensor_value_info('B1', tp.FLOAT, [32]),
// h.make_tensor_value_info('indices1', tp.INT32, []),
// h.make_tensor_value_info('indices2', tp.INT32, []),
// h.make_tensor_value_info('indices3', tp.INT32, []),
// h.make_tensor_value_info('indices4', tp.INT32, [])],
// [h.make_tensor_value_info('output0', tp.FLOAT, None),
// h.make_tensor_value_info('output1', tp.INT64, None),
// h.make_tensor_value_info('output2', tp.INT64, None),
// h.make_tensor_value_info('output3', tp.INT64, None),
// h.make_tensor_value_info('output4', tp.INT64, None)])
// model = h.make_model(g, opset_imports=[h.make_operatorsetid("", 17)], producer_name='producer_name')
// onnx.save(model, 'issue_19480.onnx')
//
TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) {
SessionOptions sess_opt;
sess_opt.graph_optimization_level = TransformerLevel::Default;

InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/issue_19480.onnx"));
auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider());
status = sess.Load();
status = sess.Initialize();
ASSERT_TRUE(status.IsOK()) << "No crash";
const SequentialExecutionPlan* plan = sess.GetSessionState().GetExecutionPlan();
ASSERT_EQ(plan->allocation_plan[14].alloc_kind, AllocKind::kReuse) << "The input of reshape and gather will reuse the output of shape";

Check warning on line 2031 in onnxruntime/test/framework/allocation_planner_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/framework/allocation_planner_test.cc:2031: Lines should be <= 120 characters long [whitespace/line_length] [2]

int gather_count = 0;
for (size_t i = 0; i < plan->execution_plan[1]->steps_.size(); i++) {
if (strstr(typeid(*(plan->execution_plan[1]->steps_[i])).name(), "LaunchKernelStep")) {
const Node* node = sess.GetSessionState().GetGraphViewer().GetNode(plan->execution_plan[1]->steps_[i]->GetNodeIndex());

Check warning on line 2036 in onnxruntime/test/framework/allocation_planner_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/framework/allocation_planner_test.cc:2036: Lines should be <= 120 characters long [whitespace/line_length] [2]
if (node->OpType() == "Gather")
gather_count++;
else
FAIL() << "CPU stream should contain only gather ops";
}
}
ASSERT_EQ(gather_count, 4) << "4 gather ops are all placed in CPU stream";
}
#endif
} // namespace test
} // namespace onnxruntime
Binary file not shown.
Loading