Skip to content

Commit

Permalink
The transformer of memcpy is needed for ROCm EP and MIGraphX EP when …
Browse files Browse the repository at this point in the history
…fallbacking CPU happens (#10522)

Co-authored-by: Weixing Zhang <[email protected]>
  • Loading branch information
weixingzhang and weixingzhang authored Feb 11, 2022
1 parent f92e47e commit 2002a96
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions onnxruntime/core/optimizer/transformer_memcpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
InitializedTensorSet& initializers_consumed) {
auto node_provider_type = node.GetExecutionProviderType();
if ((node_provider_type == provider_) ||
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_)) {
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
(node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) {
provider_nodes_.insert(&node);
// note KernelCreateInfo might be nullptr for custom kernel
const KernelCreateInfo* kci = nullptr;
Expand Down Expand Up @@ -281,7 +282,9 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co
if (arg_input_index == -1 && arg_output_index == -1)
continue;
auto node_provider_type = it.GetExecutionProviderType();
if ((node_provider_type == provider_) || (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_)) {
if ((node_provider_type == provider_) ||
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
(node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) {
const KernelCreateInfo* kci = nullptr;
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, &kci));
if (arg_input_index != -1) {
Expand Down

0 comments on commit 2002a96

Please sign in to comment.