diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index 4128524b30483..8a8426a0b3054 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -362,6 +362,7 @@ static NativeMethods()
OrtDisableMemPattern = (DOrtDisableMemPattern)Marshal.GetDelegateForFunctionPointer(api_.DisableMemPattern, typeof(DOrtDisableMemPattern));
OrtEnableCpuMemArena = (DOrtEnableCpuMemArena)Marshal.GetDelegateForFunctionPointer(api_.EnableCpuMemArena, typeof(DOrtEnableCpuMemArena));
OrtDisableCpuMemArena = (DOrtDisableCpuMemArena)Marshal.GetDelegateForFunctionPointer(api_.DisableCpuMemArena, typeof(DOrtDisableCpuMemArena));
+ OrtDisablePerSessionThreads = (DOrtDisablePerSessionThreads)Marshal.GetDelegateForFunctionPointer(api_.DisablePerSessionThreads, typeof(DOrtDisablePerSessionThreads));
OrtSetSessionLogId = (DOrtSetSessionLogId)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogId, typeof(DOrtSetSessionLogId));
OrtSetSessionLogVerbosityLevel = (DOrtSetSessionLogVerbosityLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogVerbosityLevel, typeof(DOrtSetSessionLogVerbosityLevel));
OrtSetSessionLogSeverityLevel = (DOrtSetSessionLogSeverityLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogSeverityLevel, typeof(DOrtSetSessionLogSeverityLevel));
@@ -992,6 +993,10 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
public delegate IntPtr /*(OrtStatus*)*/ DOrtDisableCpuMemArena(IntPtr /* OrtSessionOptions* */ options);
public static DOrtDisableCpuMemArena OrtDisableCpuMemArena;
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtDisablePerSessionThreads(IntPtr /* OrtSessionOptions* */ options);
+ public static DOrtDisablePerSessionThreads OrtDisablePerSessionThreads;
+
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, byte[] /* const char* */ logId);
public static DOrtSetSessionLogId OrtSetSessionLogId;
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
index 7a68246c9b67a..30d005b3c4236 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
@@ -696,6 +696,15 @@ public bool EnableCpuMemArena
}
private bool _enableCpuMemArena = true;
+ ///
+ /// Disables the per session threads. Default is true.
+ /// This makes all sessions in the process use a global TP.
+ ///
+ public void DisablePerSessionThreads()
+ {
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtDisablePerSessionThreads(handle));
+ }
+
///
/// Log Id to be used for the session. Default is empty string.
///
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
index fd8feda359f90..d6a6b9627f418 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
@@ -55,6 +55,9 @@ public void TestSessionOptions()
Assert.Equal(0, opt.InterOpNumThreads);
Assert.Equal(GraphOptimizationLevel.ORT_ENABLE_ALL, opt.GraphOptimizationLevel);
+ // No get, so no verify
+ opt.DisablePerSessionThreads();
+
// try setting options
opt.ExecutionMode = ExecutionMode.ORT_PARALLEL;
Assert.Equal(ExecutionMode.ORT_PARALLEL, opt.ExecutionMode);
@@ -98,7 +101,7 @@ public void TestSessionOptions()
Assert.Contains("[ErrorCode:InvalidArgument] Config key is empty", ex.Message);
// SessionOptions.RegisterOrtExtensions can be manually tested by referencing the
- // Microsoft.ML.OnnxRuntime.Extensions nuget package. After that is done, this should not throw.
+ // Microsoft.ML.OnnxRuntime.Extensions nuget package. After that is done, this should not throw.
ex = Assert.Throws(() => { opt.RegisterOrtExtensions(); });
Assert.Contains("Microsoft.ML.OnnxRuntime.Extensions NuGet package must be referenced", ex.Message);