-
Notifications
You must be signed in to change notification settings - Fork 506
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add config for TensorRT and CUDA execution provider (#992)
Signed-off-by: [email protected] <[email protected]> Signed-off-by: [email protected] <[email protected]>
- Loading branch information
1 parent
f5e9a16
commit 55decb7
Showing
21 changed files
with
622 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
// sherpa-onnx/csrc/provider-config.cc | ||
// | ||
// Copyright (c) 2024 Uniphore (Author: Manickavela) | ||
|
||
#include "sherpa-onnx/csrc/provider-config.h" | ||
|
||
#include <sstream> | ||
|
||
#include "sherpa-onnx/csrc/file-utils.h" | ||
#include "sherpa-onnx/csrc/macros.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
void CudaConfig::Register(ParseOptions *po) { | ||
po->Register("cuda-cudnn-conv-algo-search", &cudnn_conv_algo_search, | ||
"CuDNN convolution algrorithm search"); | ||
} | ||
|
||
bool CudaConfig::Validate() const { | ||
if (cudnn_conv_algo_search < 1 || cudnn_conv_algo_search > 3) { | ||
SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not a valid option." | ||
"Options : [1,3]. Check OnnxRT docs", | ||
cudnn_conv_algo_search); | ||
return false; | ||
} | ||
return true; | ||
} | ||
|
||
std::string CudaConfig::ToString() const { | ||
std::ostringstream os; | ||
|
||
os << "CudaConfig("; | ||
os << "cudnn_conv_algo_search=" << cudnn_conv_algo_search << ")"; | ||
|
||
return os.str(); | ||
} | ||
|
||
void TensorrtConfig::Register(ParseOptions *po) { | ||
po->Register("trt-max-workspace-size", &trt_max_workspace_size, | ||
"Set TensorRT EP GPU memory usage limit."); | ||
po->Register("trt-max-partition-iterations", &trt_max_partition_iterations, | ||
"Limit partitioning iterations for model conversion."); | ||
po->Register("trt-min-subgraph-size", &trt_min_subgraph_size, | ||
"Set minimum size for subgraphs in partitioning."); | ||
po->Register("trt-fp16-enable", &trt_fp16_enable, | ||
"Enable FP16 precision for faster performance."); | ||
po->Register("trt-detailed-build-log", &trt_detailed_build_log, | ||
"Enable detailed logging of build steps."); | ||
po->Register("trt-engine-cache-enable", &trt_engine_cache_enable, | ||
"Enable caching of TensorRT engines."); | ||
po->Register("trt-timing-cache-enable", &trt_timing_cache_enable, | ||
"Enable use of timing cache to speed up builds."); | ||
po->Register("trt-engine-cache-path", &trt_engine_cache_path, | ||
"Set path to store cached TensorRT engines."); | ||
po->Register("trt-timing-cache-path", &trt_timing_cache_path, | ||
"Set path for storing timing cache."); | ||
po->Register("trt-dump-subgraphs", &trt_dump_subgraphs, | ||
"Dump optimized subgraphs for debugging."); | ||
} | ||
|
||
bool TensorrtConfig::Validate() const { | ||
if (trt_max_workspace_size < 0) { | ||
SHERPA_ONNX_LOGE("trt_max_workspace_size: %d is not valid.", | ||
trt_max_workspace_size); | ||
return false; | ||
} | ||
if (trt_max_partition_iterations < 0) { | ||
SHERPA_ONNX_LOGE("trt_max_partition_iterations: %d is not valid.", | ||
trt_max_partition_iterations); | ||
return false; | ||
} | ||
if (trt_min_subgraph_size < 0) { | ||
SHERPA_ONNX_LOGE("trt_min_subgraph_size: %d is not valid.", | ||
trt_min_subgraph_size); | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
std::string TensorrtConfig::ToString() const { | ||
std::ostringstream os; | ||
|
||
os << "TensorrtConfig("; | ||
os << "trt_max_workspace_size=" << trt_max_workspace_size << ", "; | ||
os << "trt_max_partition_iterations=" | ||
<< trt_max_partition_iterations << ", "; | ||
os << "trt_min_subgraph_size=" << trt_min_subgraph_size << ", "; | ||
os << "trt_fp16_enable=\"" | ||
<< (trt_fp16_enable? "True" : "False") << "\", "; | ||
os << "trt_detailed_build_log=\"" | ||
<< (trt_detailed_build_log? "True" : "False") << "\", "; | ||
os << "trt_engine_cache_enable=\"" | ||
<< (trt_engine_cache_enable? "True" : "False") << "\", "; | ||
os << "trt_engine_cache_path=\"" | ||
<< trt_engine_cache_path.c_str() << "\", "; | ||
os << "trt_timing_cache_enable=\"" | ||
<< (trt_timing_cache_enable? "True" : "False") << "\", "; | ||
os << "trt_timing_cache_path=\"" | ||
<< trt_timing_cache_path.c_str() << "\","; | ||
os << "trt_dump_subgraphs=\"" | ||
<< (trt_dump_subgraphs? "True" : "False") << "\" )"; | ||
return os.str(); | ||
} | ||
|
||
void ProviderConfig::Register(ParseOptions *po) { | ||
cuda_config.Register(po); | ||
trt_config.Register(po); | ||
|
||
po->Register("device", &device, "GPU device index for CUDA and Trt EP"); | ||
po->Register("provider", &provider, | ||
"Specify a provider to use: cpu, cuda, coreml"); | ||
} | ||
|
||
bool ProviderConfig::Validate() const { | ||
if (device < 0) { | ||
SHERPA_ONNX_LOGE("device: '%d' is invalid.", device); | ||
return false; | ||
} | ||
|
||
if (provider == "cuda" && !cuda_config.Validate()) { | ||
return false; | ||
} | ||
|
||
if (provider == "trt" && !trt_config.Validate()) { | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
std::string ProviderConfig::ToString() const { | ||
std::ostringstream os; | ||
|
||
os << "ProviderConfig("; | ||
os << "device=" << device << ", "; | ||
os << "provider=\"" << provider << "\", "; | ||
os << "cuda_config=" << cuda_config.ToString() << ", "; | ||
os << "trt_config=" << trt_config.ToString() << ")"; | ||
return os.str(); | ||
} | ||
|
||
} // namespace sherpa_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
// sherpa-onnx/csrc/provider-config.h | ||
// | ||
// Copyright (c) 2024 Uniphore (Author: Manickavela) | ||
|
||
#ifndef SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_ | ||
#define SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_ | ||
|
||
#include <string> | ||
|
||
#include "sherpa-onnx/csrc/parse-options.h" | ||
#include "sherpa-onnx/csrc/macros.h" | ||
#include "onnxruntime_cxx_api.h" // NOLINT | ||
|
||
namespace sherpa_onnx { | ||
|
||
struct CudaConfig { | ||
int32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; | ||
|
||
CudaConfig() = default; | ||
explicit CudaConfig(int32_t cudnn_conv_algo_search) | ||
: cudnn_conv_algo_search(cudnn_conv_algo_search) {} | ||
|
||
void Register(ParseOptions *po); | ||
bool Validate() const; | ||
|
||
std::string ToString() const; | ||
}; | ||
|
||
struct TensorrtConfig { | ||
int32_t trt_max_workspace_size = 2147483647; | ||
int32_t trt_max_partition_iterations = 10; | ||
int32_t trt_min_subgraph_size = 5; | ||
bool trt_fp16_enable = true; | ||
bool trt_detailed_build_log = false; | ||
bool trt_engine_cache_enable = true; | ||
bool trt_timing_cache_enable = true; | ||
std::string trt_engine_cache_path = "."; | ||
std::string trt_timing_cache_path = "."; | ||
bool trt_dump_subgraphs = false; | ||
|
||
TensorrtConfig() = default; | ||
TensorrtConfig(int32_t trt_max_workspace_size, | ||
int32_t trt_max_partition_iterations, | ||
int32_t trt_min_subgraph_size, | ||
bool trt_fp16_enable, | ||
bool trt_detailed_build_log, | ||
bool trt_engine_cache_enable, | ||
bool trt_timing_cache_enable, | ||
const std::string &trt_engine_cache_path, | ||
const std::string &trt_timing_cache_path, | ||
bool trt_dump_subgraphs) | ||
: trt_max_workspace_size(trt_max_workspace_size), | ||
trt_max_partition_iterations(trt_max_partition_iterations), | ||
trt_min_subgraph_size(trt_min_subgraph_size), | ||
trt_fp16_enable(trt_fp16_enable), | ||
trt_detailed_build_log(trt_detailed_build_log), | ||
trt_engine_cache_enable(trt_engine_cache_enable), | ||
trt_timing_cache_enable(trt_timing_cache_enable), | ||
trt_engine_cache_path(trt_engine_cache_path), | ||
trt_timing_cache_path(trt_timing_cache_path), | ||
trt_dump_subgraphs(trt_dump_subgraphs) {} | ||
|
||
void Register(ParseOptions *po); | ||
bool Validate() const; | ||
|
||
std::string ToString() const; | ||
}; | ||
|
||
struct ProviderConfig { | ||
TensorrtConfig trt_config; | ||
CudaConfig cuda_config; | ||
std::string provider = "cpu"; | ||
int32_t device = 0; | ||
// device only used for cuda and trt | ||
|
||
ProviderConfig() = default; | ||
ProviderConfig(const std::string &provider, | ||
int32_t device) | ||
: provider(provider), device(device) {} | ||
ProviderConfig(const TensorrtConfig &trt_config, | ||
const CudaConfig &cuda_config, | ||
const std::string &provider, | ||
int32_t device) | ||
: trt_config(trt_config), cuda_config(cuda_config), | ||
provider(provider), device(device) {} | ||
|
||
void Register(ParseOptions *po); | ||
bool Validate() const; | ||
|
||
std::string ToString() const; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.