Skip to content

Commit

Permalink
add test case
Browse files Browse the repository at this point in the history
  • Loading branch information
Lei Cao committed Feb 20, 2024
1 parent e99a8cd commit 8262d0b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 29 deletions.
50 changes: 21 additions & 29 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ class PlannerImpl {
// upstream_node_0 and upstream_node_1 are the immmediate upstream nodes of downstream_node
// upstream_node_2 is the immediate nodes ahead of downstream_node in the same logic stream
InlinedHashMap<onnxruntime::NodeIndex, InlinedHashSet<onnxruntime::NodeIndex>> dependence_graph_;
InlinedHashMap<onnxruntime::OrtValueIndex, InlinedHashSet<onnxruntime::NodeIndex>> value_consumer_map_;
InlinedHashMap<onnxruntime::OrtValueIndex, onnxruntime::NodeIndex> value_node_map_;

// OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation:
Expand All @@ -200,6 +199,8 @@ class PlannerImpl {
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
OrtValueIndex inplace_reused_buffer_index = -1; // index of original buffer to reuse inplace
#endif
// reused_buffer_index_per_stream will be reset and used in each ComputeSingleStreamReusePlan()
// reused_buffer_index will be updated with reused_buffer_index_per_stream and preserved for the following GenerateDeallocationPlan()

Check warning on line 203 in onnxruntime/core/framework/allocation_planner.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/allocation_planner.cc:203: Lines should be <= 120 characters long [whitespace/line_length] [2]
OrtValueIndex reused_buffer_index_per_stream;
};

Expand Down Expand Up @@ -297,7 +298,7 @@ class PlannerImpl {
}
#endif

// Find if there exists some input tensor that we can use in-place for output_arg_num-th input in the node.
// Find if there exists some input tensor that we can use in-place for output_arg_num-th output in the node.
bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input,
bool* is_strided_tensor) {
*is_strided_tensor = false;
Expand Down Expand Up @@ -359,7 +360,6 @@ class PlannerImpl {
auto p_input_arg = input_args[pair.first];
if (p_input_arg->Exists()) {
auto input_arg_index = Index(p_input_arg->Name());
//auto original = Buffer(input_arg_index);
auto original = ort_value_info_[input_arg_index].reused_buffer_index_per_stream;
if (1 == UseCount(original)) {
if (SameSize(*p_input_arg, *p_output_arg)) {
Expand Down Expand Up @@ -1068,7 +1068,8 @@ class PlannerImpl {

// build the consumer list for each value
int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1;
value_consumer_map_.reserve(num_ml_values);
InlinedHashMap<onnxruntime::OrtValueIndex, InlinedHashSet<onnxruntime::NodeIndex>> value_consumer_map;
value_consumer_map.reserve(num_ml_values);

// iterate each stream from back, so the first element is the last consumer in single stream case
for (auto& stream : stream_nodes_) {
Expand All @@ -1084,7 +1085,7 @@ class PlannerImpl {
auto origin = Buffer(value_idx);
if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) {
// add current node as consumer for origin buffer
value_consumer_map_[origin].insert(node_index);
value_consumer_map[origin].insert(node_index);
}
}
return Status::OK();
Expand Down Expand Up @@ -1141,8 +1142,8 @@ class PlannerImpl {
std::cout << p_input_arg->Name() << " reused by " << p_output_arg->Name() << " as input" << std::endl;
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = reusable_input;
value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(),
value_consumer_map_[output_idx_global].end());
value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(),
value_consumer_map[output_idx_global].end());
reused.insert(reusable_input);
found_reusable = true;
break;
Expand Down Expand Up @@ -1171,8 +1172,8 @@ class PlannerImpl {
allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate) {
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = reusable_input;
value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(),
value_consumer_map_[output_idx_global].end());
value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(),
value_consumer_map[output_idx_global].end());
reused.insert(reusable_input);
continue;
} // if
Expand All @@ -1190,11 +1191,11 @@ class PlannerImpl {
OrtValueIndex input_arg_index{};
if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() &&
allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) {
if (value_consumer_map_[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) {
if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) {
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = input_arg_index;
value_consumer_map_[input_arg_index].insert(value_consumer_map_[output_idx_global].begin(),
value_consumer_map_[output_idx_global].end());
value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(),
value_consumer_map[output_idx_global].end());
reused.insert(input_arg_index);
}
}
Expand Down Expand Up @@ -1269,7 +1270,7 @@ class PlannerImpl {
}

bool all_covered = true;
for (auto consumer : value_consumer_map_[output_idx_global]) {
for (auto consumer : value_consumer_map[output_idx_global]) {
if (deps->find(consumer) == deps->end()) {
all_covered = false;
break;
Expand All @@ -1280,9 +1281,9 @@ class PlannerImpl {
allocation_plan[downstream_value].reused_buffer = output_idx_global;
get_reused = true;
// add new consumer for the value to be reused
value_consumer_map_[output_idx_global].insert(value_node_map_[downstream_value]);
value_consumer_map_[output_idx_global].insert(value_consumer_map_[downstream_value].begin(),
value_consumer_map_[downstream_value].end());
value_consumer_map[output_idx_global].insert(value_node_map_[downstream_value]);
value_consumer_map[output_idx_global].insert(value_consumer_map[downstream_value].begin(),
value_consumer_map[downstream_value].end());
node_iter = size_iter->second.erase(node_iter);
if (size_iter->second.empty()) {
local_iter->second.erase(size_iter);
Expand Down Expand Up @@ -1339,22 +1340,16 @@ class PlannerImpl {
// use parallel execution context to generate a baseline first (no memory sharing)
context_ = gsl::not_null<const ISequentialPlannerContext*>(&no_mem_reuse_context);
}
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
// copy the use counts to a vector, before computing reuse
std::vector<int> ort_value_usecount;
ort_value_usecount.reserve(ort_value_info_.size());
#endif
ORT_RETURN_IF_ERROR(ComputeReuseCount());
for (auto& ort_value_info : ort_value_info_) ort_value_usecount.push_back(ort_value_info.usecount);
for (size_t i = 0; i < stream_nodes_.size(); ++i) {
// compute use count first
ORT_RETURN_IF_ERROR(ComputeReuseCount());
for (size_t j = 0; j < ort_value_info_.size(); j++) ort_value_info_[j].reused_buffer_index_per_stream = static_cast<OrtValueIndex>(j);

Check warning on line 1349 in onnxruntime/core/framework/allocation_planner.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/allocation_planner.cc:1349: Lines should be <= 120 characters long [whitespace/line_length] [2]
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
if (i == 0) {
for (auto ort_value_info : ort_value_info_) {
ort_value_usecount.push_back(ort_value_info.usecount);
}
if (i > 0) {
for (size_t k = 0; k < ort_value_usecount.size(); k++) UseCount(static_cast<OrtValueIndex>(k)) = ort_value_usecount[k];

Check warning on line 1351 in onnxruntime/core/framework/allocation_planner.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/allocation_planner.cc:1351: Lines should be <= 120 characters long [whitespace/line_length] [2]
}
#endif
ORT_RETURN_IF_ERROR(ComputeSingleStreamReusePlan(i));
ClearUseCount();
freelist_.clear(); // DONOT share freelist across streams
Expand Down Expand Up @@ -1472,7 +1467,6 @@ class PlannerImpl {
for (auto node_input : pnode->InputDefs()) {
if (node_input->Exists()) {
auto& sym = node_input->Name();
//auto original = Buffer(Index(sym));
auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream;
// The index will be -1 if it's an initializer that was removed as part of a temporary workaround.
// See comments in the OrtValueInfo definition.
Expand All @@ -1485,7 +1479,6 @@ class PlannerImpl {
for (auto node_input : pnode->ImplicitInputDefs()) {
if (node_input->Exists()) {
auto& sym = node_input->Name();
//auto original = Buffer(Index(sym));
auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream;
// The index will be -1 if it's an initializer that was removed as part of a temporary workaround.
// See comments in the OrtValueInfo definition.
Expand All @@ -1499,7 +1492,6 @@ class PlannerImpl {
for (auto node_output : pnode->OutputDefs()) {
if (node_output->Exists()) {
auto& sym = node_output->Name();
//auto original = Buffer(Index(sym));
auto original = ort_value_info_[Index(sym)].reused_buffer_index_per_stream;
// The index will be -1 if it's an initializer that was removed as part of a temporary workaround.
// See comments in the OrtValueInfo definition.
Expand Down
25 changes: 25 additions & 0 deletions onnxruntime/test/framework/allocation_planner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1974,6 +1974,31 @@ TEST_F(PlannerTest, TestCpuIf) {
ASSERT_TRUE(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()) == WaitOnEPStep);
}
}

TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) {
SessionOptions sess_opt;
sess_opt.graph_optimization_level = TransformerLevel::Default;

InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/issue_19480.onnx"));
auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider());
status = sess.Load();
status = sess.Initialize();
ASSERT_TRUE(status.IsOK()) << "No crash";
const SequentialExecutionPlan* plan = sess.GetSessionState().GetExecutionPlan();
ASSERT_EQ(plan->allocation_plan[14].alloc_kind, AllocKind::kReuse) << "The input of reshape and gather will reuse the output of shape";

Check warning on line 1988 in onnxruntime/test/framework/allocation_planner_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/framework/allocation_planner_test.cc:1988: Lines should be <= 120 characters long [whitespace/line_length] [2]

int gather_count = 0;
for (size_t i = 0; i < plan->execution_plan[1]->steps_.size(); i++) {
if (strstr(typeid(*(plan->execution_plan[1]->steps_[i])).name(), "LaunchKernelStep")) {
const Node* node = sess.GetSessionState().GetGraphViewer().GetNode(plan->execution_plan[1]->steps_[i]->GetNodeIndex());

Check warning on line 1993 in onnxruntime/test/framework/allocation_planner_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/framework/allocation_planner_test.cc:1993: Lines should be <= 120 characters long [whitespace/line_length] [2]
if (node->OpType() == "Gather")
gather_count++;
else
FAIL() << "CPU stream should contain only gather ops";
}
}
ASSERT_EQ(gather_count, 4) << "4 gather ops are all placed in CPU stream";
}
#endif
} // namespace test
} // namespace onnxruntime
Binary file not shown.

0 comments on commit 8262d0b

Please sign in to comment.