Skip to content

Commit

Permalink
Support DisposableNamedOnnxValue inputs in c# Run() (#3175)
Browse files Browse the repository at this point in the history
* Initial commit

* Update error message

* Update

* Updates to support holding onto onnxValue and pinnedmemoryBuffer

* Updates

* Minor updates

* Comment out a portion of the tests

* PR feedback

* Minor nit update

* Resolve comments

* PR feedback

* PR updates

* PR feedback
  • Loading branch information
hariharans29 authored Mar 24, 2020
1 parent fb5ab85 commit ef7b98f
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 13 deletions.
40 changes: 38 additions & 2 deletions csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using System.Buffers;
using System.Collections.Generic;
using Microsoft.ML.OnnxRuntime.Tensors;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -60,13 +61,48 @@ public void Dispose()

public class DisposableNamedOnnxValue : NamedOnnxValue, IDisposable
{
protected IDisposable _nativeMemoryManager;
protected DisposableNamedOnnxValue(string name, Object value, IDisposable nativeMemoryManager)
private NativeMemoryHandler _nativeMemoryManager;
private DisposableNamedOnnxValue(string name, Object value, NativeMemoryHandler nativeMemoryManager)
: base(name, value)
{
_nativeMemoryManager = nativeMemoryManager;
}

/// <summary>
/// Overrides the base class method. Since the instance already has access to the
/// underlying OrtValue handle (if this instance hasn't been disposed), it just assigns
/// that to the output onnxValue. With respect to pinnedMemoryHandle, it has no operation
/// to do, as this class doesn't maintain a managed buffer. It doesn't have to maintain it
/// as it already is associated with the object of interest (native OrtValue)
/// </summary>
/// <param name="onnxValue"></param>
/// <param name="pinnedMemoryHandle"></param>
/// <param name="disposeOnnxValueAfterUse"></param>
internal override void ToNativeOnnxValue(out IntPtr onnxValue,
out MemoryHandle pinnedMemoryHandle)
{
// Make sure that this instance hasn't been disposed yet
if (disposedValue)
{
throw new ObjectDisposedException(nameof(DisposableNamedOnnxValue),
"This instance of DisposableNamedOnnxValue has already been disposed");
}

// If not already disposed, _nativeMemoryManager can only be null
// for Maps and SequenceTensors
if (_nativeMemoryManager == null)
{
throw new NotSupportedException("Use of Maps and SequenceTensors is not yet supported");
}

// Assign the onnxValue by querying this instance's NativeOnnxTensorMemory instance
onnxValue = _nativeMemoryManager.Handle;

// PinnedMemoryHandle holds the default value as DisposableNamedOnnxValue
// doesn't hold any managed buffer (that needs to be pinned)
pinnedMemoryHandle = default;
}

internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name, IntPtr nativeOnnxValue)
{
DisposableNamedOnnxValue result = null;
Expand Down
24 changes: 16 additions & 8 deletions csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> Run(IReadOnlyColl
inputNames[inputIndex] = input.Name;

// create Tensor from the input if feasible, else throw notsupported exception for now
input.ToNativeOnnxValue(out inputTensors[inputIndex], out pinnedBufferHandles[inputIndex]);
input.ToNativeOnnxValue(out inputTensors[inputIndex],
out pinnedBufferHandles[inputIndex]);

inputIndex++;
}
Expand Down Expand Up @@ -187,12 +188,19 @@ public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> Run(IReadOnlyColl
}
finally
{
// always unpin the input buffers, and delete the native Onnx value objects
for (int i = 0; i < inputs.Count; i++)
inputIndex = 0;
foreach (var input in inputs)
{
NativeMethods.OrtReleaseValue(inputTensors[i]); // For elementary type Tensors, this should not release the buffer, but should delete the native tensor object.
// For string tensors, this releases the native memory allocated for the tensor, including the buffer
pinnedBufferHandles[i].Dispose();
// For NamedOnnxValue, always unpin the input buffers, and delete the native Onnx value objects
// For DisposableNamedOnnxValue, the user needs to do this by invoking Dispose
if (input.GetType() == typeof(NamedOnnxValue))
{
NativeMethods.OrtReleaseValue(inputTensors[inputIndex]); // For elementary type Tensors, this should not release the buffer, but should delete the native tensor object.
// For string tensors, this releases the native memory allocated for the tensor, including the buffer
pinnedBufferHandles[inputIndex].Dispose();
}

inputIndex++;
}
}

Expand Down Expand Up @@ -429,7 +437,7 @@ internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo)
}
if (valueType != OnnxValueType.ONNX_TYPE_TENSOR && valueType != OnnxValueType.ONNX_TYPE_SPARSETENSOR)
{
return new NodeMetadata(valueType, new int[] { }, new string[] { }, typeof(NamedOnnxValue));
return new NodeMetadata(valueType, new int[] { }, new string[] { }, typeof(NamedOnnxValue));
}

IntPtr tensorInfo;
Expand Down Expand Up @@ -467,7 +475,7 @@ internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo)
{
symbolicDimensions[i] = Marshal.PtrToStringAnsi(dimensionNamePtrs[i]); //assumes charset = ANSI
}

return new NodeMetadata(valueType, intDimensions, symbolicDimensions, dotnetType);
}

Expand Down
6 changes: 4 additions & 2 deletions csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public static NamedOnnxValue CreateFromTensor<T>(string name, Tensor<T> value)
return new NamedOnnxValue(name, value);
}

public string Name { get { return _name; } }
public string Name { get { return _name; } set { _name = value; } }

/// <summary>
/// Try-get value as a Tensor&lt;T&gt;.
Expand Down Expand Up @@ -71,7 +71,9 @@ public IDictionary<K, V> AsDictionary<K, V>()
/// </summary>
/// <param name="onnxValue"></param>
/// <param name="pinnedMemoryHandle"></param>
internal void ToNativeOnnxValue(out IntPtr onnxValue, out MemoryHandle pinnedMemoryHandle)
/// <param name="disposeOnnxValueAfterUse"></param>
internal virtual void ToNativeOnnxValue(out IntPtr onnxValue,
out MemoryHandle pinnedMemoryHandle)
{
//try to cast _value to Tensor<T>
TensorElementType nativeElementType = TensorElementType.DataTypeMax; //invalid
Expand Down
12 changes: 11 additions & 1 deletion csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@

namespace Microsoft.ML.OnnxRuntime
{
internal class NativeOnnxTensorMemory<T> : MemoryManager<T>
/// <summary>
/// A non-public interface detailing the contract to be honored by NativeOnnxTensorMemory
/// </summary>
internal interface NativeMemoryHandler : IDisposable
{
IntPtr Handle { get;}
}

internal class NativeOnnxTensorMemory<T> : MemoryManager<T>, NativeMemoryHandler
{
private bool _disposed;
private int _referenceCount;
Expand Down Expand Up @@ -122,6 +130,8 @@ public NativeOnnxTensorMemory(IntPtr onnxValueHandle)
}
}

public IntPtr Handle { get { return _onnxValueHandle; } }

~NativeOnnxTensorMemory()
{
Dispose(false);
Expand Down
94 changes: 94 additions & 0 deletions csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,100 @@ private void TestModelInputBOOL()
}
}

[Fact]
private void TestReusingRunOutputNonStringType()
{
// model takes 1x5 input of fixed type, echoes back
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_BOOL.pb");
using (var session = new InferenceSession(modelPath))
{
var container = new List<NamedOnnxValue>();
var tensorIn = new DenseTensor<bool>(new bool[] { true, false, true, false, true }, new int[] { 1, 5 });
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
container.Add(nov);
var res1 = session.Run(container);

// change the name of the DisposableNamedOnnxValue
res1.First().Name = "input";

// Run inferencing 2 times using the output of the first Run()
for(int i=0; i<2; ++i)
{
using (var res2 = session.Run(res1))
{
var tensorOut = res2.First().AsTensor<bool>();
Assert.True(tensorOut.SequenceEqual(tensorIn));
}
}
}
}

[Fact]
private void TestReusingRunOutputStringType()
{
// model takes 1x5 input of fixed type, echoes back
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_STRING.pb");
using (var session = new InferenceSession(modelPath))
{
var container = new List<NamedOnnxValue>();
var tensorIn = new DenseTensor<string>(new string[] { "a", "b", "c", "d", "e" }, new int[] { 1, 5 });
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
container.Add(nov);
var res1 = session.Run(container);

// change the name of the DisposableNamedOnnxValue
res1.First().Name = "input";

// Run inferencing 2 times using the output of the first Run()
for (int i = 0; i < 2; ++i)
{
using (var res2 = session.Run(res1))
{
var tensorOut = res2.First().AsTensor<string>();
Assert.True(tensorOut.SequenceEqual(tensorIn));
}
}
}
}

[Fact]
private void TestReusingDisposedRunOutput()
{
// model takes 1x5 input of fixed type, echoes back
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_types_BOOL.pb");
using (var session = new InferenceSession(modelPath))
{
var container = new List<NamedOnnxValue>();
var tensorIn = new DenseTensor<bool>(new bool[] { true, false, true, false, true }, new int[] { 1, 5 });
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
container.Add(nov);
var res1 = session.Run(container);

// Dispose the result tensor
res1.First().Dispose();

bool succeeded = false;

// Now try using the disposed output as input to another Run()
try
{
// Run() should fail with a user friendly error message.
session.Run(res1);
}

catch (ObjectDisposedException e)
{
var errorString = "This instance of DisposableNamedOnnxValue has already been disposed";

Assert.True(e.Message.Contains(errorString));

succeeded = true;
}

Assert.True(succeeded);
}
}

[Fact]
private void TestModelInputINT32()
{
Expand Down

0 comments on commit ef7b98f

Please sign in to comment.