Skip to content

Commit

Permalink
[java] Migrate OnnxTensors created from arrays over to a backing Java…
Browse files Browse the repository at this point in the history
… buffer (#18556)

### Description
Following from #16578 and #16835 this migrates over
`OnnxTensor.createTensor(<array>)` to first instantiate a
`java.nio.Buffer` and then copy the array into that buffer in Java
before creating the tensor. It also changes the `OnnxTensor.getValue()`
method which returns a multidimensional array so it does the array
construction and value copy in Java. This allows the removal of some
unpleasant recursive C code which repeatedly calls into the JVM to
traverse Java's arrays. The equivalent Java code is still unpleasant and
recursive, but it's easier to reason about and memory safe. As a bonus,
more `OnnxTensor`s are now backed by buffers which allow users to pin
memory and reduce allocations by reusing them for same sized inputs.

Some of the JNI code which parses Java arrays still exists as it's used
by `OnnxMap`, removing that will be the target of a future refactor.
Strings are still processed in JNI as it is easier to work with String
tensors and UTF-8 arrays in C.

### Motivation and Context
Minimizing the amount of JNI code makes it easier to maintain and using
buffers in preference to arrays allows for fewer allocations.
  • Loading branch information
Craigacp authored Sep 24, 2024
1 parent ae66d0e commit cfa45df
Show file tree
Hide file tree
Showing 8 changed files with 682 additions and 311 deletions.
214 changes: 184 additions & 30 deletions java/src/main/java/ai/onnxruntime/OnnxTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@ public class OnnxTensor extends OnnxTensorLike {
* the state of this buffer without first getting the reference via {@link #getBufferRef()}.
*
* @return True if the buffer in this OnnxTensor was allocated by it on construction (i.e., it is
* a copy of a user buffer.)
* a copy of a user buffer or array.)
*/
public boolean ownsBuffer() {
return this.ownsBuffer;
}

/**
* Returns a reference to the buffer which backs this {@code OnnxTensor}. If the tensor is not
* backed by a buffer (i.e., it was created from a Java array, or is backed by memory allocated by
* ORT) this method returns an empty {@link Optional}.
* backed by a buffer (i.e., it is backed by memory allocated by ORT) this method returns an empty
* {@link Optional}.
*
* <p>Changes to the buffer elements will be reflected in the native {@code OrtValue}, this can be
* used to repeatedly update a single tensor for multiple different inferences without allocating
Expand All @@ -77,7 +77,116 @@ public boolean ownsBuffer() {
* @return A reference to the buffer.
*/
public Optional<Buffer> getBufferRef() {
return Optional.ofNullable(buffer);
return Optional.ofNullable(duplicate(buffer));
}

/**
* Duplicates the buffer to ensure concurrent reads don't disrupt the buffer position. Concurrent
* writes will modify the underlying memory in a racy way, don't do that.
*
* <p>Can be replaced to a call to buf.duplicate() in Java 9+.
*
* @param buf The buffer to duplicate.
* @return A copy of the buffer which refers to the same underlying memory, but has an independent
* position, limit and mark.
*/
private static Buffer duplicate(Buffer buf) {
if (buf instanceof ByteBuffer) {
return ((ByteBuffer) buf).duplicate().order(ByteOrder.nativeOrder());
} else if (buf instanceof ShortBuffer) {
return ((ShortBuffer) buf).duplicate();
} else if (buf instanceof IntBuffer) {
return ((IntBuffer) buf).duplicate();
} else if (buf instanceof LongBuffer) {
return ((LongBuffer) buf).duplicate();
} else if (buf instanceof FloatBuffer) {
return ((FloatBuffer) buf).duplicate();
} else if (buf instanceof DoubleBuffer) {
return ((DoubleBuffer) buf).duplicate();
} else {
throw new IllegalStateException("Unknown buffer type " + buf.getClass());
}
}

/**
* Checks that the buffer is the right type for the {@code info.type}, and if it's a {@link
* ByteBuffer} then convert it to the right type. If it's not convertible it throws {@link
* IllegalStateException}.
*
* <p>Note this method converts FP16 and BFLOAT16 ShortBuffers into FP32 FloatBuffers, to preserve
* compatibility with existing {@link #getValue} calls.
*
* @param buf The buffer to convert.
* @return The buffer with the expected type.
*/
private Buffer castBuffer(Buffer buf) {
switch (info.type) {
case FLOAT:
if (buf instanceof FloatBuffer) {
return buf;
} else if (buf instanceof ByteBuffer) {
return ((ByteBuffer) buf).asFloatBuffer();
}
break;
case DOUBLE:
if (buf instanceof DoubleBuffer) {
return buf;
} else if (buf instanceof ByteBuffer) {
return ((ByteBuffer) buf).asDoubleBuffer();
}
break;
case BOOL:
case INT8:
case UINT8:
if (buf instanceof ByteBuffer) {
return buf;
}
break;
case BFLOAT16:
if (buf instanceof ShortBuffer) {
ShortBuffer bf16Buf = (ShortBuffer) buf;
return Fp16Conversions.convertBf16BufferToFloatBuffer(bf16Buf);
} else if (buf instanceof ByteBuffer) {
ShortBuffer bf16Buf = ((ByteBuffer) buf).asShortBuffer();
return Fp16Conversions.convertBf16BufferToFloatBuffer(bf16Buf);
}
break;
case FLOAT16:
if (buf instanceof ShortBuffer) {
ShortBuffer fp16Buf = (ShortBuffer) buf;
return Fp16Conversions.convertFp16BufferToFloatBuffer(fp16Buf);
} else if (buf instanceof ByteBuffer) {
ShortBuffer fp16Buf = ((ByteBuffer) buf).asShortBuffer();
return Fp16Conversions.convertFp16BufferToFloatBuffer(fp16Buf);
}
break;
case INT16:
if (buf instanceof ShortBuffer) {
return buf;
} else if (buf instanceof ByteBuffer) {
return ((ByteBuffer) buf).asShortBuffer();
}
break;
case INT32:
if (buf instanceof IntBuffer) {
return buf;
} else if (buf instanceof ByteBuffer) {
return ((ByteBuffer) buf).asIntBuffer();
}
break;
case INT64:
if (buf instanceof LongBuffer) {
return buf;
} else if (buf instanceof ByteBuffer) {
return ((ByteBuffer) buf).asLongBuffer();
}
break;
}
throw new IllegalStateException(
"Invalid buffer type for cast operation, found "
+ buf.getClass()
+ " expected something convertible to "
+ info.type);
}

@Override
Expand Down Expand Up @@ -133,15 +242,26 @@ public Object getValue() throws OrtException {
Object carrier = info.makeCarrier();
if (info.getNumElements() > 0) {
// If the tensor has values copy them out
getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
}
if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) {
// We read the strings out from native code in a flat array and then reshape
// to the desired output shape.
return OrtUtil.reshape((String[]) carrier, info.shape);
} else {
return carrier;
if (info.type == OnnxJavaType.STRING) {
// We read the strings out from native code in a flat array and then reshape
// to the desired output shape if necessary.
getStringArray(OnnxRuntime.ortApiHandle, nativeHandle, (String[]) carrier);
if (info.shape.length != 1) {
carrier = OrtUtil.reshape((String[]) carrier, info.shape);
}
} else {
// Wrap ORT owned memory in buffer, otherwise use our reference
Buffer buf;
if (buffer == null) {
buf = castBuffer(getBuffer());
} else {
buf = castBuffer(duplicate(buffer));
}
// Copy out buffer into arrays
OrtUtil.fillArrayFromBuffer(info, buf, 0, carrier);
}
}
return carrier;
}
}

Expand Down Expand Up @@ -175,8 +295,8 @@ public synchronized void close() {
public ByteBuffer getByteBuffer() {
checkClosed();
if (info.type != OnnxJavaType.STRING) {
ByteBuffer buffer = getBuffer(OnnxRuntime.ortApiHandle, nativeHandle);
ByteBuffer output = ByteBuffer.allocate(buffer.capacity());
ByteBuffer buffer = getBuffer();
ByteBuffer output = ByteBuffer.allocate(buffer.capacity()).order(ByteOrder.nativeOrder());
output.put(buffer);
output.rewind();
return output;
Expand All @@ -201,12 +321,12 @@ public FloatBuffer getFloatBuffer() {
output.rewind();
return output;
} else if (info.type == OnnxJavaType.FLOAT16) {
// if it's fp16 we need to copy it out by hand.
// if it's fp16 we need to convert it.
ByteBuffer buf = getBuffer();
ShortBuffer buffer = buf.asShortBuffer();
return Fp16Conversions.convertFp16BufferToFloatBuffer(buffer);
} else if (info.type == OnnxJavaType.BFLOAT16) {
// if it's bf16 we need to copy it out by hand.
// if it's bf16 we need to convert it.
ByteBuffer buf = getBuffer();
ShortBuffer buffer = buf.asShortBuffer();
return Fp16Conversions.convertBf16BufferToFloatBuffer(buffer);
Expand Down Expand Up @@ -331,7 +451,7 @@ private native short getShort(long apiHandle, long nativeHandle, int onnxType)

private native boolean getBool(long apiHandle, long nativeHandle) throws OrtException;

private native void getArray(long apiHandle, long nativeHandle, Object carrier)
private native void getStringArray(long apiHandle, long nativeHandle, String[] carrier)
throws OrtException;

private native void close(long apiHandle, long nativeHandle);
Expand Down Expand Up @@ -387,21 +507,32 @@ static OnnxTensor createTensor(OrtEnvironment env, OrtAllocator allocator, Objec
info);
}
} else {
Buffer buf;
if (info.shape.length == 0) {
data = OrtUtil.convertBoxedPrimitiveToArray(info.type, data);
if (data == null) {
buf = OrtUtil.convertBoxedPrimitiveToBuffer(info.type, data);
if (buf == null) {
throw new OrtException(
"Failed to convert a boxed primitive to an array, this is an error with the ORT Java API, please report this message & stack trace. JavaType = "
+ info.type
+ ", object = "
+ data);
}
} else {
buf = OrtUtil.convertArrayToBuffer(info, data);
}
return new OnnxTensor(
createTensor(
OnnxRuntime.ortApiHandle, allocator.handle, data, info.shape, info.onnxType.value),
createTensorFromBuffer(
OnnxRuntime.ortApiHandle,
allocator.handle,
buf,
0,
info.type.size * info.numElements,
info.shape,
info.onnxType.value),
allocator.handle,
info);
info,
buf,
true);
}
} else {
throw new IllegalStateException("Trying to create an OnnxTensor with a closed OrtAllocator.");
Expand Down Expand Up @@ -627,7 +758,26 @@ static OnnxTensor createTensor(
*/
public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long[] shape)
throws OrtException {
return createTensor(env, env.defaultAllocator, data, shape);
return createTensor(env, env.defaultAllocator, data, shape, OnnxJavaType.INT16);
}

/**
* Create an OnnxTensor backed by a direct ShortBuffer. The buffer should be in nativeOrder.
*
* <p>If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime
* of the tensor. Uses the default allocator.
*
* @param env The current OrtEnvironment.
* @param data The tensor data.
* @param shape The shape of tensor.
* @param type The type of the data in the buffer, can be either {@link OnnxJavaType#INT16},
* {@link OnnxJavaType#FLOAT16} or {@link OnnxJavaType#BFLOAT16}.
* @return An OnnxTensor of the required shape.
* @throws OrtException Thrown if there is an onnx error or if the data and shape don't match.
*/
public static OnnxTensor createTensor(
OrtEnvironment env, ShortBuffer data, long[] shape, OnnxJavaType type) throws OrtException {
return createTensor(env, env.defaultAllocator, data, shape, type);
}

/**
Expand All @@ -640,15 +790,23 @@ public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long
* @param allocator The allocator to use.
* @param data The tensor data.
* @param shape The shape of tensor.
* @param type The type of the data in the buffer, can be either {@link OnnxJavaType#INT16},
* {@link OnnxJavaType#FLOAT16} or {@link OnnxJavaType#BFLOAT16}.
* @return An OnnxTensor of the required shape.
* @throws OrtException Thrown if there is an onnx error or if the data and shape don't match.
*/
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape)
OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape, OnnxJavaType type)
throws OrtException {
if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.INT16;
return createTensor(type, allocator, data, shape);
if ((type == OnnxJavaType.BFLOAT16)
|| (type == OnnxJavaType.FLOAT16)
|| (type == OnnxJavaType.INT16)) {
return createTensor(type, allocator, data, shape);
} else {
throw new IllegalArgumentException(
"Only int16, float16 or bfloat16 tensors can be created from ShortBuffer.");
}
} else {
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
}
Expand Down Expand Up @@ -768,10 +926,6 @@ private static OnnxTensor createTensor(
tuple.isCopy);
}

private static native long createTensor(
long apiHandle, long allocatorHandle, Object data, long[] shape, int onnxType)
throws OrtException;

private static native long createTensorFromBuffer(
long apiHandle,
long allocatorHandle,
Expand Down
Loading

0 comments on commit cfa45df

Please sign in to comment.