Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow users to bind arbitrary memory using raw pointers #10428

Merged
merged 7 commits into from
Feb 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class InferenceSession : IDisposable
/// Dictionary that represents overridableInitializers metadata
/// </summary>
private Dictionary<string, NodeMetadata> _overridableInitializerMetadata;

private SessionOptions _builtInSessionOptions = null;
private RunOptions _builtInRunOptions = null;
private ModelMetadata _modelMetadata = null;
Expand Down Expand Up @@ -998,9 +998,15 @@ internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo)
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(tensorInfo, out el_type));
type = (TensorElementType)el_type;
}

Type dotnetType = null;
int width = 0;
TensorElementTypeConverter.GetTypeAndWidth(type, out dotnetType, out width);
if (!TensorElementTypeConverter.GetTypeAndWidth(type, out dotnetType, out width))
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
"Unable to query type information for data type: " + type.ToString());
}

UIntPtr numDimensions;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(tensorInfo, out numDimensions));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ protected virtual void Dispose(bool disposing)
// dispose managed state (managed objects).
if (disposing)
{
if(_disposables != null)
if (_disposables != null)
{
_disposables.Dispose();
_disposables = null;
Expand Down Expand Up @@ -106,10 +106,19 @@ public NativeOnnxTensorMemory(OrtValue ortValue)
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(typeAndShape, out el_type));
elemType = (TensorElementType)el_type;
}
TensorElementTypeConverter.GetTypeAndWidth(elemType, out type, out width);

if (!TensorElementTypeConverter.GetTypeAndWidth(elemType, out type, out width))
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
"Unable to query type information for data type: " + elemType.ToString());
}

if (typeof(T) != type)
throw new NotSupportedException(nameof(NativeOnnxTensorMemory<T>) + " does not support T = " + nameof(T));
{
var message = String.Format("The NativeOnnxTensorMemory<T> type being instantiated for T = : {0} while supplied OrtValue contains T = {1}",
typeof(T), type);
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, message);
}

ElementType = elemType;
ElementWidth = width;
Expand All @@ -136,7 +145,7 @@ public NativeOnnxTensorMemory(OrtValue ortValue)
Dimensions[i] = (int)shape[i];
}

if (typeof(T) != typeof(string))
if (elemType != TensorElementType.String)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorMutableData(ortValue.Handle, out _dataBufferPointer));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,22 @@ internal static IntPtr[] ConvertNamesToUtf8<T>(IReadOnlyCollection<T> names, Nam

internal static class TensorElementTypeConverter
{
public static void GetTypeAndWidth(TensorElementType elemType, out Type type, out int width)
public static bool GetTypeAndWidth(TensorElementType elemType, out Type type, out int width)
{
TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType);
if(result != null)
bool result = true;
TensorElementTypeInfo typeInfo = TensorBase.GetElementTypeInfo(elemType);
if(typeInfo != null)
{
type = result.TensorType;
width = result.TypeSize;
type = typeInfo.TensorType;
width = typeInfo.TypeSize;
}
else
{
type = null;
width = 0;
result = false;
}
return result;
}
}
}
79 changes: 76 additions & 3 deletions csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Runtime.InteropServices;
using System.Text;
Expand Down Expand Up @@ -61,7 +62,7 @@ internal IntPtr Pointer
}

#region SafeHandle

/// <summary>
/// Overrides SafeHandle.IsInvalid
/// </summary>
Expand Down Expand Up @@ -257,7 +258,7 @@ public OrtAllocatorType GetAllocatorType()
public override bool Equals(object obj)
{
var other = obj as OrtMemoryInfo;
if(other == null)
if (other == null)
{
return false;
}
Expand All @@ -271,7 +272,7 @@ public override bool Equals(object obj)
/// <returns>true if instances are equal according to OrtCompareMemoryInfo.</returns>
public bool Equals(OrtMemoryInfo other)
{
if(this == other)
if (this == other)
{
return true;
}
Expand Down Expand Up @@ -310,6 +311,78 @@ protected override bool ReleaseHandle()
#endregion
}

/// <summary>
/// This class represents an arbitrary buffer of memory
/// allocated and owned by the user. It can be either a CPU, GPU or other device memory
/// that can be suitably represented by IntPtr.
/// This is just a composite of the buffer related information.
/// The memory is assumed to be pinned if necessary and usable immediately
/// in the native code.
/// </summary>
public class OrtExternalAllocation
{
/// <summary>
/// Constructor
/// </summary>
/// <param name="memInfo">use to accurately describe a piece of memory that this is wrapping</param>
/// <param name="shape">shape of this buffer</param>
/// <param name="elementType">element type</param>
/// <param name="pointer">the actual pointer to memory</param>
/// <param name="sizeInBytes">size of the allocation in bytes</param>
public OrtExternalAllocation(OrtMemoryInfo memInfo, long[] shape, Tensors.TensorElementType elementType, IntPtr pointer, long sizeInBytes)
Copy link
Contributor

@skottmckay skottmckay Jan 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this to be provided separately to the size that can be inferred from the shape? #Closed

Copy link
Member Author

@yuslepukhin yuslepukhin Jan 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going back and forth on this. Finally, I decided to supply this argument/property to validate the user knows what they are doing. It is a common mistake to make a wrong choice betwen then number of elements and the buffer size in bytes.

{
Type type;
int width;
if (!TensorElementTypeConverter.GetTypeAndWidth(elementType, out type, out width))
{
Copy link
Contributor

@skottmckay skottmckay Jan 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we update GetTypeAndWidth to return true/false if successful? if not, would checking type==null be a little more obvious than width < 1? #Closed

throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
"Unable to query type information for data type: " + elementType.ToString());
}

if (elementType == TensorElementType.String)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
"Strings are not supported by this API");
}
Copy link
Contributor

@skottmckay skottmckay Jan 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: include elementType in the error message #Closed


var shapeSize = ArrayUtilities.GetSizeForShape(shape);
var requiredBufferSize = shapeSize * width;
if (requiredBufferSize > sizeInBytes)
{
var message = String.Format("Shape of {0} elements requires a buffer of at least {1} bytes. Provided: {2} bytes",
shapeSize, requiredBufferSize, sizeInBytes);
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, message);
}

Info = memInfo;
Shape = shape;
ElementType = elementType;
Pointer = pointer;
Size = sizeInBytes;
}

/// <summary>
/// OrtMemoryInfo
/// </summary>
public OrtMemoryInfo Info { get; private set; }
/// <summary>
/// Shape
/// </summary>
public long[] Shape { get; private set; }
/// <summary>
/// Data type
/// </summary>
public Tensors.TensorElementType ElementType { get; private set; }
/// <summary>
/// Actual memory ptr
/// </summary>
public IntPtr Pointer { get; private set; }
/// <summary>
/// Size of the allocation in bytes
/// </summary>
public long Size { get; private set; }
}

/// <summary>
/// This class represents memory allocation made by a specific onnxruntime
/// allocator. Use OrtAllocator.Allocate() to obtain an instance of this class.
Expand Down
Loading