Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA graph support for TRT EP #16081

Merged
merged 18 commits into from
Jun 21, 2023
2 changes: 2 additions & 0 deletions cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,8 @@ if (onnxruntime_USE_TENSORRT)
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_stream_handle.h"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_stream_handle.cc"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_graph.h"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_graph.cc"
)

source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_tensorrt_cc_srcs})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ struct OrtTensorRTProviderOptionsV2 {
const char* trt_profile_min_shapes; // Specify the range of the input shapes to build the engine with
const char* trt_profile_max_shapes; // Specify the range of the input shapes to build the engine with
const char* trt_profile_opt_shapes; // Specify the range of the input shapes to build the engine with
int trt_cuda_graph_enable; // Enable CUDA graph in ORT TRT
};
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
profile_min_shapes = info.profile_min_shapes;
profile_max_shapes = info.profile_max_shapes;
profile_opt_shapes = info.profile_opt_shapes;
cuda_graph_enable_ = info.cuda_graph_enable;
} else {
try {
const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations);
Expand Down Expand Up @@ -842,6 +843,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
profile_min_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesMinShapes);
profile_max_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesMaxShapes);
profile_opt_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesOptShapes);

const std::string cuda_graph_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCudaGraphEnable);
if (!cuda_graph_enable_env.empty()) {
cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true);
}
} catch (const std::invalid_argument& ex) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what();
} catch (const std::out_of_range& ex) {
Expand Down Expand Up @@ -895,6 +901,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty();
}

if (cuda_graph_enable_) {
cuda_graph_ = std::make_unique<CUDAGraph>();
}

/*
* Parse explicit min/max/opt profile shapes from provider options.
*
Expand Down Expand Up @@ -968,7 +978,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
<< ", trt_tactic_sources: " << tactic_sources_
<< ", trt_profile_min_shapes: " << profile_min_shapes
<< ", trt_profile_max_shapes: " << profile_max_shapes
<< ", trt_profile_opt_shapes: " << profile_opt_shapes;
<< ", trt_profile_opt_shapes: " << profile_opt_shapes
<< ", trt_cuda_graph_enable: " << cuda_graph_enable_;
}

TensorrtExecutionProvider::~TensorrtExecutionProvider() {
Expand All @@ -982,6 +993,43 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() {
ReleaseTensorRTCustomOpDomainList(info_.custom_op_domain_list);
}

bool TensorrtExecutionProvider::IsGraphCaptureEnabled() const {
return cuda_graph_enable_;
}

bool TensorrtExecutionProvider::IsGraphCaptureAllowed() const {
return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_;
}

void TensorrtExecutionProvider::CaptureBegin() {
cuda_graph_->Reset();
cuda_graph_->CaptureBegin();
}

void TensorrtExecutionProvider::CaptureEnd() {
cuda_graph_->CaptureEnd();
is_graph_captured_ = true;
}

bool TensorrtExecutionProvider::IsGraphCaptured() const {
return is_graph_captured_;
}

Status TensorrtExecutionProvider::ReplayGraph() {
ORT_ENFORCE(IsGraphCaptured());
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
// Please note that CUDAGraph::Replay() is not thread safe.
// ORT TRT calls ReplayGraph() in compute_func() where synchromization is enforced due to lock_guard(),
// therefore calling CUDAGraph::Replay() here is guaranteed to be thread safe.
return cuda_graph_->Replay();
}

void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() {
// Please note that this function is not thread safe.
// ORT TRT calls this function in compute_func() where synchronization is enforced due to lock_guard(),
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
// therefore following increment is guaranteed to be thread safe.
++regular_run_count_before_graph_capture_;
}

std::vector<AllocatorPtr> TensorrtExecutionProvider::CreatePreferredAllocators() {
AllocatorCreationInfo default_memory_info(
[](OrtDevice::DeviceId device_id) { return CreateCUDAAllocator(device_id, onnxruntime::CUDA); }, device_id_);
Expand All @@ -999,6 +1047,10 @@ std::unique_ptr<IDataTransfer> TensorrtExecutionProvider::GetDataTransfer() cons
return onnxruntime::CreateGPUDataTransfer();
}

Status TensorrtExecutionProvider::OnRunStart() {
return Status::OK();
}

Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) {
if (sync_stream && external_stream_) {
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_));
Expand Down Expand Up @@ -2737,6 +2789,15 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, *max_context_mem_size_ptr).get());
}

// Start CUDA graph capture.
// Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because
// current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream.
if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) {
LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model";
cuda_graph_->SetStream(stream);
CaptureBegin();
}

// Run TRT inference
if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
Expand Down Expand Up @@ -2764,6 +2825,23 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
}
}
}

// End CUDA graph capture.
// Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture
// above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc,
// which might end up with many cuda graphs are captured by multiple threads if running with multithreading.
// It's safe to start/end CUDA graph capture in compute_func() here since the whole function is protected by the lock_guard().
if (cuda_graph_enable_ && !IsGraphCaptured()) {
if (IsGraphCaptureAllowed()) {
CaptureEnd();
// CUDA work issued to a capturing stream doesn’t actually run on the GPU,
// so run the captured graph here to actually execute the work.
ORT_RETURN_IF_ERROR(ReplayGraph());
} else {
IncrementRegularRunCountBeforeGraphCapture();
}
}

return Status::OK();
};

Expand Down
19 changes: 19 additions & 0 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "NvInfer.h"
#include "NvOnnxParser.h"
#include "core/platform/ort_mutex.h"
#include "core/providers/cuda/cuda_graph.h"
#include "tensorrt_execution_provider_info.h"

namespace onnxruntime {
Expand Down Expand Up @@ -42,6 +43,7 @@ static const std::string kExtraPluginLibPaths = "ORT_TENSORRT_EXTRA_PLUGIN_LIB_P
static const std::string kProfilesMinShapes = "ORT_TENSORRT_PROFILE_MIN_SHAPES";
static const std::string kProfilesMaxShapes = "ORT_TENSORRT_PROFILE_MAX_SHAPES";
static const std::string kProfilesOptShapes = "ORT_TENSORRT_PROFILE_OPT_SHAPES";
static const std::string kCudaGraphEnable = "ORT_TENSORRT_CUDA_GRAPH_ENABLE";
// Old env variable for backward compatibility
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
} // namespace tensorrt_env_vars
Expand Down Expand Up @@ -133,6 +135,7 @@ struct TensorrtFuncState {
int auxiliary_streams = -1;
bool filter_tactic_sources = false;
nvinfer1::TacticSources tactic_sources;
bool cuda_graph_enable = 0;
};

// Logical device representation.
Expand All @@ -153,6 +156,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) override;

Status OnRunStart() override;
Status OnRunEnd(bool sync_stream) override;

ProviderOptions GetProviderOptions() const override {
Expand All @@ -167,6 +171,10 @@ class TensorrtExecutionProvider : public IExecutionProvider {

std::vector<AllocatorPtr> CreatePreferredAllocators() override;

bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured() const override;
Status ReplayGraph() override;

private:
TensorrtExecutionProviderInfo info_;
bool external_stream_ = false;
Expand Down Expand Up @@ -204,6 +212,12 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool timing_cache_enable_ = false;
bool force_timing_cache_match_ = false;
bool detailed_build_log_ = false;
bool cuda_graph_enable_ = false;

std::unique_ptr<CUDAGraph> cuda_graph_; // ORT TRT only supports CUDA graph when whole model is supported by TRT, so simply maintaining a CUDAGraph pointer is enough (no need to maintain one CUDAGraph pointer per TRT subgraph)
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;
const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations.

std::unordered_set<std::string> control_flow_op_set_ = {"If", "Loop", "Scan"};
std::unordered_map<std::string, tensorrt_ptr::unique_pointer<nvonnxparser::IParser>> parsers_;
Expand Down Expand Up @@ -254,5 +268,10 @@ class TensorrtExecutionProvider : public IExecutionProvider {

/**Check whether all the nodes of subgraph are supported*/
bool IsSubGraphFullySupported(SubGraphCollection_t supported_nodes_vector, const int number_of_ort_nodes) const;

bool IsGraphCaptureAllowed() const;
void CaptureBegin();
void CaptureEnd();
void IncrementRegularRunCountBeforeGraphCapture();
};
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ constexpr const char* kExtraPluginLibPaths = "trt_extra_plugin_lib_paths";
constexpr const char* kProfilesMinShapes = "trt_profile_min_shapes";
constexpr const char* kProfilesMaxShapes = "trt_profile_max_shapes";
constexpr const char* kProfilesOptShapes = "trt_profile_opt_shapes";
constexpr const char* kCudaGraphEnable = "trt_cuda_graph_enable";
} // namespace provider_option_names
} // namespace tensorrt

Expand Down Expand Up @@ -91,6 +92,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
.AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMinShapes, info.profile_min_shapes)
.AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes)
.AddAssignmentToReference(tensorrt::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes)
.AddAssignmentToReference(tensorrt::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable)
.Parse(options)); // add new provider option here.

return info;
Expand Down Expand Up @@ -129,6 +131,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kProfilesMinShapes, MakeStringWithClassicLocale(info.profile_min_shapes)},
{tensorrt::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)},
{tensorrt::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)},
{tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)},
};
return options;
}
Expand Down Expand Up @@ -175,6 +178,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
{tensorrt::provider_option_names::kProfilesMinShapes, kProfilesMinShapes_},
{tensorrt::provider_option_names::kProfilesMaxShapes, kProfilesMaxShapes_},
{tensorrt::provider_option_names::kProfilesOptShapes, kProfilesOptShapes_},
{tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.trt_cuda_graph_enable)},
};
return options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct TensorrtExecutionProviderInfo {
std::string profile_min_shapes{""};
std::string profile_max_shapes{""};
std::string profile_opt_shapes{""};
bool cuda_graph_enable{false};

static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ struct Tensorrt_Provider : Provider {
info.profile_min_shapes = options.trt_profile_min_shapes == nullptr ? "" : options.trt_profile_min_shapes;
info.profile_max_shapes = options.trt_profile_max_shapes == nullptr ? "" : options.trt_profile_max_shapes;
info.profile_opt_shapes = options.trt_profile_opt_shapes == nullptr ? "" : options.trt_profile_opt_shapes;
info.cuda_graph_enable = options.trt_cuda_graph_enable != 0;

common::Status status = CreateTensorRTCustomOpDomainList(info);
if (!status.IsOK()) {
Expand Down Expand Up @@ -229,6 +230,8 @@ struct Tensorrt_Provider : Provider {
dest[str_size] = '\0';
trt_options.trt_profile_opt_shapes = (const char*)dest;
}

trt_options.trt_cuda_graph_enable = internal_options.cuda_graph_enable;
}

ProviderOptions GetProviderOptions(const void* provider_options) override {
Expand Down
Loading