Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA EP] Add warning logs when adding memcpy nodes #18032

Merged
merged 2 commits into from
Oct 24, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions onnxruntime/core/optimizer/transformer_memcpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "transformer_memcpy.h"
#include "core/common/logging/logging.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/execution_providers.h"
#include "core/framework/utils.h"
Expand All @@ -16,12 +17,12 @@ class TransformerMemcpyImpl {
TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider)
: graph_(graph), provider_(provider) {}

bool ModifyGraph(const KernelRegistryManager& schema_registries);
bool ModifyGraph(const KernelRegistryManager& schema_registries, const logging::Logger& logger, int& copy_node_counter);

private:
void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed);
void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries);
void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input);
void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger);
bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed);

private:
Expand Down Expand Up @@ -61,11 +62,21 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st

// very simple GraphTransformer that uses TransformerMemcpyImpl for each graph
// and mainly provides the subgraph recursion functionality
common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
for (auto& provider : provider_types_) {
if (!utils::ProviderIsCpuBased(provider)) {
TransformerMemcpyImpl copy_impl(graph, provider);
auto current_modified = copy_impl.ModifyGraph(registry_manager_);

int copy_node_counter = 0;
auto current_modified = copy_impl.ModifyGraph(registry_manager_, logger, copy_node_counter);
if (copy_node_counter > 0 && provider == kCudaExecutionProvider) {
LOGS(logger, WARNING) << copy_node_counter << " Memcpy nodes are added to the graph " << graph.Name()
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
<< " for " << provider
<< ". It might have negative impact on performance (including unable to run CUDA graph). "
<< "Set session_options.log_severity_level=1 to see the detail logs before this message.";
}

modified = modified || current_modified;
break;
}
Expand Down Expand Up @@ -111,7 +122,9 @@ This transformer does not currently optimize copies between, e.g., two different

*/

bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_registries) {
bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_registries,
const logging::Logger& logger,
int& copy_node_counter) {
bool modified = false;
InitializedTensorSet initializers_consumed;
// find defs that require copy
Expand All @@ -137,19 +150,22 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
// For inputs we need to create a copy node only when the input is connected to both provider
// and non-provider nodes. Otherwise utils::CopyInputsAcrossDevices() will do the job.
if (provider_input_defs_.count(arg) && non_provider_input_defs_.count(arg)) {
AddCopyNode(const_cast<onnxruntime::NodeArg*>(arg), true);
AddCopyNode(const_cast<onnxruntime::NodeArg*>(arg), true, logger);
copy_node_counter++;
modified = true;
}

for (auto arg : non_provider_output_defs_)
if (provider_input_defs_.count(arg)) {
AddCopyNode(arg, true);
AddCopyNode(arg, true, logger);
copy_node_counter++;
modified = true;
}

for (auto arg : provider_output_defs_)
if (non_provider_input_defs_.count(arg)) {
AddCopyNode(arg, false);
AddCopyNode(arg, false, logger);
copy_node_counter++;
modified = true;
}

Expand All @@ -176,7 +192,8 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
// (the name will be the same as the parent node's implicit input)
const auto* node_arg_in_current_graph_level = *provider_input_defs_.find(arg);

AddCopyNode(const_cast<onnxruntime::NodeArg*>(node_arg_in_current_graph_level), true);
AddCopyNode(const_cast<onnxruntime::NodeArg*>(node_arg_in_current_graph_level), true, logger);
copy_node_counter++;
modified = true;
}
}
Expand Down Expand Up @@ -297,7 +314,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co
}
}

void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input) {
void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger) {
// create unique name for new def
std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_);

Expand All @@ -309,6 +326,9 @@ void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input
std::string new_node_name = graph_.GenerateNodeName("Memcpy");

const auto op_name = is_input ? "MemcpyFromHost" : "MemcpyToHost";
LOGS(logger, INFO) << "Add " << op_name << (is_input ? " after " : " before ") << arg->Name()
<< " for " << provider_;

auto& new_node = graph_.AddNode(new_node_name, op_name, "Copy from/to host memory",
std::vector<onnxruntime::NodeArg*>{src_arg},
std::vector<onnxruntime::NodeArg*>{dst_arg});
Expand Down