From 224f0651d0e7727ea6a3b8b61a597b75e0f72f73 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 2 Oct 2024 10:00:43 -0700 Subject: [PATCH] [C#] Expose Multi-Lora support in C# (#22281) ### Description ### Motivation and Context https://github.com/microsoft/onnxruntime/pull/22046 --- cmake/onnxruntime_unittests.cmake | 2 +- .../NativeMethods.shared.cs | 80 ++++++++++++++++- .../OrtLoraAdapter.shared.cs | 81 +++++++++++++++++ .../RunOptions.shared.cs | 12 +++ .../InferenceTest.cs | 1 + .../InferenceTest.netcore.cs | 88 ++++++++++++++++++- ...oft.ML.OnnxRuntime.Tests.NetCoreApp.csproj | 8 ++ 7 files changed, 267 insertions(+), 5 deletions(-) create mode 100644 csharp/src/Microsoft.ML.OnnxRuntime/OrtLoraAdapter.shared.cs diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 15437c6037c95..c64f029ad9301 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1657,7 +1657,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") list(APPEND onnxruntime_customopregistration_test_LIBS ${TENSORRT_LIBRARY_INFER}) endif() if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") - list(APPEND onnxruntime_customopregistration_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 libprotobuf-lite onnx_proto nsync_cpp) + list(APPEND onnxruntime_customopregistration_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 libprotobuf-lite onnx_proto nsync_cpp) endif() AddTest(DYN TARGET onnxruntime_customopregistration_test diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index b2a7b75891a25..be157a0419fc0 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -302,6 +302,29 @@ public struct OrtApi public IntPtr ReleaseROCMProviderOptions; public IntPtr CreateAndRegisterAllocatorV2; public IntPtr RunAsync; + public IntPtr UpdateTensorRTProviderOptionsWithValue; + public IntPtr GetTensorRTProviderOptionsByName; + public IntPtr UpdateCUDAProviderOptionsWithValue; + public IntPtr GetCUDAProviderOptionsByName; + public IntPtr KernelContext_GetResource; + public IntPtr SetUserLoggingFunction; + public IntPtr ShapeInferContext_GetInputCount; + public IntPtr ShapeInferContext_GetInputTypeShape; + public IntPtr ShapeInferContext_GetAttribute; + public IntPtr ShapeInferContext_SetOutputTypeShape; + public IntPtr SetSymbolicDimensions; + public IntPtr ReadOpAttr; + public IntPtr SetDeterministicCompute; + public IntPtr KernelContext_ParallelFor; + public IntPtr SessionOptionsAppendExecutionProvider_OpenVINO_V2; + public IntPtr SessionOptionsAppendExecutionProvider_VitisAI; + public IntPtr KernelContext_GetScratchBuffer; + public IntPtr KernelInfoGetAllocator; + public IntPtr AddExternalInitializersFromFilesInMemory; + public IntPtr CreateLoraAdapter; + public IntPtr CreateLoraAdapterFromArray; + public IntPtr ReleaseLoraAdapter; + public IntPtr RunOptionsAddActiveLoraAdapter; } internal static class NativeMethods @@ -540,6 +563,13 @@ static NativeMethods() OrtReleaseROCMProviderOptions = (DOrtReleaseROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseROCMProviderOptions, typeof(DOrtReleaseROCMProviderOptions)); OrtCreateAndRegisterAllocatorV2 = (DCreateAndRegisterAllocatorV2)Marshal.GetDelegateForFunctionPointer(api_.CreateAndRegisterAllocatorV2, typeof(DCreateAndRegisterAllocatorV2)); OrtRunAsync = (DOrtRunAsync)Marshal.GetDelegateForFunctionPointer(api_.RunAsync, typeof(DOrtRunAsync)); + CreateLoraAdapter = (DCreateLoraAdapter)Marshal.GetDelegateForFunctionPointer(api_.CreateLoraAdapter, + typeof(DCreateLoraAdapter)); + CreateLoraAdapterFromArray = (DCreateLoraAdapterFromArray)Marshal.GetDelegateForFunctionPointer (api_.CreateLoraAdapterFromArray, typeof(DCreateLoraAdapterFromArray)); + ReleaseLoraAdapter = (DReleaseLoraAdapter)Marshal.GetDelegateForFunctionPointer(api_.ReleaseLoraAdapter, + typeof(DReleaseLoraAdapter)); + OrtRunOptionsAddActiveLoraAdapter = (DOrtRunOptionsAddActiveLoraAdapter)Marshal.GetDelegateForFunctionPointer( + api_.RunOptionsAddActiveLoraAdapter, typeof(DOrtRunOptionsAddActiveLoraAdapter)); } internal class NativeLib @@ -1263,7 +1293,49 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca #endregion -#region RunOptions API +#region LoraAdapter API + /// + /// Memory maps the adapter file, wraps it into the adapter object + /// and returns it. + /// + /// absolute path to the adapter file + /// optional device allocator or null + /// New LoraAdapter object + /// + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DCreateLoraAdapter( + byte[] adapter_path, // This takes const ORTCHAR_T* use GetPlatformSerializedString + IntPtr /* OrtAllocator */ allocator, // optional + out IntPtr lora_adapter + ); + public static DCreateLoraAdapter CreateLoraAdapter; + + /// + /// Creates LoraAdapter instance from a byte array that must + /// represents a valid LoraAdapter formst. + /// + /// bytes + /// size in bytes + /// optional device allocator + /// resuling LoraAdapter instance + /// + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DCreateLoraAdapterFromArray( + byte[] bytes, + UIntPtr size, + IntPtr /* OrtAllocator */ allocator, // optional + out IntPtr lora_adapter + ); + public static DCreateLoraAdapterFromArray CreateLoraAdapterFromArray; + + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DReleaseLoraAdapter(IntPtr /* OrtLoraAdapter* */ lora_adapter); + public static DReleaseLoraAdapter ReleaseLoraAdapter; + +#endregion + + #region RunOptions API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateRunOptions(out IntPtr /* OrtRunOptions** */ runOptions); @@ -1308,6 +1380,12 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public delegate IntPtr /*(OrtStatus*)*/ DOrtRunOptionsUnsetTerminate(IntPtr /* OrtRunOptions* */ options); public static DOrtRunOptionsUnsetTerminate OrtRunOptionsUnsetTerminate; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtRunOptionsAddActiveLoraAdapter( + IntPtr /* OrtRunOptions* */ options, + IntPtr /* OrtLoraAdapter* */ lora_adapter); + public static DOrtRunOptionsAddActiveLoraAdapter OrtRunOptionsAddActiveLoraAdapter; + /// /// Add run config entry /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtLoraAdapter.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtLoraAdapter.shared.cs new file mode 100644 index 0000000000000..e2249b4c47fec --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtLoraAdapter.shared.cs @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.ML.OnnxRuntime +{ + /// + /// Represents Lora Adapter in memory + /// + public class OrtLoraAdapter : SafeHandle + { + /// + /// Creates an instance of OrtLoraAdapter from file. + /// The adapter file is memory mapped. If allocator parameter + /// is provided, then lora parameters are copied to the memory + /// allocated by the specified allocator. + /// + /// path to the adapter file + /// optional allocator, can be null, must be a device allocator + /// New instance of LoraAdapter + public static OrtLoraAdapter Create(string adapterPath, OrtAllocator ortAllocator) + { + var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(adapterPath); + var allocatorHandle = (ortAllocator != null) ? ortAllocator.Pointer : IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.CreateLoraAdapter(platformPath, allocatorHandle, + out IntPtr adapterHandle)); + return new OrtLoraAdapter(adapterHandle); + } + + /// + /// Creates an instance of OrtLoraAdapter from an array of bytes. The API + /// makes a copy of the bytes internally. + /// + /// array of bytes containing valid LoraAdapter format + /// optional device allocator or null + /// new instance of LoraAdapter + public static OrtLoraAdapter Create(byte[] bytes, OrtAllocator ortAllocator) + { + var allocatorHandle = (ortAllocator != null) ? ortAllocator.Pointer : IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.CreateLoraAdapterFromArray(bytes, + new UIntPtr((uint)bytes.Length), allocatorHandle, out IntPtr adapterHandle)); + return new OrtLoraAdapter(adapterHandle); + } + + internal OrtLoraAdapter(IntPtr adapter) + : base(adapter, true) + { + } + + internal IntPtr Handle + { + get + { + return handle; + } + } + + #region SafeHandle + + /// + /// Overrides SafeHandle.IsInvalid + /// + /// returns true if handle is equal to Zero + public override bool IsInvalid { get { return handle == IntPtr.Zero; } } + + /// + /// Overrides SafeHandle.ReleaseHandle() to properly dispose of + /// the native instance of OrtLoraAdapter + /// + /// always returns true + protected override bool ReleaseHandle() + { + NativeMethods.ReleaseLoraAdapter(handle); + handle = IntPtr.Zero; + return true; + } + + #endregion + } +} diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/RunOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/RunOptions.shared.cs index d01c0f6e6fe4d..20547d6757b79 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/RunOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/RunOptions.shared.cs @@ -129,6 +129,18 @@ public void AddRunConfigEntry(string configKey, string configValue) NativeApiStatus.VerifySuccess(NativeMethods.OrtAddRunConfigEntry(handle, utf8Key, utf8Value)); } + /// + /// Appends the specified lora adapter to the list of active lora adapters + /// for this RunOptions instance. All run calls with this instant will + /// make use of the activated Lora Adapters. An adapter is considered active + /// if it is added to RunOptions that are used during Run() calls. + /// + /// Lora adapter instance + public void AddActiveLoraAdapter(OrtLoraAdapter loraAdapter) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtRunOptionsAddActiveLoraAdapter(handle, loraAdapter.Handle)); + } + #region SafeHandle /// /// Overrides SafeHandle.ReleaseHandle() to properly dispose of diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index ac7a84d69bbea..aa0e6ee62248a 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -1678,6 +1678,7 @@ private void TestInferenceSessionWithByteArray() } } + void TestCPUAllocatorInternal(InferenceSession session) { int device_id = 0; diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs index ad127c2579294..ff5fd2de54197 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs @@ -680,7 +680,7 @@ private void RunPretrainedModel(InferenceSession session, RunOptions runOptions, var orderedInputNames = new List(inputContainer.Count); var orderdedInputs = new List(inputContainer.Count); - foreach(var pair in inputContainer) + foreach (var pair in inputContainer) { orderedInputNames.Add(pair.Key); orderdedInputs.Add(pair.Value); @@ -772,7 +772,7 @@ private void TestPreTrainedModels(string opsetDir, string modelName, bool useOrt throw new Exception($"Opset {opset} Model {modelName}. Can't determine model file name. Found these :{modelNamesList}"); } - using(var runOptions = new RunOptions()) + using (var runOptions = new RunOptions()) using (var session = new InferenceSession(onnxModelFileName)) { string testDataDirNamePattern = "test_data*"; @@ -1077,7 +1077,7 @@ private static void VerifyContainerContent(IReadOnlyList results, Assert.Equal(result.GetStringTensorAsArray(), expectedValue.AsTensor().ToArray(), new ExactComparer()); break; default: - Assert.Fail($"VerifyTensorResults cannot handle ElementType: { resultTypeShape.ElementDataType}"); + Assert.Fail($"VerifyTensorResults cannot handle ElementType: {resultTypeShape.ElementDataType}"); break; } } @@ -1251,6 +1251,88 @@ private void TestModelSerialization() } } + private static OrtLoraAdapter CreateLoraAdapterFromFile() + { + var adapterPath = Path.Combine(Directory.GetCurrentDirectory(), "two_params_lora_model.onnx_adapter"); + return OrtLoraAdapter.Create(adapterPath, null); + } + + private static OrtLoraAdapter CreateLoraAdapterFromArray() + { + var adapterPath = Path.Combine(Directory.GetCurrentDirectory(), "two_params_lora_model.onnx_adapter"); + var adapterBytes = File.ReadAllBytes(adapterPath); + return OrtLoraAdapter.Create(adapterBytes, null); + } + + // See tests below for running with Lora Adapters + [Fact(DisplayName = "TestInferenceWithBaseLoraModel")] + private void TestInferenceWithBaseLoraModel() + { + var modelPath = Path.Combine(Directory.GetCurrentDirectory(), "two_params_lora_model.onnx"); + + var inputShape = new long[] { 4, 4 }; + var inputData = new float[16]; + Array.Fill(inputData, 1); + using var inputOrtValue = OrtValue.CreateTensorValueFromMemory(inputData, inputShape); + + var expectedOutput = new float[] { + 28, 32, 36, 40, + 28, 32, 36, 40, + 28, 32, 36, 40, + 28, 32, 36, 40 }; + + using var session = new InferenceSession(modelPath); + using var runOptions = new RunOptions(); + + using var outputs = session.Run(runOptions, ["input_x"], [inputOrtValue], ["output"]); + Assert.Single(outputs); + var output = outputs[0].GetTensorDataAsSpan(); + Assert.Equal(expectedOutput.Length, output.Length); + Assert.Equal(expectedOutput, output.ToArray(), new FloatComparer()); + } + + + private static void TestInferenceWithLoraAdapter(OrtLoraAdapter ortLoraAdapter) + { + var modelPath = Path.Combine(Directory.GetCurrentDirectory(), "two_params_lora_model.onnx"); + var adapterPath = Path.Combine(Directory.GetCurrentDirectory(), "two_params_lora_model.onnx_adapter"); + + var inputShape = new long[] { 4, 4 }; + var inputData = new float[16]; + Array.Fill(inputData, 1); + using var inputOrtValue = OrtValue.CreateTensorValueFromMemory(inputData, inputShape); + + var expectedOutput = new float[] { + 154, 176, 198, 220, + 154, 176, 198, 220, + 154, 176, 198, 220, + 154, 176, 198, 220 }; + + using var session = new InferenceSession(modelPath); + using var runOptions = new RunOptions(); + runOptions.AddActiveLoraAdapter(ortLoraAdapter); + + using var outputs = session.Run(runOptions, ["input_x"], [inputOrtValue], ["output"]); + Assert.Single(outputs); + var output = outputs[0].GetTensorDataAsSpan(); + Assert.Equal(expectedOutput.Length, output.Length); + Assert.Equal(expectedOutput, output.ToArray(), new FloatComparer()); + } + + [Fact(DisplayName = "TestInferenceWithLoraAdapterFromFile")] + private void TestInferenceWithLoraAdapterFromFile() + { + using var ortAdapter = CreateLoraAdapterFromFile(); + TestInferenceWithLoraAdapter(ortAdapter); + } + + [Fact(DisplayName = "TestInferenceWithLoraAdapterFromArray")] + private void TestInferenceWithLoraAdapterFromArray() + { + using var ortAdapter = CreateLoraAdapterFromArray(); + TestInferenceWithLoraAdapter(ortAdapter); + } + // TestGpu() will test // - the CUDA EP on CUDA enabled builds // - the DML EP on DML enabled builds diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj index f877cc376ea90..b822c999e4d39 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj @@ -105,6 +105,14 @@ PreserveNewest false + + PreserveNewest + false + + + PreserveNewest + false +