-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable TRT provider option configuration for C# #7179
Changes from 11 commits
b882a42
fb79c1d
5328c03
5e6e233
2074557
f016bc4
d274fe3
767083f
7f3a544
fb5b4a3
9af54d0
ee998f8
58f2b2f
ba8b20e
478a81c
1cb12c5
532c899
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,7 @@ public class SessionOptions : SafeHandle | |
{ | ||
// Delay-loaded CUDA or cuDNN DLLs. Currently, delayload is disabled. See cmake/CMakeLists.txt for more information. | ||
private static string[] cudaDelayLoadedLibs = { }; | ||
private static string[] trtDelayLoadedLibs = { }; | ||
|
||
#region Constructor and Factory methods | ||
|
||
|
@@ -75,6 +76,71 @@ public static SessionOptions MakeSessionOptionWithCudaProvider(int deviceId = 0) | |
return options; | ||
} | ||
|
||
/// <summary> | ||
/// A helper method to construct a SessionOptions object for TensorRT execution. | ||
/// Use only if CUDA/TensorRT are installed and you have the onnxruntime package specific to this Execution Provider. | ||
/// </summary> | ||
/// <param name="deviceId"></param> | ||
/// <returns>A SessionsOptions() object configured for execution on deviceId</returns> | ||
public static SessionOptions MakeSessionOptionWithTensorrtProvider(int deviceId = 0) | ||
{ | ||
CheckTensorrtExecutionProviderDLLs(); | ||
SessionOptions options = new SessionOptions(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Tensorrt(options.Handle, deviceId)); | ||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CUDA(options.Handle, deviceId)); | ||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options.Handle, 1)); | ||
return options; | ||
} | ||
|
||
/// <summary> | ||
/// A helper method to construct a SessionOptions object for TensorRT execution. | ||
/// Use only if CUDA/TensorRT are installed and you have the onnxruntime package specific to this Execution Provider. | ||
/// </summary> | ||
/// <param name="trt_options">Provider Options for TensorRT EP.</param> | ||
/// <returns>A SessionsOptions() object configured for execution on deviceId</returns> | ||
public static SessionOptions MakeSessionOptionWithTensorrtProvider(OrtTensorRTProviderOptions trt_options) | ||
{ | ||
CheckTensorrtExecutionProviderDLLs(); | ||
SessionOptions options = new SessionOptions(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
OrtTensorRTProviderOptionsNative trt_options_native; | ||
trt_options_native.device_id = trt_options.device_id; | ||
trt_options_native.has_user_compute_stream = 0; | ||
trt_options_native.user_compute_stream = IntPtr.Zero; | ||
trt_options_native.has_trt_options = trt_options.has_trt_options; | ||
if ((ulong)trt_options.trt_max_workspace_size > (1 << 30)) | ||
{ | ||
trt_options_native.trt_max_workspace_size = (UIntPtr)(1 << 30); | ||
} | ||
else | ||
{ | ||
trt_options_native.trt_max_workspace_size = trt_options.trt_max_workspace_size; | ||
} | ||
trt_options_native.trt_fp16_enable = trt_options.trt_fp16_enable; | ||
trt_options_native.trt_int8_enable = trt_options.trt_int8_enable; | ||
var tableNamePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(trt_options.trt_int8_calibration_table_name), GCHandleType.Pinned); | ||
using (var pinnedSettingsName = new PinnedGCHandle(tableNamePinned)) | ||
{ | ||
trt_options_native.trt_int8_calibration_table_name = pinnedSettingsName.Pointer; | ||
} | ||
trt_options_native.trt_int8_use_native_calibration_table = trt_options.trt_int8_use_native_calibration_table; | ||
trt_options_native.trt_max_partition_iterations = trt_options.trt_max_partition_iterations; | ||
trt_options_native.trt_min_subgraph_size = trt_options.trt_min_subgraph_size; | ||
trt_options_native.trt_dump_subgraphs = trt_options.trt_dump_subgraphs; | ||
trt_options_native.trt_engine_cache_enable = trt_options.trt_engine_cache_enable; | ||
var cachePathPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(trt_options.trt_cache_path), GCHandleType.Pinned); | ||
using (var pinnedSettingsName2 = new PinnedGCHandle(cachePathPinned)) | ||
{ | ||
trt_options_native.trt_cache_path = pinnedSettingsName2.Pointer; | ||
} | ||
|
||
|
||
NativeApiStatus.VerifySuccess(NativeMethods.SessionOptionsAppendExecutionProvider_TensorRT(options.Handle, ref trt_options_native)); | ||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CUDA(options.Handle, trt_options.device_id)); | ||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options.Handle, 1)); | ||
return options; | ||
} | ||
|
||
/// <summary> | ||
/// A helper method to construct a SessionOptions object for Nuphar execution. | ||
/// Use only if you have the onnxruntime package specific to this Execution Provider. | ||
|
@@ -325,6 +391,29 @@ public void AddFreeDimensionOverrideByName(string dimName, long dimValue) | |
NativeApiStatus.VerifySuccess(NativeMethods.OrtAddFreeDimensionOverrideByName(handle, pinnedDimName.Pointer, dimValue)); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Get TensorRT provider options with default setting. | ||
/// </summary> | ||
/// <returns> TRT provider options instance. </returns> | ||
public static OrtTensorRTProviderOptions GetDefaultTensorRTProviderOptions() | ||
{ | ||
OrtTensorRTProviderOptions trt_options; | ||
trt_options.device_id = 0; | ||
trt_options.has_trt_options = 0; | ||
trt_options.trt_max_workspace_size = (UIntPtr)(1 << 30); | ||
trt_options.trt_fp16_enable = 0; | ||
trt_options.trt_int8_enable = 0; | ||
trt_options.trt_int8_calibration_table_name = ""; | ||
trt_options.trt_int8_use_native_calibration_table = 0; | ||
trt_options.trt_max_partition_iterations = 1000; | ||
trt_options.trt_min_subgraph_size = 1; | ||
trt_options.trt_dump_subgraphs = 0; | ||
trt_options.trt_engine_cache_enable = 0; | ||
trt_options.trt_cache_path = ""; | ||
|
||
return trt_options; | ||
} | ||
#endregion | ||
|
||
internal IntPtr Handle | ||
|
@@ -592,6 +681,35 @@ public ExecutionMode ExecutionMode | |
} | ||
private ExecutionMode _executionMode = ExecutionMode.ORT_SEQUENTIAL; | ||
|
||
|
||
/// <summary> | ||
/// Provider options for TensorRT. | ||
/// </summary> | ||
// Example for setting: | ||
// SessionOptions.OrtTensorRTProviderOptions trt_options; | ||
// trt_options.device_id = 0; | ||
// trt_options.has_trt_options = 1; | ||
// trt_options.trt_max_workspace_size = (UIntPtr) (1<<30); | ||
// trt_options.trt_fp16_enable = 1; | ||
// trt_options.trt_int8_enable = 1; | ||
// trt_options.trt_int8_calibration_table_name = "calibration.flatbuffers"; | ||
// trt_options.trt_int8_use_native_calibration_table = 0; | ||
public struct OrtTensorRTProviderOptions | ||
jywu-msft marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
public int device_id; //!< cuda device id. Default is 0. </typeparam> | ||
public int has_trt_options; //!< override environment variables with following TensorRT settings at runtime. Default 0 = false, nonzero = true. | ||
public UIntPtr trt_max_workspace_size; //!< maximum workspace size for TensorRT. ORT C++ DLL has this field to be the type of size_t, hence using UIntPtr for conversion. | ||
public int trt_fp16_enable; //!< enable TensorRT FP16 precision. Default 0 = false, nonzero = true. | ||
public int trt_int8_enable; //!< enable TensorRT INT8 precision. Default 0 = false, nonzero = true. | ||
public String trt_int8_calibration_table_name; //!< TensorRT INT8 calibration table name. | ||
public int trt_int8_use_native_calibration_table; //!< use native TensorRT generated calibration table. Default 0 = false, nonzero = true | ||
public int trt_max_partition_iterations; //!< maximum number of iterations allowed in model partitioning for TensorRT. | ||
public int trt_min_subgraph_size; //!< minimum node size in a subgraph after partitioning. | ||
public int trt_dump_subgraphs; //!< dump the subgraphs that are transformed into TRT engines in onnx format to the filesystem. Default 0 = false, nonzero = true | ||
public int trt_engine_cache_enable; //!< enable TensorRT engine caching. Default 0 = false, nonzero = true | ||
public String trt_cache_path; //!< specify path for TensorRT engine and profile files if engine_cache_enable is enabled, or INT8 calibration table file if trt_int8_enable is enabled. | ||
} | ||
|
||
#endregion | ||
|
||
#region Private Methods | ||
|
@@ -624,6 +742,27 @@ private static bool CheckCudaExecutionProviderDLLs() | |
return true; | ||
} | ||
|
||
private static bool CheckTensorrtExecutionProviderDLLs() | ||
{ | ||
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) | ||
{ | ||
foreach (var dll in trtDelayLoadedLibs) | ||
{ | ||
IntPtr handle = LoadLibrary(dll); | ||
if (handle != IntPtr.Zero) | ||
continue; | ||
var sysdir = new StringBuilder(String.Empty, 2048); | ||
GetSystemDirectory(sysdir, (uint)sysdir.Capacity); | ||
throw new OnnxRuntimeException( | ||
ErrorCode.NoSuchFile, | ||
$"kernel32.LoadLibrary():'{dll}' not found. TensorRT/CUDA are required for GPU execution. " + | ||
$". Verify it is available in the system directory={sysdir}. Else copy it to the output folder." | ||
); | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
|
||
#endregion | ||
#region SafeHandle | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -227,6 +227,54 @@ public void CanCreateAndDisposeSessionWithModelPath() | |
} | ||
} | ||
|
||
|
||
|
||
#if USE_TENSORRT | ||
[Fact] | ||
private void validateTensorRTProviderOptions() | ||
{ | ||
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx"); | ||
string calTablPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet_calibration.flatbuffers"); | ||
//Environment.SetEnvironmentVariable("ORT_TENSORRT_ENGINE_CACHE_ENABLE", "1"); | ||
|
||
SessionOptions.OrtTensorRTProviderOptions trt_options = SessionOptions.GetDefaultTensorRTProviderOptions(); | ||
trt_options.device_id = 0; | ||
trt_options.trt_int8_calibration_table_name = calTablPath; | ||
trt_options.has_trt_options = 1; | ||
trt_options.trt_max_workspace_size = (UIntPtr)(1 << 30); | ||
trt_options.trt_fp16_enable = 1; | ||
trt_options.trt_int8_enable = 1; | ||
trt_options.trt_int8_use_native_calibration_table = 0; | ||
|
||
var session = new InferenceSession(modelPath, SessionOptions.MakeSessionOptionWithTensorrtProvider(trt_options)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Session is a disposable class. So this should either be wrapped into a using clause, tyr/finally. OR, if you have trouble managing so many disposables, you can add then to a disposable list which alone would require disposal. See examples in this file. The issue here, people copy this code as examples, and then complain about leaks. |
||
var inputMeta = session.InputMetadata; | ||
var container = new List<NamedOnnxValue>(); | ||
float[] inputData = LoadTensorFromFile(@"bench.in"); // this is the data for only one input tensor for this model | ||
foreach (var name in inputMeta.Keys) | ||
{ | ||
Assert.Equal(typeof(float), inputMeta[name].ElementType); | ||
Assert.True(inputMeta[name].IsTensor); | ||
var tensor = new DenseTensor<float>(inputData, inputMeta[name].Dimensions); | ||
container.Add(NamedOnnxValue.CreateFromTensor<float>(name, tensor)); | ||
} | ||
|
||
|
||
using (var results = session.Run(container)) | ||
{ | ||
// Following code is temporarily commented. | ||
// Even though we enable fp16 or int8 through provider options, it could be disabled from TRT EP due to GPU not supporting fp16 or int8. | ||
// Once From/ToProviderOptions() has been implemented in TRT EP, better test cases will be added. | ||
/* | ||
string[] files = Directory.GetFiles(Directory.GetCurrentDirectory(), "*int8*.engine"); | ||
Assert.True(files.Any()); | ||
files = Directory.GetFiles(Directory.GetCurrentDirectory(), "*fp16*.engine"); | ||
Assert.True(files.Any()); | ||
*/ | ||
} | ||
} | ||
#endif | ||
|
||
|
||
[Theory] | ||
[InlineData(GraphOptimizationLevel.ORT_DISABLE_ALL, true)] | ||
[InlineData(GraphOptimizationLevel.ORT_DISABLE_ALL, false)] | ||
|
@@ -2361,6 +2409,7 @@ private void VerifyNativeMethodsExist() | |
#endif | ||
#if USE_TENSORRT | ||
,"OrtSessionOptionsAppendExecutionProvider_Tensorrt" | ||
,"SessionOptionsAppendExecutionProvider_TensorRT" | ||
#endif | ||
#if USE_MIGRAPHX | ||
,"OrtSessionOptionsAppendExecutionProvider_MIGraphX" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -289,15 +289,20 @@ typedef struct OrtROCMProviderOptions { | |
/// Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT | ||
/// </summary> | ||
typedef struct OrtTensorRTProviderOptions { | ||
int device_id; // cuda device id. | ||
int has_user_compute_stream; // indicator of user specified CUDA compute stream. | ||
void* user_compute_stream; // user specified CUDA compute stream. | ||
int has_trt_options; // override environment variables with following TensorRT settings at runtime. | ||
size_t trt_max_workspace_size; // maximum workspace size for TensorRT. | ||
int trt_fp16_enable; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true | ||
int trt_int8_enable; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true | ||
const char* trt_int8_calibration_table_name; // TensorRT INT8 calibration table name. | ||
int trt_int8_use_native_calibration_table; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true | ||
int device_id; // cuda device id. | ||
int has_user_compute_stream; // indicator of user specified CUDA compute stream. | ||
void* user_compute_stream; // user specified CUDA compute stream. | ||
int has_trt_options; // override environment variables with following TensorRT settings at runtime. | ||
size_t trt_max_workspace_size; // maximum workspace size for TensorRT. | ||
int trt_fp16_enable; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true | ||
int trt_int8_enable; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true | ||
const char* trt_int8_calibration_table_name; // TensorRT INT8 calibration table name. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
int trt_int8_use_native_calibration_table; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true | ||
int max_partition_iterations; // maximum number of iterations allowed in model partitioning for TensorRT. | ||
jywu-msft marked this conversation as resolved.
Show resolved
Hide resolved
|
||
int min_subgraph_size; // minimum node size in a subgraph after partitioning. | ||
int dump_subgraphs; // dump the subgraphs that are transformed into TRT engines in onnx format to the filesystem. Default 0 = false, nonzero = true | ||
int engine_cache_enable; // enable TensorRT engine caching. Default 0 = false, nonzero = true | ||
const char* cache_path; // specify path for TensorRT engine and profile files if engine_cache_enable is enabled, or INT8 calibration table file if trt_int8_enable is enabled. | ||
} OrtTensorRTProviderOptions; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should watch the extensiblity of such structs and modify all languages that makes use of this C API at the same time. |
||
|
||
/// <summary> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
OrtSessionOptionsAppendExecutionProvider_Tensorrt | ||
SessionOptionsAppendExecutionProvider_TensorRT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to clarify how to initialize this, since paths in Windows are in UTF-16 and in Linux it is UTF-8.