Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable TRT provider option configuration for C# (updated version) #7808

Merged
merged 44 commits into from
Jun 25, 2021
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
7c2d4c6
prepare for C# to configure provider options
chilo-ms May 17, 2021
1921ec4
add c# code
chilo-ms May 17, 2021
9ec0f35
revert modification
chilo-ms May 18, 2021
6ac01ca
Add update provider info configuration in trt ep side
chilo-ms May 18, 2021
de6cb01
fix bugs
chilo-ms May 18, 2021
c80a02a
fix bug for compiler error C2259
chilo-ms May 18, 2021
a2b8984
Add c# test
chilo-ms May 19, 2021
681c319
fix bug
chilo-ms May 19, 2021
a67fcf3
fix bug
chilo-ms May 19, 2021
ace27d2
Properly deal with string
chilo-ms May 20, 2021
12b7cdc
Add c# api for accepting trt provider options
chilo-ms May 20, 2021
a763a64
fix bug
chilo-ms May 21, 2021
557724b
Merge branch 'c_sharp_trt_provider_options' of https://github.com/mic…
chilo-ms May 21, 2021
b628bf5
Modify C# test
chilo-ms May 21, 2021
9677dfc
add shared lib test
chilo-ms May 21, 2021
3e9d013
Add get provider options functionality
chilo-ms May 24, 2021
d8c18aa
clean up
chilo-ms May 24, 2021
69d37e8
clean up
chilo-ms May 24, 2021
d40122b
fix bug
chilo-ms May 24, 2021
b456e6c
Merge branch 'master' into c_sharp_trt_provider_options
chilo-ms May 24, 2021
2a87b40
fix bugs for CI
chilo-ms May 25, 2021
99774ae
Fix bugs for CI and documentation
chilo-ms May 25, 2021
9a0b07c
Move TRT EP provider options related functions out of C API
chilo-ms May 25, 2021
96851de
revert
chilo-ms May 25, 2021
30cc55c
fix bug
chilo-ms May 25, 2021
2576645
refactor
chilo-ms May 26, 2021
497550d
add check for provider options string
chilo-ms May 26, 2021
748cb95
Merge branch 'master' into c_sharp_trt_provider_options
chilo-ms Jun 4, 2021
1f815e4
code refactor
chilo-ms Jun 8, 2021
0a173c4
fix CI bug
chilo-ms Jun 8, 2021
ae45fe8
Fix CI bugs
chilo-ms Jun 8, 2021
1126559
clean up
chilo-ms Jun 8, 2021
e6953e8
fix bug
chilo-ms Jun 8, 2021
5fff868
Fix bug for Post Analysis
chilo-ms Jun 8, 2021
7a5f903
fix accidental bug
chilo-ms Jun 8, 2021
5e3f600
Add API_IMPL_BEGIN/API_IMPL_END
chilo-ms Jun 9, 2021
3ee5b20
clean up
chilo-ms Jun 9, 2021
b312090
code refactor
chilo-ms Jun 15, 2021
1f6280e
code refactor
chilo-ms Jun 18, 2021
c48c9ba
Merge branch 'master' into c_sharp_trt_provider_options
chilo-ms Jun 18, 2021
41a8dee
fix CI fail
chilo-ms Jun 18, 2021
3649fa9
fix bug
chilo-ms Jun 18, 2021
2ca2a19
use string append
chilo-ms Jun 22, 2021
a2456af
Change the code to better handle strncpy and string append
chilo-ms Jun 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ public struct OrtApi
public IntPtr ReleasePrepackedWeightsContainer;
public IntPtr CreateSessionWithPrepackedWeightsContainer;
public IntPtr CreateSessionFromArrayWithPrepackedWeightsContainer;
public IntPtr CreateTensorRTProviderOptions;
public IntPtr UpdateTensorRTProviderOptions;
public IntPtr GetTensorRTProviderOptions;
public IntPtr ReleaseTensorRTProviderOptions;
}

internal static class NativeMethods
Expand Down Expand Up @@ -271,6 +275,8 @@ static NativeMethods()
OrtRegisterCustomOpsLibrary = (DOrtRegisterCustomOpsLibrary)Marshal.GetDelegateForFunctionPointer(api_.RegisterCustomOpsLibrary, typeof(DOrtRegisterCustomOpsLibrary));
OrtAddSessionConfigEntry = (DOrtAddSessionConfigEntry)Marshal.GetDelegateForFunctionPointer(api_.AddSessionConfigEntry, typeof(DOrtAddSessionConfigEntry));
OrtAddInitializer = (DOrtAddInitializer)Marshal.GetDelegateForFunctionPointer(api_.AddInitializer, typeof(DOrtAddInitializer));
SessionOptionsAppendExecutionProvider_TensorRT = (DSessionOptionsAppendExecutionProvider_TensorRT)Marshal.GetDelegateForFunctionPointer(
api_.SessionOptionsAppendExecutionProvider_TensorRT, typeof(DSessionOptionsAppendExecutionProvider_TensorRT));

OrtCreateRunOptions = (DOrtCreateRunOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateRunOptions, typeof(DOrtCreateRunOptions));
OrtReleaseRunOptions = (DOrtReleaseRunOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseRunOptions, typeof(DOrtReleaseRunOptions));
Expand Down Expand Up @@ -354,6 +360,10 @@ static NativeMethods()
OrtCreatePrepackedWeightsContainer = (DOrtCreatePrepackedWeightsContainer)Marshal.GetDelegateForFunctionPointer(api_.CreatePrepackedWeightsContainer, typeof(DOrtCreatePrepackedWeightsContainer));
OrtReleasePrepackedWeightsContainer = (DOrtReleasePrepackedWeightsContainer)Marshal.GetDelegateForFunctionPointer(api_.ReleasePrepackedWeightsContainer, typeof(DOrtReleasePrepackedWeightsContainer));

OrtCreateTensorRTProviderOptions = (DOrtCreateTensorRTProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateTensorRTProviderOptions, typeof(DOrtCreateTensorRTProviderOptions));
OrtUpdateTensorRTProviderOptions = (DOrtUpdateTensorRTProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateTensorRTProviderOptions, typeof(DOrtUpdateTensorRTProviderOptions));
OrtGetTensorRTProviderOptions = (DOrtGetTensorRTProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.GetTensorRTProviderOptions, typeof(DOrtGetTensorRTProviderOptions));
OrtReleaseTensorRTProviderOptions = (DOrtReleaseTensorRTProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseTensorRTProviderOptions, typeof(DOrtReleaseTensorRTProviderOptions));
}

[DllImport(nativeLib, CharSet = charSet)]
Expand All @@ -376,6 +386,49 @@ static NativeMethods()

#endregion Runtime/Environment API

#region Provider Options API

/// <summary>
/// Creates native OrtTensorRTProviderOptions instance
/// </summary>
/// <param name="trtProviderOptionsInstance">(output) native instance of OrtTensorRTProviderOptions</param>
public delegate IntPtr /* OrtStatus* */DOrtCreateTensorRTProviderOptions(
out IntPtr /*(OrtTensorRTProviderOptions**)*/ trtProviderOptionsInstance);
public static DOrtCreateTensorRTProviderOptions OrtCreateTensorRTProviderOptions;

/// <summary>
/// Updates native OrtTensorRTProviderOptions instance using given key/value pairs
/// </summary>
/// <param name="trtProviderOptionsInstance">native instance of OrtTensorRTProviderOptions</param>
/// <param name="providerOptionsKeys">configuration keys of OrtTensorRTProviderOptions</param>
/// <param name="providerOptionsValues">configuration values of OrtTensorRTProviderOptions</param>
/// <param name="numKeys">number of configuration keys</param>
public delegate IntPtr /* OrtStatus* */DOrtUpdateTensorRTProviderOptions(
IntPtr /*(OrtTensorRTProviderOptions*)*/ trtProviderOptionsInstance,
IntPtr[] /*(const char* const *)*/ providerOptionsKeys,
IntPtr[] /*(const char* const *)*/ providerOptionsValues,
UIntPtr /*(size_t)*/ numKeys);
public static DOrtUpdateTensorRTProviderOptions OrtUpdateTensorRTProviderOptions;

/// <summary>
/// Updates native OrtTensorRTProviderOptions instance using given key/value pairs
/// </summary>
/// <param name="allocator">instance of OrtAllocator</param>
/// <param name="ptr">is a UTF-8 null terminated string allocated using 'allocator'</param>
public delegate IntPtr /* OrtStatus* */DOrtGetTensorRTProviderOptions(
IntPtr /*(OrtAllocator*)*/ allocator,
out IntPtr /*(char**)*/ptr);
public static DOrtGetTensorRTProviderOptions OrtGetTensorRTProviderOptions;

/// <summary>
/// Releases native OrtTensorRTProviderOptions instance
/// </summary>
/// <param name="trtProviderOptionsInstance">native instance of OrtTensorRTProviderOptions to be released</param>
public delegate void DOrtReleaseTensorRTProviderOptions(IntPtr /*(OrtTensorRTProviderOptions*)*/ trtProviderOptionsInstance);
public static DOrtReleaseTensorRTProviderOptions OrtReleaseTensorRTProviderOptions;

#endregion

#region Status API
public delegate ErrorCode DOrtGetErrorCode(IntPtr /*(OrtStatus*)*/status);
public static DOrtGetErrorCode OrtGetErrorCode;
Expand Down Expand Up @@ -640,6 +693,16 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca

//[DllImport(nativeLib, CharSet = charSet)]
//public static extern void OrtAddCustomOp(IntPtr /*(OrtSessionOptions*)*/ options, string custom_op_path);
//
/// <summary>
/// Append a TensorRT EP instance (configured based on given provider options) to the native OrtSessionOptions instance
/// </summary>
/// <param name="options">Native OrtSessionOptions instance</param>
/// <param name="trtProviderOptions">Native OrtTensorRTProviderOptions instance</param>
public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider_TensorRT(
IntPtr /*(OrtSessionOptions*)*/ options,
IntPtr /*(const OrtTensorRTProviderOptions*)*/ trtProviderOptions);
public static DSessionOptionsAppendExecutionProvider_TensorRT SessionOptionsAppendExecutionProvider_TensorRT;

/// <summary>
/// Free Dimension override (by denotation)
Expand Down
27 changes: 27 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
using System;
using System.Runtime.InteropServices;
using System.Text;
using System.Collections.Generic;
using System.Linq;

namespace Microsoft.ML.OnnxRuntime
{
Expand Down Expand Up @@ -77,6 +79,31 @@ internal static string StringFromNativeUtf8(IntPtr nativeUtf8)
Marshal.Copy(nativeUtf8, buffer, 0, len);
return Encoding.UTF8.GetString(buffer, 0, buffer.Length);
}

/// <summary>
/// Run helper
/// </summary>
/// <param name="names">names to convert to zero terminated utf8 and pin</param>
/// <param name="extractor">delegate for string extraction from inputs</param>
/// <param name="cleanupList">list to add pinned memory to for later disposal</param>
/// <returns></returns>
internal static IntPtr[] ConvertNamesToUtf8<T>(IReadOnlyCollection<T> names, NameExtractor<T> extractor,
DisposableList<IDisposable> cleanupList)
{
var result = new IntPtr[names.Count];
for (int i = 0; i < names.Count; ++i)
{
var name = extractor(names.ElementAt(i));
var utf8Name = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name);
var pinnedHandle = new PinnedGCHandle(GCHandle.Alloc(utf8Name, GCHandleType.Pinned));
result[i] = pinnedHandle.Pointer;
cleanupList.Add(pinnedHandle);
}
return result;
}

// Delegate for string extraction from an arbitrary input/output object
internal delegate string NameExtractor<in TInput>(TInput input);
}

internal static class TensorElementTypeConverter
Expand Down
149 changes: 149 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
/// Holds the options for configuring a TensorRT Execution Provider instance
/// </summary>
public class OrtTensorRTProviderOptions : SafeHandle
{
internal IntPtr Handle
{
get
{
return handle;
}
}

private int _deviceId = 0;
private string _deviceIdStr = "device_id";

#region Constructor

/// <summary>
/// Constructs an empty OrtTensorRTProviderOptions instance
/// </summary>
public OrtTensorRTProviderOptions() : base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateTensorRTProviderOptions(out handle));
}

#endregion

#region Public Methods

/// <summary>
/// Get TensorRT EP provider options
/// </summary>
/// <returns> return C# UTF-16 encoded string </returns>
public string GetOptions()
{
var allocator = OrtAllocator.DefaultInstance;

// Process provider options string
IntPtr providerOptions = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorRTProviderOptions(allocator.Pointer, out providerOptions));
using (var ortAllocation = new OrtMemoryAllocation(allocator, providerOptions, 0))
{
return NativeOnnxValueHelper.StringFromNativeUtf8(providerOptions);
}
}

/// <summary>
/// Updates the configuration knobs of OrtTensorRTProviderOptions that will eventually be used to configure a TensorRT EP
/// Please refer to the following on different key/value pairs to configure a TensorRT EP and their meaning:
/// https://www.onnxruntime.ai/docs/reference/execution-providers/TensorRT-ExecutionProvider.html
/// </summary>
/// <param name="providerOptions">key/value pairs used to configure a TensorRT Execution Provider</param>
public void UpdateOptions(Dictionary<string, string> providerOptions)
{

using (var cleanupList = new DisposableList<IDisposable>())
{
var keysArray = NativeOnnxValueHelper.ConvertNamesToUtf8(providerOptions.Keys.ToArray(), n => n, cleanupList);
var valuesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(providerOptions.Values.ToArray(), n => n, cleanupList);

NativeApiStatus.VerifySuccess(NativeMethods.OrtUpdateTensorRTProviderOptions(handle, keysArray, valuesArray, (UIntPtr)providerOptions.Count));

if (providerOptions.ContainsKey(_deviceIdStr))
{
_deviceId = Int32.Parse(providerOptions[_deviceIdStr]);
}
}
}

/// <summary>
/// Get device id of TensorRT EP.
/// </summary>
/// <returns> device id </returns>
public int GetDeviceId()
{
return _deviceId;
}

#endregion

#region Public Properties

/// <summary>
/// Overrides SafeHandle.IsInvalid
/// </summary>
/// <value>returns true if handle is equal to Zero</value>
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }

#endregion

#region Private Methods


#endregion

#region SafeHandle
/// <summary>
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of OrtTensorRTProviderOptions
/// </summary>
/// <returns>always returns true</returns>
protected override bool ReleaseHandle()
{
NativeMethods.OrtReleaseTensorRTProviderOptions(handle);
handle = IntPtr.Zero;
return true;
}

#endregion
}

/// <summary>
/// This helper class contains methods to handle values of provider options
/// </summary>
public class ProviderOptionsValueHelper
{
/// <summary>
/// Parse from string and save to dictionary
/// </summary>
/// <param name="s">C# string</param>
/// <param name="dict">Dictionary instance to store the parsing result of s</param>
public static void StringToDict(string s, Dictionary<string, string> dict)
{
string[] paris = s.Split(';');

foreach (var p in paris)
{
string[] keyValue = p.Split('=');
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

string[] keyValue

Need to check that the array after split contains 2 elements, throw with a meaningful message, otherwise it would be some generic OutOfBounds().

if (keyValue.Length != 2)
{
throw new ArgumentException("Make sure input string contains key-value paris, e.g. key1=value1;key2=value2...", "s");
}
dict.Add(keyValue[0], keyValue[1]);
}
}
}

}
38 changes: 38 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Runtime.InteropServices;
using System.Text;
using System.Collections.Generic;

namespace Microsoft.ML.OnnxRuntime
{
Expand Down Expand Up @@ -100,6 +101,33 @@ public static SessionOptions MakeSessionOptionWithTensorrtProvider(int deviceId
}
}

/// <summary>
/// A helper method to construct a SessionOptions object for TensorRT execution provider.
/// Use only if CUDA/TensorRT are installed and you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="trtProviderOptions">TensorRT EP provider options</param>
/// <returns>A SessionsOptions() object configured for execution on provider options</returns>
public static SessionOptions MakeSessionOptionWithTensorrtProvider(OrtTensorRTProviderOptions trtProviderOptions)
{
CheckTensorrtExecutionProviderDLLs();
SessionOptions options = new SessionOptions();
try
{
// Make sure that CUDA EP uses the same device id as TensorRT EP.
int deviceId = trtProviderOptions.GetDeviceId() ;

NativeApiStatus.VerifySuccess(NativeMethods.SessionOptionsAppendExecutionProvider_TensorRT(options.Handle, trtProviderOptions.Handle));
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CUDA(options.Handle, deviceId));
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options.Handle, 1));
return options;
}
catch (Exception e)
{
options.Dispose();
throw e;
}
}

/// <summary>
/// A helper method to construct a SessionOptions object for Nuphar execution.
/// Use only if you have the onnxruntime package specific to this Execution Provider.
Expand Down Expand Up @@ -205,6 +233,16 @@ public void AppendExecutionProvider_Tensorrt(int deviceId)
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Tensorrt(handle, deviceId));
}

/// <summary>
/// Append a TensorRT EP instance (based on specified configuration) to the SessionOptions instance.
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="trtProviderOptions">TensorRT EP provider options</param>
public void AppendExecutionProvider_Tensorrt(OrtTensorRTProviderOptions trtProviderOptions)
{
NativeApiStatus.VerifySuccess(NativeMethods.SessionOptionsAppendExecutionProvider_TensorRT(handle, trtProviderOptions.Handle));
}

/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
Expand Down
Loading