From 91b8ad5ee73df834782b19093da7038f6e0ab6e1 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 1 Feb 2022 18:09:24 -0800 Subject: [PATCH] Allow users to bind arbitrary memory using raw pointers (#10428) Add binding external allocation Add negative tests Add missing return status check --- .../InferenceSession.shared.cs | 10 +- .../NativeOnnxTensorMemory.shared.cs | 17 ++- .../NativeOnnxValueHelper.shared.cs | 13 ++- .../OrtAllocator.shared.cs | 79 ++++++++++++- .../OrtIoBinding.shared.cs | 108 +++++++++++++----- .../OrtValue.shared.cs | 19 ++- .../OrtIoBindingAllocationTest.cs | 74 +++++++++++- .../TestDataLoader.cs | 3 +- 8 files changed, 274 insertions(+), 49 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs index 693f6ea2bd632..027cbfdc788c7 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs @@ -40,7 +40,7 @@ public class InferenceSession : IDisposable /// Dictionary that represents overridableInitializers metadata /// private Dictionary _overridableInitializerMetadata; - + private SessionOptions _builtInSessionOptions = null; private RunOptions _builtInRunOptions = null; private ModelMetadata _modelMetadata = null; @@ -998,9 +998,15 @@ internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo) NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(tensorInfo, out el_type)); type = (TensorElementType)el_type; } + Type dotnetType = null; int width = 0; - TensorElementTypeConverter.GetTypeAndWidth(type, out dotnetType, out width); + if (!TensorElementTypeConverter.GetTypeAndWidth(type, out dotnetType, out width)) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + "Unable to query type information for data type: " + type.ToString()); + } + UIntPtr numDimensions; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(tensorInfo, out numDimensions)); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.shared.cs index 61ac3324b6b06..b2439c32b0708 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.shared.cs @@ -54,7 +54,7 @@ protected virtual void Dispose(bool disposing) // dispose managed state (managed objects). if (disposing) { - if(_disposables != null) + if (_disposables != null) { _disposables.Dispose(); _disposables = null; @@ -106,10 +106,19 @@ public NativeOnnxTensorMemory(OrtValue ortValue) NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(typeAndShape, out el_type)); elemType = (TensorElementType)el_type; } - TensorElementTypeConverter.GetTypeAndWidth(elemType, out type, out width); + + if (!TensorElementTypeConverter.GetTypeAndWidth(elemType, out type, out width)) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + "Unable to query type information for data type: " + elemType.ToString()); + } if (typeof(T) != type) - throw new NotSupportedException(nameof(NativeOnnxTensorMemory) + " does not support T = " + nameof(T)); + { + var message = String.Format("The NativeOnnxTensorMemory type being instantiated for T = : {0} while supplied OrtValue contains T = {1}", + typeof(T), type); + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, message); + } ElementType = elemType; ElementWidth = width; @@ -136,7 +145,7 @@ public NativeOnnxTensorMemory(OrtValue ortValue) Dimensions[i] = (int)shape[i]; } - if (typeof(T) != typeof(string)) + if (elemType != TensorElementType.String) { NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorMutableData(ortValue.Handle, out _dataBufferPointer)); } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs index ee57a02b8120b..67781b82f5f1f 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs @@ -108,19 +108,22 @@ internal static IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection names, Nam internal static class TensorElementTypeConverter { - public static void GetTypeAndWidth(TensorElementType elemType, out Type type, out int width) + public static bool GetTypeAndWidth(TensorElementType elemType, out Type type, out int width) { - TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType); - if(result != null) + bool result = true; + TensorElementTypeInfo typeInfo = TensorBase.GetElementTypeInfo(elemType); + if(typeInfo != null) { - type = result.TensorType; - width = result.TypeSize; + type = typeInfo.TensorType; + width = typeInfo.TypeSize; } else { type = null; width = 0; + result = false; } + return result; } } } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs index a396df41ec580..c420b706e28f2 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using Microsoft.ML.OnnxRuntime.Tensors; using System; using System.Runtime.InteropServices; using System.Text; @@ -61,7 +62,7 @@ internal IntPtr Pointer } #region SafeHandle - + /// /// Overrides SafeHandle.IsInvalid /// @@ -257,7 +258,7 @@ public OrtAllocatorType GetAllocatorType() public override bool Equals(object obj) { var other = obj as OrtMemoryInfo; - if(other == null) + if (other == null) { return false; } @@ -271,7 +272,7 @@ public override bool Equals(object obj) /// true if instances are equal according to OrtCompareMemoryInfo. public bool Equals(OrtMemoryInfo other) { - if(this == other) + if (this == other) { return true; } @@ -310,6 +311,78 @@ protected override bool ReleaseHandle() #endregion } + /// + /// This class represents an arbitrary buffer of memory + /// allocated and owned by the user. It can be either a CPU, GPU or other device memory + /// that can be suitably represented by IntPtr. + /// This is just a composite of the buffer related information. + /// The memory is assumed to be pinned if necessary and usable immediately + /// in the native code. + /// + public class OrtExternalAllocation + { + /// + /// Constructor + /// + /// use to accurately describe a piece of memory that this is wrapping + /// shape of this buffer + /// element type + /// the actual pointer to memory + /// size of the allocation in bytes + public OrtExternalAllocation(OrtMemoryInfo memInfo, long[] shape, Tensors.TensorElementType elementType, IntPtr pointer, long sizeInBytes) + { + Type type; + int width; + if (!TensorElementTypeConverter.GetTypeAndWidth(elementType, out type, out width)) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + "Unable to query type information for data type: " + elementType.ToString()); + } + + if (elementType == TensorElementType.String) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + "Strings are not supported by this API"); + } + + var shapeSize = ArrayUtilities.GetSizeForShape(shape); + var requiredBufferSize = shapeSize * width; + if (requiredBufferSize > sizeInBytes) + { + var message = String.Format("Shape of {0} elements requires a buffer of at least {1} bytes. Provided: {2} bytes", + shapeSize, requiredBufferSize, sizeInBytes); + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, message); + } + + Info = memInfo; + Shape = shape; + ElementType = elementType; + Pointer = pointer; + Size = sizeInBytes; + } + + /// + /// OrtMemoryInfo + /// + public OrtMemoryInfo Info { get; private set; } + /// + /// Shape + /// + public long[] Shape { get; private set; } + /// + /// Data type + /// + public Tensors.TensorElementType ElementType { get; private set; } + /// + /// Actual memory ptr + /// + public IntPtr Pointer { get; private set; } + /// + /// Size of the allocation in bytes + /// + public long Size { get; private set; } + } + /// /// This class represents memory allocation made by a specific onnxruntime /// allocator. Use OrtAllocator.Allocate() to obtain an instance of this class. diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtIoBinding.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtIoBinding.shared.cs index 40549a684856c..382ffe7929de7 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtIoBinding.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtIoBinding.shared.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using Microsoft.ML.OnnxRuntime.Tensors; using System; using System.Runtime.InteropServices; using System.Text; @@ -55,9 +56,7 @@ internal IntPtr Handle /// Bind a piece of pre-allocated native memory as a OrtValue Tensor with a given shape /// to an input with a given name. The model will read the specified input from that memory /// possibly avoiding the need to copy between devices. OrtMemoryAllocation continues to own - /// the chunk of native memory and should be alive until the end of execution. - /// The size of the allocation can not be less than required. - /// by the Tensor of the given size. + /// the chunk of native memory, and the allocation should be alive until the end of execution. /// /// of the input /// Tensor element type @@ -65,11 +64,20 @@ internal IntPtr Handle /// native memory allocation public void BindInput(string name, Tensors.TensorElementType elementType, long[] shape, OrtMemoryAllocation allocation) { - using (var ortValue = OrtValue.CreateTensorValueWithData(allocation.Info, - elementType, - shape, - allocation.Pointer, allocation.Size)) - BindInputOrOutput(name, ortValue.Handle, true); + BindOrtAllocation(name, elementType, shape, allocation, true); + } + + /// + /// Bind externally (not from OrtAllocator) allocated memory as input. + /// The model will read the specified input from that memory + /// possibly avoiding the need to copy between devices. The user code continues to own + /// the chunk of externally allocated memory, and the allocation should be alive until the end of execution. + /// + /// name + /// non ort allocated memory + public void BindInput(string name, OrtExternalAllocation allocation) + { + BindExternalAllocation(name, allocation, true); } /// @@ -80,7 +88,7 @@ public void BindInput(string name, Tensors.TensorElementType elementType, long[] /// public void BindInput(string name, FixedBufferOnnxValue fixedValue) { - if(fixedValue.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) + if (fixedValue.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) { throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Binding works only with Tensors"); } @@ -93,13 +101,12 @@ public void BindInput(string name, FixedBufferOnnxValue fixedValue) /// public void SynchronizeBoundInputs() { - NativeMethods.OrtSynchronizeBoundInputs(handle); + NativeApiStatus.VerifySuccess(NativeMethods.OrtSynchronizeBoundInputs(handle)); } /// /// Bind model output to an OrtValue as Tensor with a given type and shape. An instance of OrtMemoryAllocaiton - /// owns the memory and should be alive for the time of execution.The size of the allocation can not be less than required - /// by the Tensor of the given size. + /// owns the memory and should be alive for the time of execution. /// /// of the output /// tensor element type @@ -107,11 +114,20 @@ public void SynchronizeBoundInputs() /// allocated memory public void BindOutput(string name, Tensors.TensorElementType elementType, long[] shape, OrtMemoryAllocation allocation) { - using (var ortValue = OrtValue.CreateTensorValueWithData(allocation.Info, - elementType, - shape, - allocation.Pointer, allocation.Size)) - BindInputOrOutput(name, ortValue.Handle, false); + BindOrtAllocation(name, elementType, shape, allocation, false); + } + + /// + /// Bind externally (not from OrtAllocator) allocated memory as output. + /// The model will read the specified input from that memory + /// possibly avoiding the need to copy between devices. The user code continues to own + /// the chunk of externally allocated memory, and the allocation should be alive until the end of execution. + /// + /// name + /// non ort allocated memory + public void BindOutput(string name, OrtExternalAllocation allocation) + { + BindExternalAllocation(name, allocation, false); } /// @@ -139,7 +155,7 @@ public void BindOutputToDevice(string name, OrtMemoryInfo memInfo) { var utf8NamePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name), GCHandleType.Pinned); using (var pinnedName = new PinnedGCHandle(utf8NamePinned)) - NativeApiStatus.VerifySuccess(NativeMethods.OrtBindOutputToDevice(handle, pinnedName.Pointer, memInfo.Pointer)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtBindOutputToDevice(handle, pinnedName.Pointer, memInfo.Pointer)); } /// @@ -148,9 +164,46 @@ public void BindOutputToDevice(string name, OrtMemoryInfo memInfo) /// public void SynchronizeBoundOutputs() { - NativeMethods.OrtSynchronizeBoundOutputs(handle); + NativeApiStatus.VerifySuccess(NativeMethods.OrtSynchronizeBoundOutputs(handle)); } + /// + /// Bind allocation obtained from an Ort allocator + /// + /// name + /// data type + /// tensor shape + /// ort allocation + /// whether this is input or output + private void BindOrtAllocation(string name, Tensors.TensorElementType elementType, long[] shape, + OrtMemoryAllocation allocation, bool isInput) + { + using (var ortValue = OrtValue.CreateTensorValueWithData(allocation.Info, + elementType, + shape, + allocation.Pointer, allocation.Size)) + BindInputOrOutput(name, ortValue.Handle, isInput); + } + + + /// + /// Bind external allocation as input or output. + /// The allocation is owned by the user code. + /// + /// name + /// non ort allocated memory + /// whether this is an input or output + private void BindExternalAllocation(string name, OrtExternalAllocation allocation, bool isInput) + { + using (var ortValue = OrtValue.CreateTensorValueWithData(allocation.Info, + allocation.ElementType, + allocation.Shape, + allocation.Pointer, + allocation.Size)) + BindInputOrOutput(name, ortValue.Handle, isInput); + } + + /// /// Internal helper /// @@ -185,7 +238,7 @@ public string[] GetOutputNames() var allocator = OrtAllocator.DefaultInstance; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetBoundOutputNames(handle, allocator.Pointer, out buffer, out lengths, out count)); - if(count.Equals(UIntPtr.Zero)) + if (count.Equals(UIntPtr.Zero)) { return new string[0]; } @@ -196,9 +249,9 @@ public string[] GetOutputNames() int outputCount = (int)count; var lens = new int[outputCount]; int totalLength = 0; - for(int i = 0; i < outputCount; ++i) + for (int i = 0; i < outputCount; ++i) { - var len =(int)Marshal.ReadIntPtr(lengths, IntPtr.Size * i); + var len = (int)Marshal.ReadIntPtr(lengths, IntPtr.Size * i); lens[i] = len; totalLength += len; } @@ -208,7 +261,7 @@ public string[] GetOutputNames() string[] result = new string[outputCount]; int readOffset = 0; - for(int i = 0; i < outputCount; ++i) + for (int i = 0; i < outputCount; ++i) { var strLen = lens[i]; result[i] = Encoding.UTF8.GetString(stringData, readOffset, strLen); @@ -229,23 +282,24 @@ public IDisposableReadOnlyCollection GetOutputValues() var allocator = OrtAllocator.DefaultInstance; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetBoundOutputValues(handle, allocator.Pointer, out ortValues, out count)); - if(count.Equals(UIntPtr.Zero)) + if (count.Equals(UIntPtr.Zero)) { return new DisposableList(); } - using(var ortValuesAllocation = new OrtMemoryAllocation(allocator, ortValues, 0)) + using (var ortValuesAllocation = new OrtMemoryAllocation(allocator, ortValues, 0)) { int outputCount = (int)count; var ortList = new DisposableList(outputCount); try { - for(int i = 0; i < outputCount; ++i) + for (int i = 0; i < outputCount; ++i) { IntPtr ortValue = Marshal.ReadIntPtr(ortValues, IntPtr.Size * i); ortList.Add(new OrtValue(ortValue)); } - } catch(Exception) + } + catch (Exception) { ortList.Dispose(); throw; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index 49f9cb33f0686..08609bb4826a6 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -94,16 +94,25 @@ public static OrtValue CreateTensorValueWithData(OrtMemoryInfo memInfo, TensorEl { Type type; int width; - TensorElementTypeConverter.GetTypeAndWidth(elementType, out type, out width); - if(width < 1) + if (!TensorElementTypeConverter.GetTypeAndWidth(elementType, out type, out width)) { - throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Unsupported data type (such as string)"); + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + "Unable to query type information for data type: " + elementType.ToString()); + } + + if (elementType == TensorElementType.String) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + "Cannot map managed strings buffer to native OrtValue"); } var shapeSize = ArrayUtilities.GetSizeForShape(shape); - if((shapeSize * width) > bufferLength) + var requiredBufferSize = shapeSize * width; + if (requiredBufferSize > bufferLength) { - throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Can not bind the shape to smaller buffer"); + var message = String.Format("Shape of: {0} elements requires a buffer of at least {1} bytes. Provided: {2} bytes", + shapeSize, requiredBufferSize, bufferLength); + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, message); } IntPtr ortValueHandle = IntPtr.Zero; diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtIoBindingAllocationTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtIoBindingAllocationTest.cs index 7c9fcfe34819c..c6312b65f751b 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtIoBindingAllocationTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtIoBindingAllocationTest.cs @@ -2,8 +2,10 @@ // Licensed under the MIT License. using Microsoft.ML.OnnxRuntime.Tensors; +using Microsoft.Win32.SafeHandles; using System; using System.Linq; +using System.Runtime.InteropServices; using Xunit; using static Microsoft.ML.OnnxRuntime.Tests.InferenceTest; @@ -23,15 +25,36 @@ private static void PopulateNativeBufferFloat(OrtMemoryAllocation buffer, float[ Assert.True(false); } + PopulateNativeBuffer(buffer.Pointer, elements); + } + + private static void PopulateNativeBuffer(IntPtr buffer, float[] elements) + { unsafe { - float* p = (float*)buffer.Pointer; + float* p = (float*)buffer; for (int i = 0; i < elements.Length; ++i) { *p++ = elements[i]; } } } + /// + /// Use to free globally allocated memory + /// + class OrtSafeMemoryHandle : SafeHandle + { + public OrtSafeMemoryHandle(IntPtr allocPtr) : base(allocPtr, true) { } + + public override bool IsInvalid => handle == IntPtr.Zero; + + protected override bool ReleaseHandle() + { + Marshal.FreeHGlobal(handle); + handle = IntPtr.Zero; + return true; + } + } [Fact(DisplayName = "TestIOBindingWithOrtAllocation")] public void TestIOBindingWithOrtAllocation() @@ -61,8 +84,17 @@ public void TestIOBindingWithOrtAllocation() var ortAllocationInput = allocator.Allocate((uint)inputData.Length * sizeof(float)); dispList.Add(ortAllocationInput); var inputShape = Array.ConvertAll(inputMeta[inputName].Dimensions, d => d); + var shapeSize = ArrayUtilities.GetSizeForShape(inputShape); + Assert.Equal(shapeSize, inputData.Length); PopulateNativeBufferFloat(ortAllocationInput, inputData); + // Create an external allocation for testing OrtExternalAllocation + var cpuMemInfo = OrtMemoryInfo.DefaultInstance; + var sizeInBytes = shapeSize * sizeof(float); + IntPtr allocPtr = Marshal.AllocHGlobal((int)sizeInBytes); + dispList.Add(new OrtSafeMemoryHandle(allocPtr)); + PopulateNativeBuffer(allocPtr, inputData); + var ortAllocationOutput = allocator.Allocate((uint)outputData.Length * sizeof(float)); dispList.Add(ortAllocationOutput); @@ -102,6 +134,46 @@ public void TestIOBindingWithOrtAllocation() Assert.Equal(outputData, tensor.ToArray(), new FloatComparer()); } } + // 3. Test external allocation + { + var externalInputAllocation = new OrtExternalAllocation(cpuMemInfo, inputShape, + Tensors.TensorElementType.Float, allocPtr, sizeInBytes); + + ioBinding.BindInput(inputName, externalInputAllocation); + ioBinding.BindOutput(outputName, Tensors.TensorElementType.Float, outputShape, ortAllocationOutput); + ioBinding.SynchronizeBoundInputs(); + using (var outputs = session.RunWithBindingAndNames(runOptions, ioBinding)) + { + ioBinding.SynchronizeBoundOutputs(); + Assert.Equal(1, outputs.Count); + var output = outputs.ElementAt(0); + Assert.Equal(outputName, output.Name); + var tensor = output.AsTensor(); + Assert.True(tensor.IsFixedSize); + Assert.Equal(outputData, tensor.ToArray(), new FloatComparer()); + } + } + // 4. Some negative tests for external allocation + { + // Small buffer size + Action smallBuffer = delegate () + { + new OrtExternalAllocation(cpuMemInfo, inputShape, + Tensors.TensorElementType.Float, allocPtr, sizeInBytes - 10); + }; + + Assert.Throws(smallBuffer); + + Action stringType = delegate () + { + new OrtExternalAllocation(cpuMemInfo, inputShape, + Tensors.TensorElementType.String, allocPtr, sizeInBytes); + }; + + Assert.Throws(stringType); + + } + } } } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs index e4ff2bf9c71da..8b556e68c1af5 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs @@ -56,8 +56,7 @@ internal static void GetTypeAndWidth(Tensors.TensorElementType elemType, out Typ } else { - type = null; - width = 0; + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Unable to get information for type: " + elemType.ToString()); } }