Skip to content

Commit

Permalink
fix bug -- failing on non float tensor output types (#58)
Browse files Browse the repository at this point in the history
* fixed the missing implementation of output tensor construction based on type checking

* cleanup

* cleanup
  • Loading branch information
shahasad authored and pranavsharma committed Nov 29, 2018
1 parent 8980cbb commit 7780fd6
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 4 deletions.
61 changes: 57 additions & 4 deletions csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Buffers;
using System.Collections;
using System.Diagnostics;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntime
{
Expand Down Expand Up @@ -185,16 +186,68 @@ internal static NamedOnnxValue CreateFromOnnxValue(string name, IntPtr nativeOnn
{
NamedOnnxValue result = null;

if (true /* TODO: check native data type when API available. assuming Tensor<float> for now */)
/* Get Tensor element type */ //TODO: Assumed value is Tensor, need to support non-tensor types in future
IntPtr typeAndShape = IntPtr.Zero;
TensorElementType elemType = TensorElementType.DataTypeMax;
try
{
NativeApiStatus.VerifySuccess(NativeMethods.ONNXRuntimeGetTensorShapeAndType(nativeOnnxValue, out typeAndShape));
elemType = NativeMethods.ONNXRuntimeGetTensorElementType(typeAndShape);
}
finally
{
NativeOnnxTensorMemory<float> nativeTensorWrapper = new NativeOnnxTensorMemory<float>(nativeOnnxValue);
DenseTensor<float> dt = new DenseTensor<float>(nativeTensorWrapper.Memory, nativeTensorWrapper.Dimensions);
result = new NamedOnnxValue(name, dt);
if (typeAndShape != IntPtr.Zero)
{
NativeMethods.ONNXRuntimeReleaseObject(typeAndShape);
}
}

switch (elemType)
{
case TensorElementType.Float:
result = NameOnnxValueFromNativeTensor<float>(name, nativeOnnxValue);
break;
case TensorElementType.Double:
result = NameOnnxValueFromNativeTensor<double>(name, nativeOnnxValue);
break;
case TensorElementType.Int16:
result = NameOnnxValueFromNativeTensor<short>(name, nativeOnnxValue);
break;
case TensorElementType.UInt16:
result = NameOnnxValueFromNativeTensor<ushort>(name, nativeOnnxValue);
break;
case TensorElementType.Int32:
result = NameOnnxValueFromNativeTensor<int>(name, nativeOnnxValue);
break;
case TensorElementType.UInt32:
result = NameOnnxValueFromNativeTensor<uint>(name, nativeOnnxValue);
break;
case TensorElementType.Int64:
result = NameOnnxValueFromNativeTensor<long>(name, nativeOnnxValue);
break;
case TensorElementType.UInt64:
result = NameOnnxValueFromNativeTensor<ulong>(name, nativeOnnxValue);
break;
case TensorElementType.UInt8:
result = NameOnnxValueFromNativeTensor<byte>(name, nativeOnnxValue);
break;
default:
throw new NotSupportedException("Tensor of element type: "+elemType+" is not supported");

}

return result;
}


private static NamedOnnxValue NameOnnxValueFromNativeTensor<T>(string name, IntPtr nativeOnnxValue)
{
NativeOnnxTensorMemory<T> nativeTensorWrapper = new NativeOnnxTensorMemory<T>(nativeOnnxValue);
DenseTensor<T> dt = new DenseTensor<T>(nativeTensorWrapper.Memory, nativeTensorWrapper.Dimensions);
return NamedOnnxValue.CreateFromTensor<T>(name, dt);
}


private bool TryPinAsTensor<T>(
out MemoryHandle pinnedMemoryHandle,
out IntPtr dataBufferPointer,
Expand Down
22 changes: 22 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ public NativeOnnxTensorMemory(IntPtr onnxValueHandle)
Dispose(false);
}

public void Dispose()
{
GC.SuppressFinalize(this);
Dispose(true);
}

public bool IsDisposed => _disposed;

protected bool IsRetained => _referenceCount > 0;
Expand All @@ -99,6 +105,22 @@ public int Rank
}
}

public int Count
{
get
{
return _elementCount;
}
}

public int ElementWidth
{
get
{
return _elementWidth;
}
}

public override Span<T> GetSpan()
{
if (IsDisposed)
Expand Down
Loading

0 comments on commit 7780fd6

Please sign in to comment.