-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow users to bind arbitrary memory using raw pointers #10428
Changes from all commits
04c5ae9
63f468f
27f54a1
4a78761
dd43824
22e7595
6ff8426
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
using Microsoft.ML.OnnxRuntime.Tensors; | ||
using System; | ||
using System.Runtime.InteropServices; | ||
using System.Text; | ||
|
@@ -61,7 +62,7 @@ internal IntPtr Pointer | |
} | ||
|
||
#region SafeHandle | ||
|
||
/// <summary> | ||
/// Overrides SafeHandle.IsInvalid | ||
/// </summary> | ||
|
@@ -257,7 +258,7 @@ public OrtAllocatorType GetAllocatorType() | |
public override bool Equals(object obj) | ||
{ | ||
var other = obj as OrtMemoryInfo; | ||
if(other == null) | ||
if (other == null) | ||
{ | ||
return false; | ||
} | ||
|
@@ -271,7 +272,7 @@ public override bool Equals(object obj) | |
/// <returns>true if instances are equal according to OrtCompareMemoryInfo.</returns> | ||
public bool Equals(OrtMemoryInfo other) | ||
{ | ||
if(this == other) | ||
if (this == other) | ||
{ | ||
return true; | ||
} | ||
|
@@ -310,6 +311,78 @@ protected override bool ReleaseHandle() | |
#endregion | ||
} | ||
|
||
/// <summary> | ||
/// This class represents an arbitrary buffer of memory | ||
/// allocated and owned by the user. It can be either a CPU, GPU or other device memory | ||
/// that can be suitably represented by IntPtr. | ||
/// This is just a composite of the buffer related information. | ||
/// The memory is assumed to be pinned if necessary and usable immediately | ||
/// in the native code. | ||
/// </summary> | ||
public class OrtExternalAllocation | ||
{ | ||
/// <summary> | ||
/// Constructor | ||
/// </summary> | ||
/// <param name="memInfo">use to accurately describe a piece of memory that this is wrapping</param> | ||
/// <param name="shape">shape of this buffer</param> | ||
/// <param name="elementType">element type</param> | ||
/// <param name="pointer">the actual pointer to memory</param> | ||
/// <param name="sizeInBytes">size of the allocation in bytes</param> | ||
public OrtExternalAllocation(OrtMemoryInfo memInfo, long[] shape, Tensors.TensorElementType elementType, IntPtr pointer, long sizeInBytes) | ||
{ | ||
Type type; | ||
int width; | ||
if (!TensorElementTypeConverter.GetTypeAndWidth(elementType, out type, out width)) | ||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can we update GetTypeAndWidth to return true/false if successful? if not, would checking type==null be a little more obvious than width < 1? #Closed |
||
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, | ||
"Unable to query type information for data type: " + elementType.ToString()); | ||
} | ||
|
||
if (elementType == TensorElementType.String) | ||
{ | ||
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, | ||
"Strings are not supported by this API"); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: include elementType in the error message #Closed |
||
|
||
var shapeSize = ArrayUtilities.GetSizeForShape(shape); | ||
var requiredBufferSize = shapeSize * width; | ||
if (requiredBufferSize > sizeInBytes) | ||
{ | ||
var message = String.Format("Shape of {0} elements requires a buffer of at least {1} bytes. Provided: {2} bytes", | ||
shapeSize, requiredBufferSize, sizeInBytes); | ||
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, message); | ||
} | ||
|
||
Info = memInfo; | ||
Shape = shape; | ||
ElementType = elementType; | ||
Pointer = pointer; | ||
Size = sizeInBytes; | ||
} | ||
|
||
/// <summary> | ||
/// OrtMemoryInfo | ||
/// </summary> | ||
public OrtMemoryInfo Info { get; private set; } | ||
/// <summary> | ||
/// Shape | ||
/// </summary> | ||
public long[] Shape { get; private set; } | ||
/// <summary> | ||
/// Data type | ||
/// </summary> | ||
public Tensors.TensorElementType ElementType { get; private set; } | ||
/// <summary> | ||
/// Actual memory ptr | ||
/// </summary> | ||
public IntPtr Pointer { get; private set; } | ||
/// <summary> | ||
/// Size of the allocation in bytes | ||
/// </summary> | ||
public long Size { get; private set; } | ||
} | ||
|
||
/// <summary> | ||
/// This class represents memory allocation made by a specific onnxruntime | ||
/// allocator. Use OrtAllocator.Allocate() to obtain an instance of this class. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this to be provided separately to the size that can be inferred from the shape? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was going back and forth on this. Finally, I decided to supply this argument/property to validate the user knows what they are doing. It is a common mistake to make a wrong choice betwen then number of elements and the buffer size in bytes.