diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index 93370904e50e1..b37c0cac3cd93 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -189,7 +189,8 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg InitializedTensorSet& initializers_consumed) { auto node_provider_type = node.GetExecutionProviderType(); if ((node_provider_type == provider_) || - (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_)) { + (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) || + (node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) { provider_nodes_.insert(&node); // note KernelCreateInfo might be nullptr for custom kernel const KernelCreateInfo* kci = nullptr; @@ -281,7 +282,9 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co if (arg_input_index == -1 && arg_output_index == -1) continue; auto node_provider_type = it.GetExecutionProviderType(); - if ((node_provider_type == provider_) || (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_)) { + if ((node_provider_type == provider_) || + (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) || + (node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) { const KernelCreateInfo* kci = nullptr; ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, &kci)); if (arg_input_index != -1) {