Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola committed Sep 25, 2024
1 parent 03b094b commit 5e15b59
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 23 deletions.
9 changes: 9 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,15 @@ class IExecutionProvider {
return default_device_;
};

/**
* Return the appropriate OrtDevice object given OrtMemType for allocating graph inputs, including initializers.
* It returns the same allocator as GetOrtDeviceByMemType by default, but it can be overriden by execution providers

Check warning on line 329 in include/onnxruntime/core/framework/execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "overriden" is a misspelling of "overridden" Raw Output: ./include/onnxruntime/core/framework/execution_provider.h:329:86: "overriden" is a misspelling of "overridden"
* if needed.
*/
virtual OrtDevice GetOrtDeviceByMemTypeForGraphInput(OrtMemType mem_type) const {
return GetOrtDeviceByMemType(mem_type);
};

Check warning on line 334 in include/onnxruntime/core/framework/execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: include/onnxruntime/core/framework/execution_provider.h:334: You don't need a ; after a } [readability/braces] [4]

/**
* Create Preferred allocators for the current Execution Provider
* This function is a stateless function which creates new instances of Allocator, without storing them in EP.
Expand Down
26 changes: 4 additions & 22 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -767,16 +767,7 @@ class PlannerImpl {

if (!is_implicit_input) {
OrtMemType mem_type = p_kernel_def->InputMemoryType(arg_idx);
auto ort_device = exec_provider->GetOrtDeviceByMemType(mem_type);

#ifdef USE_DML
// DML uses a different allocator for weights and inputs that allocates unpooled memory
if (p_kernel_def->Provider() == onnxruntime::kDmlExecutionProvider && mem_type == OrtMemType::OrtMemTypeDefault) {
ort_device = OrtDevice(ort_device.Type(), OrtDevice::MemType::DML_UNPOOLED, ort_device.Id());
}
#endif

plan_.SetLocation(static_cast<size_t>(index), ort_device);
plan_.SetLocation(static_cast<size_t>(index), exec_provider->GetOrtDeviceByMemTypeForGraphInput(mem_type));

Check warning on line 770 in onnxruntime/core/framework/allocation_planner.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/allocation_planner.cc:770: Lines should be <= 120 characters long [whitespace/line_length] [2]
set_node_arg_has_explicit_consumer.insert(index);
} else { // implicit input
// Only process an implicit input if there are explicit consumers at this graph level
Expand Down Expand Up @@ -888,23 +879,14 @@ class PlannerImpl {
return Status::OK();
}

OrtDevice GetLocationForNodeInput(size_t input_index, const Node& node, const KernelCreateInfoMap& kernel_create_info_map) {
OrtDevice GetLocationForNodeWeightInput(size_t input_index, const Node& node, const KernelCreateInfoMap& kernel_create_info_map) {

Check warning on line 882 in onnxruntime/core/framework/allocation_planner.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/allocation_planner.cc:882: Lines should be <= 120 characters long [whitespace/line_length] [2]
auto* p_provider = execution_providers_.Get(node);
ORT_ENFORCE(p_provider);

const KernelCreateInfo& kernel_create_info = GetKernelCreateInfo(kernel_create_info_map, node.Index());

// weights are not output from any node, so it's OK to put its location on CPU provider
auto ort_device = p_provider->GetOrtDeviceByMemType(utils::IsInputOnCpu(node, &kernel_create_info, input_index) ? OrtMemTypeCPUInput : OrtMemTypeDefault);

#ifdef USE_DML
// DML uses a different allocator for weights and inputs that allocates unpooled memory
if (node.GetExecutionProviderType() == onnxruntime::kDmlExecutionProvider && ort_device.MemType() == OrtDevice::MemType::DEFAULT) {
ort_device = OrtDevice(ort_device.Type(), OrtDevice::MemType::DML_UNPOOLED, ort_device.Id());
}
#endif

return ort_device;
return p_provider->GetOrtDeviceByMemTypeForGraphInput(utils::IsInputOnCpu(node, &kernel_create_info, input_index) ? OrtMemTypeCPUInput : OrtMemTypeDefault);

Check warning on line 889 in onnxruntime/core/framework/allocation_planner.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/allocation_planner.cc:889: Lines should be <= 120 characters long [whitespace/line_length] [2]
}

std::vector<std::pair<int, int>> GetAliasMap(const Node& node, const KernelCreateInfo& kernel_create_info) {
Expand Down Expand Up @@ -1000,7 +982,7 @@ class PlannerImpl {
// (subgraphs) is okay and utils::CopyInputsAcrossDevices() will take it to
// the right device before subgraph execution.
locations[wt_index].emplace_back(
GetLocationForNodeInput(node_input_index, node, kernel_create_info_map));
GetLocationForNodeWeightInput(node_input_index, node, kernel_create_info_map));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ namespace Dml
bool enableGraphCapture,
bool enableSyncSpinning,
bool disableMemoryArena) :
IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0))
IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)),
m_unpooledDevice(OrtDevice::GPU, OrtDevice::MemType::DML_UNPOOLED, 0)
{
D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue();
if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,19 @@ namespace Dml
return m_impl->ReplayGraph(graph_annotation_id);
}

OrtDevice GetOrtDeviceByMemTypeForGraphInput(OrtMemType mem_type) const final
{
if (mem_type == OrtMemTypeDefault)
{
return m_unpooledDevice;
}

return GetOrtDeviceByMemType(mem_type);
};

private:
ComPtr<ExecutionProviderImpl> m_impl;
const OrtDevice m_unpooledDevice;
};

} // namespace Dml

0 comments on commit 5e15b59

Please sign in to comment.