getBufferRef() {
- return Optional.ofNullable(buffer);
+ return Optional.ofNullable(duplicate(buffer));
+ }
+
+ /**
+ * Duplicates the buffer to ensure concurrent reads don't disrupt the buffer position. Concurrent
+ * writes will modify the underlying memory in a racy way, don't do that.
+ *
+ * Can be replaced to a call to buf.duplicate() in Java 9+.
+ *
+ * @param buf The buffer to duplicate.
+ * @return A copy of the buffer which refers to the same underlying memory, but has an independent
+ * position, limit and mark.
+ */
+ private static Buffer duplicate(Buffer buf) {
+ if (buf instanceof ByteBuffer) {
+ return ((ByteBuffer) buf).duplicate().order(ByteOrder.nativeOrder());
+ } else if (buf instanceof ShortBuffer) {
+ return ((ShortBuffer) buf).duplicate();
+ } else if (buf instanceof IntBuffer) {
+ return ((IntBuffer) buf).duplicate();
+ } else if (buf instanceof LongBuffer) {
+ return ((LongBuffer) buf).duplicate();
+ } else if (buf instanceof FloatBuffer) {
+ return ((FloatBuffer) buf).duplicate();
+ } else if (buf instanceof DoubleBuffer) {
+ return ((DoubleBuffer) buf).duplicate();
+ } else {
+ throw new IllegalStateException("Unknown buffer type " + buf.getClass());
+ }
+ }
+
+ /**
+ * Checks that the buffer is the right type for the {@code info.type}, and if it's a {@link
+ * ByteBuffer} then convert it to the right type. If it's not convertible it throws {@link
+ * IllegalStateException}.
+ *
+ *
Note this method converts FP16 and BFLOAT16 ShortBuffers into FP32 FloatBuffers, to preserve
+ * compatibility with existing {@link #getValue} calls.
+ *
+ * @param buf The buffer to convert.
+ * @return The buffer with the expected type.
+ */
+ private Buffer castBuffer(Buffer buf) {
+ switch (info.type) {
+ case FLOAT:
+ if (buf instanceof FloatBuffer) {
+ return buf;
+ } else if (buf instanceof ByteBuffer) {
+ return ((ByteBuffer) buf).asFloatBuffer();
+ }
+ break;
+ case DOUBLE:
+ if (buf instanceof DoubleBuffer) {
+ return buf;
+ } else if (buf instanceof ByteBuffer) {
+ return ((ByteBuffer) buf).asDoubleBuffer();
+ }
+ break;
+ case BOOL:
+ case INT8:
+ case UINT8:
+ if (buf instanceof ByteBuffer) {
+ return buf;
+ }
+ break;
+ case BFLOAT16:
+ if (buf instanceof ShortBuffer) {
+ ShortBuffer bf16Buf = (ShortBuffer) buf;
+ return Fp16Conversions.convertBf16BufferToFloatBuffer(bf16Buf);
+ } else if (buf instanceof ByteBuffer) {
+ ShortBuffer bf16Buf = ((ByteBuffer) buf).asShortBuffer();
+ return Fp16Conversions.convertBf16BufferToFloatBuffer(bf16Buf);
+ }
+ break;
+ case FLOAT16:
+ if (buf instanceof ShortBuffer) {
+ ShortBuffer fp16Buf = (ShortBuffer) buf;
+ return Fp16Conversions.convertFp16BufferToFloatBuffer(fp16Buf);
+ } else if (buf instanceof ByteBuffer) {
+ ShortBuffer fp16Buf = ((ByteBuffer) buf).asShortBuffer();
+ return Fp16Conversions.convertFp16BufferToFloatBuffer(fp16Buf);
+ }
+ break;
+ case INT16:
+ if (buf instanceof ShortBuffer) {
+ return buf;
+ } else if (buf instanceof ByteBuffer) {
+ return ((ByteBuffer) buf).asShortBuffer();
+ }
+ break;
+ case INT32:
+ if (buf instanceof IntBuffer) {
+ return buf;
+ } else if (buf instanceof ByteBuffer) {
+ return ((ByteBuffer) buf).asIntBuffer();
+ }
+ break;
+ case INT64:
+ if (buf instanceof LongBuffer) {
+ return buf;
+ } else if (buf instanceof ByteBuffer) {
+ return ((ByteBuffer) buf).asLongBuffer();
+ }
+ break;
+ }
+ throw new IllegalStateException(
+ "Invalid buffer type for cast operation, found "
+ + buf.getClass()
+ + " expected something convertible to "
+ + info.type);
}
@Override
@@ -133,15 +242,26 @@ public Object getValue() throws OrtException {
Object carrier = info.makeCarrier();
if (info.getNumElements() > 0) {
// If the tensor has values copy them out
- getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
- }
- if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) {
- // We read the strings out from native code in a flat array and then reshape
- // to the desired output shape.
- return OrtUtil.reshape((String[]) carrier, info.shape);
- } else {
- return carrier;
+ if (info.type == OnnxJavaType.STRING) {
+ // We read the strings out from native code in a flat array and then reshape
+ // to the desired output shape if necessary.
+ getStringArray(OnnxRuntime.ortApiHandle, nativeHandle, (String[]) carrier);
+ if (info.shape.length != 1) {
+ carrier = OrtUtil.reshape((String[]) carrier, info.shape);
+ }
+ } else {
+ // Wrap ORT owned memory in buffer, otherwise use our reference
+ Buffer buf;
+ if (buffer == null) {
+ buf = castBuffer(getBuffer());
+ } else {
+ buf = castBuffer(duplicate(buffer));
+ }
+ // Copy out buffer into arrays
+ OrtUtil.fillArrayFromBuffer(info, buf, 0, carrier);
+ }
}
+ return carrier;
}
}
@@ -175,8 +295,8 @@ public synchronized void close() {
public ByteBuffer getByteBuffer() {
checkClosed();
if (info.type != OnnxJavaType.STRING) {
- ByteBuffer buffer = getBuffer(OnnxRuntime.ortApiHandle, nativeHandle);
- ByteBuffer output = ByteBuffer.allocate(buffer.capacity());
+ ByteBuffer buffer = getBuffer();
+ ByteBuffer output = ByteBuffer.allocate(buffer.capacity()).order(ByteOrder.nativeOrder());
output.put(buffer);
output.rewind();
return output;
@@ -201,12 +321,12 @@ public FloatBuffer getFloatBuffer() {
output.rewind();
return output;
} else if (info.type == OnnxJavaType.FLOAT16) {
- // if it's fp16 we need to copy it out by hand.
+ // if it's fp16 we need to convert it.
ByteBuffer buf = getBuffer();
ShortBuffer buffer = buf.asShortBuffer();
return Fp16Conversions.convertFp16BufferToFloatBuffer(buffer);
} else if (info.type == OnnxJavaType.BFLOAT16) {
- // if it's bf16 we need to copy it out by hand.
+ // if it's bf16 we need to convert it.
ByteBuffer buf = getBuffer();
ShortBuffer buffer = buf.asShortBuffer();
return Fp16Conversions.convertBf16BufferToFloatBuffer(buffer);
@@ -331,7 +451,7 @@ private native short getShort(long apiHandle, long nativeHandle, int onnxType)
private native boolean getBool(long apiHandle, long nativeHandle) throws OrtException;
- private native void getArray(long apiHandle, long nativeHandle, Object carrier)
+ private native void getStringArray(long apiHandle, long nativeHandle, String[] carrier)
throws OrtException;
private native void close(long apiHandle, long nativeHandle);
@@ -387,21 +507,32 @@ static OnnxTensor createTensor(OrtEnvironment env, OrtAllocator allocator, Objec
info);
}
} else {
+ Buffer buf;
if (info.shape.length == 0) {
- data = OrtUtil.convertBoxedPrimitiveToArray(info.type, data);
- if (data == null) {
+ buf = OrtUtil.convertBoxedPrimitiveToBuffer(info.type, data);
+ if (buf == null) {
throw new OrtException(
"Failed to convert a boxed primitive to an array, this is an error with the ORT Java API, please report this message & stack trace. JavaType = "
+ info.type
+ ", object = "
+ data);
}
+ } else {
+ buf = OrtUtil.convertArrayToBuffer(info, data);
}
return new OnnxTensor(
- createTensor(
- OnnxRuntime.ortApiHandle, allocator.handle, data, info.shape, info.onnxType.value),
+ createTensorFromBuffer(
+ OnnxRuntime.ortApiHandle,
+ allocator.handle,
+ buf,
+ 0,
+ info.type.size * info.numElements,
+ info.shape,
+ info.onnxType.value),
allocator.handle,
- info);
+ info,
+ buf,
+ true);
}
} else {
throw new IllegalStateException("Trying to create an OnnxTensor with a closed OrtAllocator.");
@@ -627,7 +758,26 @@ static OnnxTensor createTensor(
*/
public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long[] shape)
throws OrtException {
- return createTensor(env, env.defaultAllocator, data, shape);
+ return createTensor(env, env.defaultAllocator, data, shape, OnnxJavaType.INT16);
+ }
+
+ /**
+ * Create an OnnxTensor backed by a direct ShortBuffer. The buffer should be in nativeOrder.
+ *
+ *
If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime
+ * of the tensor. Uses the default allocator.
+ *
+ * @param env The current OrtEnvironment.
+ * @param data The tensor data.
+ * @param shape The shape of tensor.
+ * @param type The type of the data in the buffer, can be either {@link OnnxJavaType#INT16},
+ * {@link OnnxJavaType#FLOAT16} or {@link OnnxJavaType#BFLOAT16}.
+ * @return An OnnxTensor of the required shape.
+ * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match.
+ */
+ public static OnnxTensor createTensor(
+ OrtEnvironment env, ShortBuffer data, long[] shape, OnnxJavaType type) throws OrtException {
+ return createTensor(env, env.defaultAllocator, data, shape, type);
}
/**
@@ -640,15 +790,23 @@ public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long
* @param allocator The allocator to use.
* @param data The tensor data.
* @param shape The shape of tensor.
+ * @param type The type of the data in the buffer, can be either {@link OnnxJavaType#INT16},
+ * {@link OnnxJavaType#FLOAT16} or {@link OnnxJavaType#BFLOAT16}.
* @return An OnnxTensor of the required shape.
* @throws OrtException Thrown if there is an onnx error or if the data and shape don't match.
*/
static OnnxTensor createTensor(
- OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape)
+ OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape, OnnxJavaType type)
throws OrtException {
if (!allocator.isClosed()) {
- OnnxJavaType type = OnnxJavaType.INT16;
- return createTensor(type, allocator, data, shape);
+ if ((type == OnnxJavaType.BFLOAT16)
+ || (type == OnnxJavaType.FLOAT16)
+ || (type == OnnxJavaType.INT16)) {
+ return createTensor(type, allocator, data, shape);
+ } else {
+ throw new IllegalArgumentException(
+ "Only int16, float16 or bfloat16 tensors can be created from ShortBuffer.");
+ }
} else {
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
}
@@ -768,10 +926,6 @@ private static OnnxTensor createTensor(
tuple.isCopy);
}
- private static native long createTensor(
- long apiHandle, long allocatorHandle, Object data, long[] shape, int onnxType)
- throws OrtException;
-
private static native long createTensorFromBuffer(
long apiHandle,
long allocatorHandle,
diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java
index 4f3dee3c00b91..2f44236e4ef67 100644
--- a/java/src/main/java/ai/onnxruntime/OrtUtil.java
+++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java
@@ -26,10 +26,10 @@ public final class OrtUtil {
private OrtUtil() {}
/**
- * Converts an long shape into a int shape.
+ * Converts a long shape into an int shape.
*
- *
Validates that the shape has more than 1 elements, less than 9 elements, each element is
- * less than {@link Integer#MAX_VALUE} and that each entry is non-negative.
+ *
Validates that the shape has more than 1 element, less than 9 elements, each element is less
+ * than {@link Integer#MAX_VALUE} and that each entry is non-negative.
*
* @param shape The long shape.
* @return The int shape.
@@ -460,6 +460,308 @@ static Object convertBoxedPrimitiveToArray(OnnxJavaType javaType, Object data) {
}
}
+ /**
+ * Stores a boxed primitive in a single element buffer of the unboxed type.
+ *
+ *
If it's not a boxed primitive then it returns null.
+ *
+ * @param javaType The type of the boxed primitive.
+ * @param data The boxed primitive.
+ * @return The primitive in a direct buffer.
+ */
+ static Buffer convertBoxedPrimitiveToBuffer(OnnxJavaType javaType, Object data) {
+ switch (javaType) {
+ case FLOAT:
+ {
+ FloatBuffer buf =
+ ByteBuffer.allocateDirect(javaType.size)
+ .order(ByteOrder.nativeOrder())
+ .asFloatBuffer();
+ buf.put(0, (Float) data);
+ return buf;
+ }
+ case DOUBLE:
+ {
+ DoubleBuffer buf =
+ ByteBuffer.allocateDirect(javaType.size)
+ .order(ByteOrder.nativeOrder())
+ .asDoubleBuffer();
+ buf.put(0, (Double) data);
+ return buf;
+ }
+ case BOOL:
+ {
+ ByteBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder());
+ buf.put(0, ((boolean) data) ? (byte) 1 : (byte) 0);
+ return buf;
+ }
+ case UINT8:
+ case INT8:
+ {
+ ByteBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder());
+ buf.put(0, (Byte) data);
+ return buf;
+ }
+ case FLOAT16:
+ case BFLOAT16:
+ case INT16:
+ {
+ ShortBuffer buf =
+ ByteBuffer.allocateDirect(javaType.size)
+ .order(ByteOrder.nativeOrder())
+ .asShortBuffer();
+ buf.put(0, (Short) data);
+ return buf;
+ }
+ case INT32:
+ {
+ IntBuffer buf =
+ ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()).asIntBuffer();
+ buf.put(0, (Integer) data);
+ return buf;
+ }
+ case INT64:
+ {
+ LongBuffer buf =
+ ByteBuffer.allocateDirect(javaType.size)
+ .order(ByteOrder.nativeOrder())
+ .asLongBuffer();
+ buf.put(0, (Long) data);
+ return buf;
+ }
+ case STRING:
+ case UNKNOWN:
+ default:
+ return null;
+ }
+ }
+
+ /**
+ * Copies a Java (possibly multidimensional) array into a direct {@link Buffer}.
+ *
+ *
Throws {@link IllegalArgumentException} if the array is not an array of Java primitives or
+ * if the array is ragged.
+ *
+ * @param info The tensor info object containing the types and shape of the array.
+ * @param array The array object.
+ * @return A direct buffer containing all the elements.
+ */
+ static Buffer convertArrayToBuffer(TensorInfo info, Object array) {
+ ByteBuffer byteBuffer =
+ ByteBuffer.allocateDirect((int) info.numElements * info.type.size)
+ .order(ByteOrder.nativeOrder());
+
+ Buffer buffer;
+ switch (info.type) {
+ case FLOAT:
+ buffer = byteBuffer.asFloatBuffer();
+ break;
+ case DOUBLE:
+ buffer = byteBuffer.asDoubleBuffer();
+ break;
+ case BOOL:
+ case INT8:
+ case UINT8:
+ // no-op, it's already a bytebuffer
+ buffer = byteBuffer;
+ break;
+ case BFLOAT16:
+ case FLOAT16:
+ case INT16:
+ buffer = byteBuffer.asShortBuffer();
+ break;
+ case INT32:
+ buffer = byteBuffer.asIntBuffer();
+ break;
+ case INT64:
+ buffer = byteBuffer.asLongBuffer();
+ break;
+ case STRING:
+ case UNKNOWN:
+ default:
+ throw new IllegalArgumentException(
+ "Unexpected type, expected Java primitive found " + info.type);
+ }
+
+ fillBufferFromArray(info, array, 0, buffer);
+
+ if (buffer.remaining() != 0) {
+ throw new IllegalArgumentException(
+ "Failed to copy all elements into the buffer, expected to copy "
+ + info.numElements
+ + " into a buffer of capacity "
+ + buffer.capacity()
+ + " but had "
+ + buffer.remaining()
+ + " values left over.");
+ }
+ buffer.rewind();
+
+ return buffer;
+ }
+
+ /**
+ * Fills the provided buffer with the values from the array, recursing through the array
+ * structure.
+ *
+ * @param info The tensor info containing the type and shape of the array.
+ * @param array The array object to read from.
+ * @param curDim The current dimension we're processing.
+ * @param buffer The buffer to write to.
+ */
+ private static void fillBufferFromArray(
+ TensorInfo info, Object array, int curDim, Buffer buffer) {
+ if (curDim == info.shape.length - 1) {
+ // Reached primitive values, copy into buffer
+ switch (info.type) {
+ case FLOAT:
+ float[] fArr = (float[]) array;
+ FloatBuffer fBuf = (FloatBuffer) buffer;
+ fBuf.put(fArr);
+ break;
+ case DOUBLE:
+ double[] dArr = (double[]) array;
+ DoubleBuffer dBuf = (DoubleBuffer) buffer;
+ dBuf.put(dArr);
+ break;
+ case INT8:
+ case UINT8:
+ byte[] bArr = (byte[]) array;
+ ByteBuffer bBuf = (ByteBuffer) buffer;
+ bBuf.put(bArr);
+ break;
+ case FLOAT16:
+ case BFLOAT16:
+ case INT16:
+ short[] sArr = (short[]) array;
+ ShortBuffer sBuf = (ShortBuffer) buffer;
+ sBuf.put(sArr);
+ break;
+ case INT32:
+ int[] iArr = (int[]) array;
+ IntBuffer iBuf = (IntBuffer) buffer;
+ iBuf.put(iArr);
+ break;
+ case INT64:
+ long[] lArr = (long[]) array;
+ LongBuffer lBuf = (LongBuffer) buffer;
+ lBuf.put(lArr);
+ break;
+ case BOOL:
+ boolean[] boolArr = (boolean[]) array;
+ ByteBuffer boolBuf = (ByteBuffer) buffer;
+ for (int i = 0; i < boolArr.length; i++) {
+ boolBuf.put(boolArr[i] ? (byte) 1 : (byte) 0);
+ }
+ break;
+ case STRING:
+ case UNKNOWN:
+ throw new IllegalArgumentException(
+ "Unexpected type, expected Java primitive found " + info.type);
+ }
+ } else {
+ // Recurse through array
+ long expectedSize = info.shape[curDim];
+ long actualSize = Array.getLength(array);
+ if (expectedSize != actualSize) {
+ throw new IllegalArgumentException(
+ "Mismatch in array sizes, expected "
+ + expectedSize
+ + " at dim "
+ + curDim
+ + " from shape "
+ + Arrays.toString(info.shape)
+ + ", found "
+ + actualSize);
+ } else {
+ for (int i = 0; i < actualSize; i++) {
+ fillBufferFromArray(info, Array.get(array, i), curDim + 1, buffer);
+ }
+ }
+ }
+ }
+
+ /**
+ * Fills the provided array with the values from the buffer, recursing through the array
+ * structure.
+ *
+ * @param info The tensor info containing the type and shape of the array.
+ * @param buffer The buffer to read from.
+ * @param curDim The current dimension we're processing.
+ * @param array The array object to write to.
+ */
+ static void fillArrayFromBuffer(TensorInfo info, Buffer buffer, int curDim, Object array) {
+ if (curDim == info.shape.length - 1) {
+ // Reached primitive values, copy into buffer
+ switch (info.type) {
+ case FLOAT16:
+ case BFLOAT16:
+ case FLOAT:
+ float[] fArr = (float[]) array;
+ FloatBuffer fBuf = (FloatBuffer) buffer;
+ fBuf.get(fArr);
+ break;
+ case DOUBLE:
+ double[] dArr = (double[]) array;
+ DoubleBuffer dBuf = (DoubleBuffer) buffer;
+ dBuf.get(dArr);
+ break;
+ case INT8:
+ case UINT8:
+ byte[] bArr = (byte[]) array;
+ ByteBuffer bBuf = (ByteBuffer) buffer;
+ bBuf.get(bArr);
+ break;
+ case INT16:
+ short[] sArr = (short[]) array;
+ ShortBuffer sBuf = (ShortBuffer) buffer;
+ sBuf.get(sArr);
+ break;
+ case INT32:
+ int[] iArr = (int[]) array;
+ IntBuffer iBuf = (IntBuffer) buffer;
+ iBuf.get(iArr);
+ break;
+ case INT64:
+ long[] lArr = (long[]) array;
+ LongBuffer lBuf = (LongBuffer) buffer;
+ lBuf.get(lArr);
+ break;
+ case BOOL:
+ boolean[] boolArr = (boolean[]) array;
+ ByteBuffer boolBuf = (ByteBuffer) buffer;
+ for (int i = 0; i < boolArr.length; i++) {
+ // Test to see if the byte is non-zero, non-zero bytes are true, zero bytes are false.
+ boolArr[i] = boolBuf.get() != 0;
+ }
+ break;
+ case STRING:
+ case UNKNOWN:
+ throw new IllegalArgumentException(
+ "Unexpected type, expected Java primitive found " + info.type);
+ }
+ } else {
+ // Recurse through array
+ long expectedSize = info.shape[curDim];
+ long actualSize = Array.getLength(array);
+ if (expectedSize != actualSize) {
+ throw new IllegalArgumentException(
+ "Mismatch in array sizes, expected "
+ + expectedSize
+ + " at dim "
+ + curDim
+ + " from shape "
+ + Arrays.toString(info.shape)
+ + ", found "
+ + actualSize);
+ } else {
+ for (int i = 0; i < actualSize; i++) {
+ fillArrayFromBuffer(info, buffer, curDim + 1, Array.get(array, i));
+ }
+ }
+ }
+ }
+
/**
* Returns expected JDK map capacity for a given size, this factors in the default JDK load factor
*
diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java
index 1c21387b50455..f3e9f21ef408d 100644
--- a/java/src/main/java/ai/onnxruntime/TensorInfo.java
+++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java
@@ -323,6 +323,9 @@ public long getNumElements() {
* all elements as that's the expected format of the native code. It can be reshaped to the
* correct shape using {@link OrtUtil#reshape(String[],long[])}.
*
+ *
For fp16 and bf16 tensors the output carrier type is float, and so this method produces
+ * multidimensional float arrays.
+ *
* @return A multidimensional array of the appropriate primitive type (or String).
* @throws OrtException If the shape isn't representable in Java (i.e. if one of its indices is
* greater than an int).
@@ -335,6 +338,8 @@ public Object makeCarrier() throws OrtException {
+ Arrays.toString(shape));
}
switch (type) {
+ case BFLOAT16:
+ case FLOAT16:
case FLOAT:
return OrtUtil.newFloatArray(shape);
case DOUBLE:
diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c
index 7b26291581395..6a3c279073860 100644
--- a/java/src/main/native/OrtJniUtil.c
+++ b/java/src/main/native/OrtJniUtil.c
@@ -502,104 +502,6 @@ jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSeque
return sequenceInfo;
}
-int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor) {
- int32_t inputLength = (*jniEnv)->GetArrayLength(jniEnv, inputArray);
- int64_t consumedSize = inputLength * onnxTypeSize(onnxType);
- switch (onnxType) {
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // maps to c type uint8_t
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { // maps to c type int8_t
- jbyteArray typedArr = (jbyteArray)inputArray;
- (*jniEnv)->GetByteArrayRegion(jniEnv, typedArr, 0, inputLength, (jbyte * )outputTensor);
- return consumedSize;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: // maps to c type uint16_t
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { // maps to c type int16_t
- jshortArray typedArr = (jshortArray)inputArray;
- (*jniEnv)->GetShortArrayRegion(jniEnv, typedArr, 0, inputLength, (jshort * )outputTensor);
- return consumedSize;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: // maps to c type uint32_t
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { // maps to c type int32_t
- jintArray typedArr = (jintArray)inputArray;
- (*jniEnv)->GetIntArrayRegion(jniEnv, typedArr, 0, inputLength, (jint * )outputTensor);
- return consumedSize;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: // maps to c type uint64_t
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { // maps to c type int64_t
- jlongArray typedArr = (jlongArray)inputArray;
- (*jniEnv)->GetLongArrayRegion(jniEnv, typedArr, 0, inputLength, (jlong * )outputTensor);
- return consumedSize;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // Non-IEEE floating-point format based on IEEE754 single-precision
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: {
- throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "16-bit float not supported.");
- return -1;
- /*
- float *floatArr = malloc(sizeof(float) * inputLength);
- uint16_t *halfArr = (uint16_t *) outputTensor;
- for (uint32_t i = 0; i < inputLength; i++) {
- floatArr[i] = convertHalfToFloat(halfArr[i]);
- }
- jfloatArray typedArr = (jfloatArray) inputArray;
- (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, floatArr);
- free(floatArr);
- return consumedSize;
- */
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { // maps to c type float
- jfloatArray typedArr = (jfloatArray)inputArray;
- (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, (jfloat * )outputTensor);
- return consumedSize;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { // maps to c type double
- jdoubleArray typedArr = (jdoubleArray)inputArray;
- (*jniEnv)->GetDoubleArrayRegion(jniEnv, typedArr, 0, inputLength, (jdouble * )outputTensor);
- return consumedSize;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: { // maps to c++ type std::string
- throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "String is not supported.");
- return -1;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
- jbooleanArray typedArr = (jbooleanArray)inputArray;
- (*jniEnv)->GetBooleanArrayRegion(jniEnv, typedArr, 0, inputLength, (jboolean *)outputTensor);
- return consumedSize;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: // complex with float32 real and imaginary components
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: // complex with float64 real and imaginary components
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED:
- default: {
- throwOrtException(jniEnv, convertErrorCode(ORT_INVALID_ARGUMENT), "Invalid outputTensor element type.");
- return -1;
- }
- }
-}
-
-int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor) {
- if (dimensionsRemaining == 1) {
- // write out 1d array of the respective primitive type
- return copyJavaToPrimitiveArray(jniEnv, onnxType, inputArray, outputTensor);
- } else {
- // recurse through the dimensions
- // Java arrays are objects until the final dimension
- jobjectArray inputObjArr = (jobjectArray)inputArray;
- int32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv, inputObjArr);
- int64_t sizeConsumed = 0;
- for (int32_t i = 0; i < dimLength; i++) {
- jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv, inputObjArr, i);
- int64_t consumed = copyJavaToTensor(jniEnv, onnxType, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr, outputTensor + sizeConsumed);
- sizeConsumed += consumed;
- // Cleanup reference to childArr so it doesn't prevent GC.
- (*jniEnv)->DeleteLocalRef(jniEnv, childArr);
- // If we failed to copy an array then break and return.
- if (consumed == -1) {
- return -1;
- }
- }
- return sizeConsumed;
- }
-}
-
int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray) {
int32_t outputLength = (*jniEnv)->GetArrayLength(jniEnv, outputArray);
if (outputLength == 0) return 0;
@@ -697,65 +599,6 @@ int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxT
}
}
-int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize,
- size_t dimensionsRemaining, jarray outputArray) {
- if (dimensionsRemaining == 1) {
- // write out 1d array of the respective primitive type
- return copyPrimitiveArrayToJava(jniEnv, onnxType, inputTensor, outputArray);
- } else {
- // recurse through the dimensions
- // Java arrays are objects until the final dimension
- jobjectArray outputObjArr = (jobjectArray)outputArray;
- int32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv, outputObjArr);
- int64_t sizeConsumed = 0;
- for (int32_t i = 0; i < dimLength; i++) {
- jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv, outputObjArr, i);
- int64_t consumed = copyTensorToJava(jniEnv, onnxType, inputTensor + sizeConsumed, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr);
- sizeConsumed += consumed;
- // Cleanup reference to childArr so it doesn't prevent GC.
- (*jniEnv)->DeleteLocalRef(jniEnv, childArr);
- // If we failed to copy an array then break and return.
- if (consumed == -1) {
- return -1;
- }
- }
- return sizeConsumed;
- }
-}
-
-jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) {
- jobject tempString = NULL;
- // Get the buffer size needed
- size_t totalStringLength = 0;
- OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetStringTensorDataLength(tensor, &totalStringLength));
- if (code != ORT_OK) {
- return NULL;
- }
-
- // Create the character and offset buffers, character is one larger to allow zero termination.
- char * characterBuffer = malloc(sizeof(char)*(totalStringLength+1));
- if (characterBuffer == NULL) {
- throwOrtException(jniEnv, 1, "OOM error");
- } else {
- size_t * offsets = malloc(sizeof(size_t));
- if (offsets != NULL) {
- // Get a view on the String data
- code = checkOrtStatus(jniEnv, api, api->GetStringTensorContent(tensor, characterBuffer, totalStringLength, offsets, 1));
-
- if (code == ORT_OK) {
- size_t curSize = (offsets[0]) + 1;
- characterBuffer[curSize-1] = '\0';
- tempString = (*jniEnv)->NewStringUTF(jniEnv, characterBuffer);
- }
-
- free((void*)characterBuffer);
- free((void*)offsets);
- }
- }
-
- return tempString;
-}
-
OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray) {
size_t bufferSize = 16;
char * tempBuffer = malloc(bufferSize);
diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h
index 023bc0c739583..7f41e06371f2a 100644
--- a/java/src/main/native/OrtJniUtil.h
+++ b/java/src/main/native/OrtJniUtil.h
@@ -54,16 +54,8 @@ jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInf
jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info);
-int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor);
-
-int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor);
-
int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray);
-int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, size_t dimensionsRemaining, jarray outputArray);
-
-jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor);
-
OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray);
jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor);
diff --git a/java/src/main/native/ai_onnxruntime_OnnxTensor.c b/java/src/main/native/ai_onnxruntime_OnnxTensor.c
index b694f57357bb5..d757bd6281499 100644
--- a/java/src/main/native/ai_onnxruntime_OnnxTensor.c
+++ b/java/src/main/native/ai_onnxruntime_OnnxTensor.c
@@ -8,72 +8,6 @@
#include "OrtJniUtil.h"
#include "ai_onnxruntime_OnnxTensor.h"
-/*
- * Class: ai_onnxruntime_OnnxTensor
- * Method: createTensor
- * Signature: (JJLjava/lang/Object;[JI)J
- */
-JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor
- (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject dataObj,
- jlongArray shape, jint onnxTypeJava) {
- (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
- const OrtApi* api = (const OrtApi*) apiHandle;
- OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
- // Convert type to ONNX C enum
- ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava);
-
- // Extract the shape information
- jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, shape, NULL);
- jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, shape);
-
- // Create the OrtValue
- OrtValue* ortValue = NULL;
- OrtErrorCode code = checkOrtStatus(jniEnv, api,
- api->CreateTensorAsOrtValue(
- allocator, (int64_t*)shapeArr, shapeLen, onnxType, &ortValue
- )
- );
- (*jniEnv)->ReleaseLongArrayElements(jniEnv, shape, shapeArr, JNI_ABORT);
-
- int failed = 0;
- if (code == ORT_OK) {
- // Get a reference to the OrtValue's data
- uint8_t* tensorData = NULL;
- code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**)&tensorData));
- if (code == ORT_OK) {
- // Check if we're copying a scalar or not
- if (shapeLen == 0) {
- // Scalars are passed in as a single element array
- int64_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, dataObj, tensorData);
- failed = copied == -1 ? 1 : failed;
- } else {
- // Extract the tensor shape information
- JavaTensorTypeShape typeShape;
- code = getTensorTypeShape(jniEnv, &typeShape, api, ortValue);
-
- if (code == ORT_OK) {
- // Copy the java array into the tensor
- int64_t copied = copyJavaToTensor(jniEnv, onnxType, typeShape.elementCount,
- typeShape.dimensions, dataObj, tensorData);
- failed = copied == -1 ? 1 : failed;
- } else {
- failed = 1;
- }
- }
- } else {
- failed = 1;
- }
- }
-
- if (failed) {
- api->ReleaseValue(ortValue);
- ortValue = NULL;
- }
-
- // Return the pointer to the OrtValue
- return (jlong) ortValue;
-}
-
/*
* Class: ai_onnxruntime_OnnxTensor
* Method: createTensorFromBuffer
@@ -227,7 +161,7 @@ JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getBuffer
size_t sizeBytes = typeShape.elementCount * typeSize;
uint8_t* arr = NULL;
- code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
+ code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**)&arr));
if (code == ORT_OK) {
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, (jlong)sizeBytes);
@@ -401,11 +335,11 @@ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool
/*
* Class: ai_onnxruntime_OnnxTensor
- * Method: getArray
- * Signature: (JJLjava/lang/Object;)V
+ * Method: getStringArray
+ * Signature: (JJ[Ljava/lang/String;)V
*/
-JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray
- (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobject carrier) {
+JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getStringArray
+ (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobjectArray carrier) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtValue* value = (OrtValue*) handle;
@@ -415,12 +349,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray
if (typeShape.onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
copyStringTensorToArray(jniEnv, api, value, typeShape.elementCount, carrier);
} else {
- uint8_t* arr = NULL;
- code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(value, (void**)&arr));
- if (code == ORT_OK) {
- copyTensorToJava(jniEnv, typeShape.onnxTypeEnum, arr, typeShape.elementCount,
- typeShape.dimensions, (jarray)carrier);
- }
+ throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Non-string types are not supported by this codepath, please raise a Github issue as it should not reach here.");
}
}
}
diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java
index 11141a3a65a3e..7cb6305923279 100644
--- a/java/src/test/java/ai/onnxruntime/InferenceTest.java
+++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java
@@ -495,12 +495,12 @@ public void throwWrongInputName() throws OrtException {
container.put("wrong_name", OnnxTensor.createTensor(env, tensor));
try {
session.run(container);
- OnnxValue.close(container.values());
fail("Should throw exception for incorrect name.");
} catch (OrtException e) {
- OnnxValue.close(container.values());
String msg = e.getMessage();
assertTrue(msg.contains("Unknown input name"));
+ } finally {
+ OnnxValue.close(container.values());
}
}
}
@@ -522,12 +522,57 @@ public void throwWrongInputType() throws OrtException {
container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor));
try {
session.run(container);
- OnnxValue.close(container.values());
fail("Should throw exception for incorrect type.");
} catch (OrtException e) {
- OnnxValue.close(container.values());
String msg = e.getMessage();
assertTrue(msg.contains("Unexpected input data type"));
+ } finally {
+ OnnxValue.close(container.values());
+ }
+ }
+ }
+
+ @Test
+ public void throwWrongSizeInput() throws OrtException {
+ SqueezeNetTuple tuple = openSessionSqueezeNet();
+ try (OrtSession session = tuple.session) {
+
+ float[] inputData = tuple.inputData;
+ NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
+ Map container = new HashMap<>();
+ float[] wrongSizeData = Arrays.copyOf(inputData, 2 * 224 * 224);
+ Object tensor = OrtUtil.reshape(wrongSizeData, new long[] {1, 2, 224, 224});
+ container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor));
+ try {
+ session.run(container);
+ fail("Should throw exception for incorrect size.");
+ } catch (OrtException e) {
+ String msg = e.getMessage();
+ assertTrue(msg.contains("Got invalid dimensions for input"));
+ } finally {
+ OnnxValue.close(container.values());
+ }
+ }
+ }
+
+ @Test
+ public void throwWrongRankInput() throws OrtException {
+ SqueezeNetTuple tuple = openSessionSqueezeNet();
+ try (OrtSession session = tuple.session) {
+
+ float[] inputData = tuple.inputData;
+ NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
+ Map container = new HashMap<>();
+ Object tensor = OrtUtil.reshape(inputData, new long[] {1, 1, 3, 224, 224});
+ container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor));
+ try {
+ session.run(container);
+ fail("Should throw exception for incorrect size.");
+ } catch (OrtException e) {
+ String msg = e.getMessage();
+ assertTrue(msg.contains("Invalid rank for input"));
+ } finally {
+ OnnxValue.close(container.values());
}
}
}
@@ -550,12 +595,12 @@ public void throwExtraInputs() throws OrtException {
container.put("extra", OnnxTensor.createTensor(env, tensor));
try {
session.run(container);
- OnnxValue.close(container.values());
fail("Should throw exception for too many inputs.");
} catch (OrtException e) {
- OnnxValue.close(container.values());
String msg = e.getMessage();
assertTrue(msg.contains("Unexpected number of inputs"));
+ } finally {
+ OnnxValue.close(container.values());
}
}
}
@@ -565,12 +610,11 @@ public void testMultiThreads() throws OrtException, InterruptedException {
int numThreads = 10;
int loop = 10;
SqueezeNetTuple tuple = openSessionSqueezeNet();
+ Map container = new HashMap<>();
try (OrtSession session = tuple.session) {
-
float[] inputData = tuple.inputData;
float[] expectedOutput = tuple.outputData;
NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
- Map container = new HashMap<>();
long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape;
Object tensor = OrtUtil.reshape(inputData, inputShape);
container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor));
@@ -592,8 +636,9 @@ public void testMultiThreads() throws OrtException, InterruptedException {
}
executor.shutdown();
executor.awaitTermination(1, TimeUnit.MINUTES);
- OnnxValue.close(container.values());
assertTrue(executor.isTerminated());
+ } finally {
+ OnnxValue.close(container.values());
}
}
diff --git a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java
index ea210d96c1507..064f14f3b51ff 100644
--- a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java
+++ b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java
@@ -12,8 +12,11 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
import java.nio.ShortBuffer;
+import java.util.ArrayList;
import java.util.Collections;
+import java.util.List;
import java.util.SplittableRandom;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
@@ -93,30 +96,108 @@ public void testScalarCreation() throws OrtException {
}
@Test
- public void testBufferCreation() throws OrtException {
+ public void testArrayCreation() throws OrtException {
OrtEnvironment env = OrtEnvironment.getEnvironment();
- // Test creating a value from an array
- // Arrays result in tensors allocated by ORT, so they do not have a backing java.nio.Buffer
+ // Test creating a value from a single dimensional array
float[] arrValues = new float[] {0, 1, 2, 3, 4};
try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) {
- // array creation isn't backed by buffers
- assertFalse(t.ownsBuffer());
- assertFalse(t.getBufferRef().isPresent());
- FloatBuffer buf = t.getFloatBuffer();
+ Assertions.assertTrue(t.ownsBuffer());
+ Assertions.assertTrue(t.getBufferRef().isPresent());
+ FloatBuffer buf = (FloatBuffer) t.getBufferRef().get();
float[] output = new float[arrValues.length];
buf.get(output);
Assertions.assertArrayEquals(arrValues, output);
- // Can't modify the tensor through this buffer.
+ // Can modify the tensor through this buffer.
buf.put(0, 25);
- Assertions.assertArrayEquals(arrValues, output);
+ Assertions.assertArrayEquals(new float[] {25, 1, 2, 3, 4}, (float[]) t.getValue());
}
+ // Test creating a value from a multidimensional float array
+ float[][][] arr3dValues =
+ new float[][][] {
+ {{0, 1, 2}, {3, 4, 5}},
+ {{6, 7, 8}, {9, 10, 11}},
+ {{12, 13, 14}, {15, 16, 17}},
+ {{18, 19, 20}, {21, 22, 23}}
+ };
+ try (OnnxTensor t = OnnxTensor.createTensor(env, arr3dValues)) {
+ Assertions.assertArrayEquals(new long[] {4, 2, 3}, t.getInfo().getShape());
+ Assertions.assertTrue(t.ownsBuffer());
+ Assertions.assertTrue(t.getBufferRef().isPresent());
+ float[][][] output = (float[][][]) t.getValue();
+ Assertions.assertArrayEquals(arr3dValues, output);
+
+ // Can modify the tensor through the buffer.
+ FloatBuffer buf = (FloatBuffer) t.getBufferRef().get();
+ buf.put(0, 25);
+ buf.put(12, 32);
+ buf.put(13, 33);
+ buf.put(23, 35);
+ arr3dValues[0][0][0] = 25;
+ arr3dValues[2][0][0] = 32;
+ arr3dValues[2][0][1] = 33;
+ arr3dValues[3][1][2] = 35;
+ output = (float[][][]) t.getValue();
+ Assertions.assertArrayEquals(arr3dValues, output);
+ }
+
+ // Test creating a value from a multidimensional int array
+ int[][][] iArr3dValues =
+ new int[][][] {
+ {{0, 1, 2}, {3, 4, 5}},
+ {{6, 7, 8}, {9, 10, 11}},
+ {{12, 13, 14}, {15, 16, 17}},
+ {{18, 19, 20}, {21, 22, 23}}
+ };
+ try (OnnxTensor t = OnnxTensor.createTensor(env, iArr3dValues)) {
+ Assertions.assertArrayEquals(new long[] {4, 2, 3}, t.getInfo().getShape());
+ Assertions.assertTrue(t.ownsBuffer());
+ Assertions.assertTrue(t.getBufferRef().isPresent());
+ int[][][] output = (int[][][]) t.getValue();
+ Assertions.assertArrayEquals(iArr3dValues, output);
+
+ // Can modify the tensor through the buffer.
+ IntBuffer buf = (IntBuffer) t.getBufferRef().get();
+ buf.put(0, 25);
+ iArr3dValues[0][0][0] = 25;
+ output = (int[][][]) t.getValue();
+ Assertions.assertArrayEquals(iArr3dValues, output);
+ }
+
+ // Test creating a value from a ragged array throws
+ int[][][] ragged =
+ new int[][][] {
+ {{0, 1, 2}, {3, 4, 5}},
+ {{6, 7, 8}},
+ {{12, 13}, {15, 16, 17}},
+ {{18, 19, 20}, {21, 22, 23}}
+ };
+ try (OnnxTensor t = OnnxTensor.createTensor(env, ragged)) {
+ Assertions.fail("Can't create tensors from ragged arrays");
+ } catch (OrtException e) {
+ Assertions.assertTrue(e.getMessage().contains("ragged"));
+ }
+
+ // Test creating a value from a non-array, non-primitive type throws.
+ List list = new ArrayList<>(5);
+ list.add(5);
+ try (OnnxTensor t = OnnxTensor.createTensor(env, list)) {
+ Assertions.fail("Can't create tensors from lists");
+ } catch (OrtException e) {
+ Assertions.assertTrue(e.getMessage().contains("Cannot convert"));
+ }
+ }
+
+ @Test
+ public void testBufferCreation() throws OrtException {
+ OrtEnvironment env = OrtEnvironment.getEnvironment();
+
// Test creating a value from a non-direct byte buffer
// Non-direct byte buffers are allocated on the Java heap and must be copied into off-heap
- // direct byte buffers
- // which can be directly passed to ORT
+ // direct byte buffers which can be directly passed to ORT
+ float[] arrValues = new float[] {0, 1, 2, 3, 4};
FloatBuffer nonDirectBuffer = FloatBuffer.allocate(5);
nonDirectBuffer.put(arrValues);
nonDirectBuffer.rewind();
@@ -335,10 +416,12 @@ public void testFp32ToFp16() throws OrtException {
String modelPath = TestHelpers.getResourcePath("/java-fp32-to-fp16.onnx").toString();
SplittableRandom rng = new SplittableRandom(1);
- float[][] input = new float[10][5];
+ int dim1 = 10, dim2 = 5;
+ float[][] input = new float[dim1][dim2];
+ float[][] expectedOutput = new float[dim1][dim2];
FloatBuffer floatBuf =
- ByteBuffer.allocateDirect(4 * 10 * 5).order(ByteOrder.nativeOrder()).asFloatBuffer();
- ShortBuffer shortBuf = ShortBuffer.allocate(10 * 5);
+ ByteBuffer.allocateDirect(4 * dim1 * dim2).order(ByteOrder.nativeOrder()).asFloatBuffer();
+ ShortBuffer shortBuf = ShortBuffer.allocate(dim1 * dim2);
// Generate data
for (int i = 0; i < input.length; i++) {
@@ -347,6 +430,8 @@ public void testFp32ToFp16() throws OrtException {
input[i][j] = Float.intBitsToFloat(bits);
floatBuf.put(input[i][j]);
shortBuf.put(Fp16Conversions.floatToFp16(input[i][j]));
+ expectedOutput[i][j] =
+ Fp16Conversions.fp16ToFloat(Fp16Conversions.floatToFp16(input[i][j]));
}
}
floatBuf.rewind();
@@ -354,25 +439,31 @@ public void testFp32ToFp16() throws OrtException {
try (OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
OrtSession session = env.createSession(modelPath, opts);
- OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {10, 5});
+ OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {dim1, dim2});
OrtSession.Result result = session.run(Collections.singletonMap("input", tensor))) {
OnnxTensor output = (OnnxTensor) result.get(0);
// Check outbound Java side cast to fp32 works
FloatBuffer castOutput = output.getFloatBuffer();
- float[] expectedFloatArr = new float[10 * 5];
+ float[] expectedFloatArr = new float[dim1 * dim2];
Fp16Conversions.convertFp16BufferToFloatBuffer(shortBuf).get(expectedFloatArr);
- float[] actualFloatArr = new float[10 * 5];
+ float[] actualFloatArr = new float[dim1 * dim2];
castOutput.get(actualFloatArr);
Assertions.assertArrayEquals(expectedFloatArr, actualFloatArr);
// Check bits are correct
ShortBuffer outputBuf = output.getShortBuffer();
- short[] expectedShortArr = new short[10 * 5];
+ short[] expectedShortArr = new short[dim1 * dim2];
shortBuf.get(expectedShortArr);
- short[] actualShortArr = new short[10 * 5];
+ short[] actualShortArr = new short[dim1 * dim2];
outputBuf.get(actualShortArr);
Assertions.assertArrayEquals(expectedShortArr, actualShortArr);
+
+ // Check outbound fp16 -> float[] conversion
+ float[][] floats = (float[][]) output.getValue();
+ for (int i = 0; i < dim1; i++) {
+ Assertions.assertArrayEquals(expectedOutput[i], floats[i]);
+ }
}
}
@@ -382,10 +473,12 @@ public void testFp32ToBf16() throws OrtException {
String modelPath = TestHelpers.getResourcePath("/java-fp32-to-bf16.onnx").toString();
SplittableRandom rng = new SplittableRandom(1);
- float[][] input = new float[10][5];
+ int dim1 = 10, dim2 = 5;
+ float[][] input = new float[dim1][dim2];
+ float[][] expectedOutput = new float[dim1][dim2];
FloatBuffer floatBuf =
- ByteBuffer.allocateDirect(4 * 10 * 5).order(ByteOrder.nativeOrder()).asFloatBuffer();
- ShortBuffer shortBuf = ShortBuffer.allocate(10 * 5);
+ ByteBuffer.allocateDirect(4 * dim1 * dim2).order(ByteOrder.nativeOrder()).asFloatBuffer();
+ ShortBuffer shortBuf = ShortBuffer.allocate(dim1 * dim2);
// Generate data
for (int i = 0; i < input.length; i++) {
@@ -394,6 +487,8 @@ public void testFp32ToBf16() throws OrtException {
input[i][j] = Float.intBitsToFloat(bits);
floatBuf.put(input[i][j]);
shortBuf.put(Fp16Conversions.floatToBf16(input[i][j]));
+ expectedOutput[i][j] =
+ Fp16Conversions.bf16ToFloat(Fp16Conversions.floatToBf16(input[i][j]));
}
}
floatBuf.rewind();
@@ -401,25 +496,31 @@ public void testFp32ToBf16() throws OrtException {
try (OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
OrtSession session = env.createSession(modelPath, opts);
- OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {10, 5});
+ OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {dim1, dim2});
OrtSession.Result result = session.run(Collections.singletonMap("input", tensor))) {
OnnxTensor output = (OnnxTensor) result.get(0);
// Check outbound Java side cast to fp32 works
FloatBuffer castOutput = output.getFloatBuffer();
- float[] expectedFloatArr = new float[10 * 5];
+ float[] expectedFloatArr = new float[dim1 * dim2];
Fp16Conversions.convertBf16BufferToFloatBuffer(shortBuf).get(expectedFloatArr);
- float[] actualFloatArr = new float[10 * 5];
+ float[] actualFloatArr = new float[dim1 * dim2];
castOutput.get(actualFloatArr);
Assertions.assertArrayEquals(expectedFloatArr, actualFloatArr);
// Check bits are correct
ShortBuffer outputBuf = output.getShortBuffer();
- short[] expectedShortArr = new short[10 * 5];
+ short[] expectedShortArr = new short[dim1 * dim2];
shortBuf.get(expectedShortArr);
- short[] actualShortArr = new short[10 * 5];
+ short[] actualShortArr = new short[dim1 * dim2];
outputBuf.get(actualShortArr);
Assertions.assertArrayEquals(expectedShortArr, actualShortArr);
+
+ // Check outbound bf16 -> float[] conversion
+ float[][] floats = (float[][]) output.getValue();
+ for (int i = 0; i < dim1; i++) {
+ Assertions.assertArrayEquals(expectedOutput[i], floats[i]);
+ }
}
}
diff --git a/js/common/lib/tensor-factory-impl.ts b/js/common/lib/tensor-factory-impl.ts
index 52e028a9fcd31..cbc0270091818 100644
--- a/js/common/lib/tensor-factory-impl.ts
+++ b/js/common/lib/tensor-factory-impl.ts
@@ -11,6 +11,7 @@ import {
TensorFromImageBitmapOptions,
TensorFromImageDataOptions,
TensorFromImageElementOptions,
+ TensorFromMLTensorOptions,
TensorFromTextureOptions,
TensorFromUrlOptions,
} from './tensor-factory.js';
@@ -152,7 +153,7 @@ export const tensorFromImage = async (
}
};
const createCanvasContext = (canvas: HTMLCanvasElement | OffscreenCanvas) => {
- if (canvas instanceof HTMLCanvasElement) {
+ if (typeof HTMLCanvasElement !== 'undefined' && canvas instanceof HTMLCanvasElement) {
return canvas.getContext('2d');
} else if (canvas instanceof OffscreenCanvas) {
return canvas.getContext('2d') as OffscreenCanvasRenderingContext2D;
@@ -310,6 +311,17 @@ export const tensorFromGpuBuffer = (
+ mlTensor: TensorInterface.MLTensorType,
+ options: TensorFromMLTensorOptions,
+): Tensor => {
+ const { dataType, dims, download, dispose } = options;
+ return new Tensor({ location: 'ml-tensor', type: dataType ?? 'float32', mlTensor, dims, download, dispose });
+};
+
/**
* implementation of Tensor.fromPinnedBuffer().
*/
diff --git a/js/common/lib/tensor-factory.ts b/js/common/lib/tensor-factory.ts
index 7938b4a4eb927..f66684112623e 100644
--- a/js/common/lib/tensor-factory.ts
+++ b/js/common/lib/tensor-factory.ts
@@ -86,6 +86,20 @@ export interface GpuBufferConstructorParameters
+ extends CommonConstructorParameters,
+ GpuResourceConstructorParameters {
+ /**
+ * Specify the location of the data to be 'ml-tensor'.
+ */
+ readonly location: 'ml-tensor';
+
+ /**
+ * Specify the WebNN MLTensor that holds the tensor data.
+ */
+ readonly mlTensor: Tensor.MLTensorType;
+}
+
// #endregion
// the following region contains type definitions of each individual options.
@@ -219,6 +233,15 @@ export interface TensorFromGpuBufferOptions
dataType?: T;
}
+export interface TensorFromMLTensorOptions
+ extends Pick,
+ GpuResourceConstructorParameters {
+ /**
+ * Describes the data type of the tensor.
+ */
+ dataType?: T;
+}
+
// #endregion
/**
@@ -336,6 +359,29 @@ export interface TensorFactory {
options: TensorFromGpuBufferOptions,
): TypedTensor;
+ /**
+ * create a tensor from a WebNN MLTensor
+ *
+ * @param tensor - the MLTensor object to create tensor from
+ * @param options - An optional object representing options for creating tensor from a WebNN MLTensor.
+ *
+ * The options include following properties:
+ * - `dataType`: the data type of the tensor. If omitted, assume 'float32'.
+ * - `dims`: the dimension of the tensor. Required.
+ * - `download`: an optional function to download the tensor data from the MLTensor to CPU. If omitted, the MLTensor
+ * data will not be able to download. Usually, this is provided by the WebNN backend for the inference outputs.
+ * Users don't need to provide this function.
+ * - `dispose`: an optional function to dispose the tensor data on the WebNN MLTensor. If omitted, the MLTensor will
+ * not be disposed. Usually, this is provided by the WebNN backend for the inference outputs. Users don't need to
+ * provide this function.
+ *
+ * @returns a tensor object
+ */
+ fromMLTensor(
+ tensor: Tensor.MLTensorType,
+ options: TensorFromMLTensorOptions,
+ ): TypedTensor;
+
/**
* create a tensor from a pre-allocated buffer. The buffer will be used as a pinned buffer.
*
diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts
index 342f5e3a467eb..c0e1582c17de5 100644
--- a/js/common/lib/tensor-impl.ts
+++ b/js/common/lib/tensor-impl.ts
@@ -6,16 +6,19 @@ import { TensorToDataUrlOptions, TensorToImageDataOptions } from './tensor-conve
import {
tensorFromGpuBuffer,
tensorFromImage,
+ tensorFromMLTensor,
tensorFromPinnedBuffer,
tensorFromTexture,
} from './tensor-factory-impl.js';
import {
CpuPinnedConstructorParameters,
GpuBufferConstructorParameters,
+ MLTensorConstructorParameters,
TensorFromGpuBufferOptions,
TensorFromImageBitmapOptions,
TensorFromImageDataOptions,
TensorFromImageElementOptions,
+ TensorFromMLTensorOptions,
TensorFromTextureOptions,
TensorFromUrlOptions,
TextureConstructorParameters,
@@ -37,6 +40,7 @@ type TensorDataType = TensorInterface.DataType;
type TensorDataLocation = TensorInterface.DataLocation;
type TensorTextureType = TensorInterface.TextureType;
type TensorGpuBufferType = TensorInterface.GpuBufferType;
+type TensorMLTensorType = TensorInterface.MLTensorType;
/**
* the implementation of Tensor interface.
@@ -86,6 +90,15 @@ export class Tensor implements TensorInterface {
*/
constructor(params: GpuBufferConstructorParameters);
+ /**
+ * Construct a new tensor object from the WebNN MLTensor with the given type and dims.
+ *
+ * Tensor's location will be set to 'ml-tensor'.
+ *
+ * @param params - Specify the parameters to construct the tensor.
+ */
+ constructor(params: MLTensorConstructorParameters);
+
/**
* implementation.
*/
@@ -98,7 +111,8 @@ export class Tensor implements TensorInterface {
| readonly boolean[]
| CpuPinnedConstructorParameters
| TextureConstructorParameters
- | GpuBufferConstructorParameters,
+ | GpuBufferConstructorParameters
+ | MLTensorConstructorParameters,
arg1?: TensorDataType | Uint8ClampedArray | readonly number[] | readonly string[] | readonly boolean[],
arg2?: readonly number[],
) {
@@ -155,6 +169,25 @@ export class Tensor implements TensorInterface {
this.disposer = arg0.dispose;
break;
}
+ case 'ml-tensor': {
+ if (
+ type !== 'float32' &&
+ type !== 'float16' &&
+ type !== 'int32' &&
+ type !== 'int64' &&
+ type !== 'uint32' &&
+ type !== 'uint64' &&
+ type !== 'int8' &&
+ type !== 'uint8' &&
+ type !== 'bool'
+ ) {
+ throw new TypeError(`unsupported type "${type}" to create tensor from MLTensor`);
+ }
+ this.mlTensorData = arg0.mlTensor;
+ this.downloader = arg0.download;
+ this.disposer = arg0.dispose;
+ break;
+ }
default:
throw new Error(`Tensor constructor: unsupported location '${this.dataLocation}'`);
}
@@ -325,6 +358,13 @@ export class Tensor implements TensorInterface {
return tensorFromGpuBuffer(gpuBuffer, options);
}
+ static fromMLTensor(
+ mlTensor: TensorMLTensorType,
+ options: TensorFromMLTensorOptions,
+ ): TensorInterface {
+ return tensorFromMLTensor(mlTensor, options);
+ }
+
static fromPinnedBuffer(
type: T,
buffer: TensorInterface.DataTypeMap[T],
@@ -373,6 +413,11 @@ export class Tensor implements TensorInterface {
*/
private gpuBufferData?: TensorGpuBufferType;
+ /**
+ * stores the underlying WebNN MLTensor when location is 'ml-tensor'. otherwise empty.
+ */
+ private mlTensorData?: TensorMLTensorType;
+
/**
* stores an optional downloader function to download data from GPU to CPU.
*/
@@ -420,6 +465,14 @@ export class Tensor implements TensorInterface {
}
return this.gpuBufferData;
}
+
+ get mlTensor(): TensorMLTensorType {
+ this.ensureValid();
+ if (!this.mlTensorData) {
+ throw new Error('The data is not stored as a WebNN MLTensor.');
+ }
+ return this.mlTensorData;
+ }
// #endregion
// #region methods
@@ -431,7 +484,8 @@ export class Tensor implements TensorInterface {
case 'cpu-pinned':
return this.data;
case 'texture':
- case 'gpu-buffer': {
+ case 'gpu-buffer':
+ case 'ml-tensor': {
if (!this.downloader) {
throw new Error('The current tensor is not created with a specified data downloader.');
}
@@ -472,6 +526,7 @@ export class Tensor implements TensorInterface {
this.cpuData = undefined;
this.gpuTextureData = undefined;
this.gpuBufferData = undefined;
+ this.mlTensorData = undefined;
this.downloader = undefined;
this.isDownloading = undefined;
diff --git a/js/common/lib/tensor-utils-impl.ts b/js/common/lib/tensor-utils-impl.ts
index 9c633cd95fac3..97b1735e6eac5 100644
--- a/js/common/lib/tensor-utils-impl.ts
+++ b/js/common/lib/tensor-utils-impl.ts
@@ -4,6 +4,7 @@
import {
CpuPinnedConstructorParameters,
GpuBufferConstructorParameters,
+ MLTensorConstructorParameters,
TextureConstructorParameters,
} from './tensor-factory.js';
import { Tensor } from './tensor-impl.js';
@@ -56,6 +57,13 @@ export const tensorReshape = (tensor: Tensor, dims: readonly number[]): Tensor =
type: tensor.type as GpuBufferConstructorParameters['type'],
dims,
});
+ case 'ml-tensor':
+ return new Tensor({
+ location: 'ml-tensor',
+ mlTensor: tensor.mlTensor,
+ type: tensor.type as MLTensorConstructorParameters['type'],
+ dims,
+ });
default:
throw new Error(`tensorReshape: tensor location ${tensor.location} is not supported`);
}
diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts
index 8a1197994393b..17e2f4d37c91f 100644
--- a/js/common/lib/tensor.ts
+++ b/js/common/lib/tensor.ts
@@ -42,6 +42,13 @@ interface TypedTensorBase {
*/
readonly gpuBuffer: Tensor.GpuBufferType;
+ /**
+ * Get the WebNN MLTensor that holds the tensor data.
+ *
+ * If the data is not in a WebNN MLTensor, throw error.
+ */
+ readonly mlTensor: Tensor.MLTensorType;
+
/**
* Get the buffer data of the tensor.
*
@@ -136,15 +143,36 @@ export declare namespace Tensor {
*/
export type GpuBufferType = { size: number; mapState: 'unmapped' | 'pending' | 'mapped' };
+ /**
+ * type alias for WebNN MLTensor
+ *
+ * The specification for WebNN's MLTensor is currently in flux.
+ */
+ export type MLTensorType = unknown;
+
/**
* supported data types for constructing a tensor from a WebGPU buffer
*/
export type GpuBufferDataTypes = 'float32' | 'float16' | 'int32' | 'int64' | 'uint32' | 'uint8' | 'bool';
+ /**
+ * supported data types for constructing a tensor from a WebNN MLTensor
+ */
+ export type MLTensorDataTypes =
+ | 'float32'
+ | 'float16'
+ | 'int8'
+ | 'uint8'
+ | 'int32'
+ | 'uint32'
+ | 'int64'
+ | 'uint64'
+ | 'bool';
+
/**
* represent where the tensor data is stored
*/
- export type DataLocation = 'none' | 'cpu' | 'cpu-pinned' | 'texture' | 'gpu-buffer';
+ export type DataLocation = 'none' | 'cpu' | 'cpu-pinned' | 'texture' | 'gpu-buffer' | 'ml-tensor';
/**
* represent the data type of a tensor
diff --git a/js/react_native/e2e/.detoxrc.js b/js/react_native/e2e/.detoxrc.js
index 0792c3d528585..e886a363d378b 100644
--- a/js/react_native/e2e/.detoxrc.js
+++ b/js/react_native/e2e/.detoxrc.js
@@ -6,7 +6,7 @@ module.exports = {
config: 'test/jest.config.js',
},
jest: {
- setupTimeout: 120000,
+ setupTimeout: 240000,
},
},
apps: {
diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md
index 6fd4f9af20432..6c50f3752737b 100644
--- a/js/web/docs/webnn-operators.md
+++ b/js/web/docs/webnn-operators.md
@@ -53,6 +53,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| LessOrEqual | ai.onnx(12-15, 16+) | lesserOrEqual | ✓ | ✓ | |
| Log | ai.onnx(7-12, 13+) | log | ✓ | ✓ | |
| LpPool | ai.onnx(7-10, 11-17, 18+) | l2Pool2d | ✗ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'p' value is 2 |
+| LSTM | ai.onnx(7-13, 14-21, 22+) | lstm | ✓ | ✓ | Only supports 'layout' == 0, 'input_forget' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' |
| MatMul | ai.onnx(7-8, 9-12, 13+) | matmul | ✓ | ✓ | |
| Max | ai.onnx(7, 8-11, 12, 13+) | max | ✓ | ✓ | |
| MaxPool | ai.onnx(7, 8-9, 10, 11, 12+) | maxPool2d | ✓ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'storage_order' != 1, one output |
diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts
index 39f8c2a6d0db3..bfb74355b0d70 100644
--- a/js/web/lib/wasm/jsep/backend-webgpu.ts
+++ b/js/web/lib/wasm/jsep/backend-webgpu.ts
@@ -785,15 +785,20 @@ export class WebGpuBackend {
this.sessionExternalDataMapping.set(sessionId, sessionInputOutputMapping);
}
+ // the buffer may be user created, or managed by GPU data manager.
+ // The GPU data manager will not manage these buffers. we register them as external buffers.
+ //
+ // The map `sessionInputOutputMapping` is used to store the data ID and buffer for each input/output. Once a
+ // specific input/output is registered, the data ID will not change.
const previousBuffer = sessionInputOutputMapping.get(index);
- const id = this.gpuDataManager.registerExternalBuffer(buffer, size, previousBuffer?.[1]);
+ const id = this.gpuDataManager.registerExternalBuffer(buffer, size, previousBuffer);
sessionInputOutputMapping.set(index, [id, buffer]);
return id;
}
unregisterBuffers(sessionId: number): void {
const sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId);
if (sessionInputOutputMapping) {
- sessionInputOutputMapping.forEach((bufferInfo) => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[1]));
+ sessionInputOutputMapping.forEach((bufferInfo) => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[0]));
this.sessionExternalDataMapping.delete(sessionId);
}
}
diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts
new file mode 100644
index 0000000000000..685f3dc019461
--- /dev/null
+++ b/js/web/lib/wasm/jsep/backend-webnn.ts
@@ -0,0 +1,169 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from
+// WebNN API specification.
+// https://github.com/webmachinelearning/webnn/issues/677
+///
+
+import { Env, Tensor } from 'onnxruntime-common';
+
+import { DataType } from '../wasm-common';
+import { getInstance } from '../wasm-factory';
+
+import { createView } from './tensor-view';
+import { TensorId, createTensorManager } from './webnn/tensor-manager';
+import { configureLogger, LOG_DEBUG } from './log';
+
+/*
+ * TensorProto::data_type to WebNN OperandType mapping.
+ */
+const onnxDataTypeToWebnnDataType = new Map([
+ [DataType.float, 'float32'],
+ [DataType.float16, 'float16'],
+ [DataType.int32, 'int32'],
+ [DataType.uint32, 'uint32'],
+ [DataType.int64, 'int64'],
+ [DataType.uint64, 'uint64'],
+ [DataType.int8, 'int8'],
+ [DataType.uint8, 'uint8'],
+ [DataType.bool, 'uint8'],
+]);
+
+/**
+ * WebNN backend implementation. This class is used to keep track of the MLTensors created by the backend and keep track
+ * of the current MLContext being used by the sessions.
+ */
+export class WebNNBackend {
+ /**
+ * Tensor managers for each session.
+ */
+ private tensorManager = createTensorManager(this);
+ /**
+ * Maps from session id to MLContexts.
+ */
+ private mlContextBySessionId = new Map();
+ /**
+ * Maps from MLContext to session ids.
+ */
+ private sessionIdsByMLContext = new Map>();
+ /**
+ * Current session id.
+ */
+ private activeSessionId?: number;
+
+ constructor(env: Env) {
+ configureLogger(env.logLevel!, !!env.debug);
+ }
+
+ public get currentSessionId(): number {
+ if (this.activeSessionId === undefined) {
+ throw new Error('No active session');
+ }
+ return this.activeSessionId;
+ }
+
+ public onRunStart(sessionId: number): void {
+ this.activeSessionId = sessionId;
+ }
+
+ public get currentContext(): MLContext {
+ const mlContext = this.getMLContext(this.currentSessionId);
+ if (!mlContext) {
+ throw new Error(`No MLContext found for session ${this.currentSessionId}`);
+ }
+ return mlContext;
+ }
+
+ public registerMLContext(sessionId: number, mlContext: MLContext): void {
+ this.mlContextBySessionId.set(sessionId, mlContext);
+ let sessionIds = this.sessionIdsByMLContext.get(mlContext);
+ if (!sessionIds) {
+ sessionIds = new Set();
+ this.sessionIdsByMLContext.set(mlContext, sessionIds);
+ }
+ sessionIds.add(sessionId);
+ }
+
+ public onReleaseSession(sessionId: number): void {
+ const mlContext = this.mlContextBySessionId.get(sessionId)!;
+ if (!mlContext) {
+ // Current session is not a WebNN session.
+ return;
+ }
+ this.mlContextBySessionId.delete(sessionId);
+ const sessionIds = this.sessionIdsByMLContext.get(mlContext)!;
+ sessionIds.delete(sessionId);
+ if (sessionIds.size === 0) {
+ this.sessionIdsByMLContext.delete(mlContext);
+ this.tensorManager.releaseTensorsForContext(mlContext);
+ }
+ }
+
+ public getMLContext(sessionId: number): MLContext | undefined {
+ return this.mlContextBySessionId.get(sessionId);
+ }
+
+ public reserveTensorId(): TensorId {
+ return this.tensorManager.reserveTensorId();
+ }
+
+ public releaseTensorId(tensorId: TensorId): void {
+ LOG_DEBUG('verbose', () => `[WebNN] releaseTensorId {tensorId: ${tensorId}}`);
+ this.tensorManager.releaseTensorId(tensorId);
+ }
+
+ public async ensureTensor(
+ tensorId: TensorId,
+ onnxDataType: DataType,
+ dimensions: number[],
+ copyOld: boolean,
+ ): Promise {
+ const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType);
+ if (!webnnDataType) {
+ throw new Error(`Unsupported ONNX data type: ${onnxDataType}`);
+ }
+ return this.tensorManager.ensureTensor(tensorId, webnnDataType, dimensions, copyOld);
+ }
+
+ public uploadTensor(tensorId: TensorId, data: Uint8Array): void {
+ const wasm = getInstance();
+ if (!wasm.shouldTransferToMLTensor) {
+ throw new Error('Trying to upload to a MLTensor while shouldTransferToMLTensor is false');
+ }
+ LOG_DEBUG('verbose', () => `[WebNN] uploadTensor {tensorId: ${tensorId}, data: ${data.byteLength}}`);
+ this.tensorManager.upload(tensorId, data);
+ }
+
+ public async downloadTensor(tensorId: TensorId, dstBuffer: ArrayBufferView | ArrayBuffer): Promise {
+ return this.tensorManager.download(tensorId, dstBuffer);
+ }
+
+ public createMLTensorDownloader(tensorId: TensorId, type: Tensor.MLTensorDataTypes): () => Promise {
+ return async () => {
+ const data = await this.tensorManager.download(tensorId);
+ return createView(data, type);
+ };
+ }
+
+ public registerMLTensor(tensor: MLTensor, onnxDataType: DataType, dimensions: number[]): TensorId {
+ const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType);
+ if (!webnnDataType) {
+ throw new Error(`Unsupported ONNX data type: ${onnxDataType}`);
+ }
+
+ const id = this.tensorManager.registerTensor(this.currentContext, tensor, webnnDataType, dimensions);
+ LOG_DEBUG(
+ 'verbose',
+ () =>
+ `[WebNN] registerMLTensor {tensor: ${tensor}, dataType: ${webnnDataType}, dimensions: ${
+ dimensions
+ }} -> {tensorId: ${id}}`,
+ );
+ return id;
+ }
+
+ public flush(): void {
+ // Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations.
+ }
+}
diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts
index 2f0e5da2b3f27..7bce5ff9390e8 100644
--- a/js/web/lib/wasm/jsep/init.ts
+++ b/js/web/lib/wasm/jsep/init.ts
@@ -12,6 +12,7 @@ import { LOG_DEBUG } from './log';
import { TensorView } from './tensor-view';
import { ShapeUtil } from './util';
import { AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo } from './webgpu/types';
+import { WebNNBackend } from './backend-webnn';
/* eslint-disable no-bitwise */
@@ -266,6 +267,22 @@ export const init = async (
() => backend.replay(),
]);
} else {
- jsepInit('webnn');
+ const backend = new WebNNBackend(env);
+ jsepInit('webnn', [
+ backend,
+ // jsepReserveTensorId
+ () => backend.reserveTensorId(),
+ // jsepReleaseTensorId,
+ (tensorId: number) => backend.releaseTensorId(tensorId),
+ // jsepEnsureTensor
+ async (tensorId: number, onnxDataType: number, shape: number[], copyOld) =>
+ backend.ensureTensor(tensorId, onnxDataType, shape, copyOld),
+ // jsepUploadTensor
+ (tensorId: number, data: Uint8Array) => {
+ backend.uploadTensor(tensorId, data);
+ },
+ // jsepDownloadTensor
+ async (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => backend.downloadTensor(tensorId, dstBuffer),
+ ]);
}
};
diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
index 8e18a28acc364..33e8c95c141ee 100644
--- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
+++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
@@ -52,12 +52,12 @@ export interface GpuDataManager {
* GPU data manager only manages a mapping between the buffer and the GPU data ID. It will not manage the lifecycle of
* the external buffer.
*/
- registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number;
+ registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previous?: [GpuDataId, GPUBuffer]): number;
/**
* unregister an external buffer for IO Binding.
*/
- unregisterExternalBuffer(buffer: GPUBuffer): void;
+ unregisterExternalBuffer(id: GpuDataId): void;
/**
* destroy all gpu buffers.
@@ -196,9 +196,6 @@ class GpuDataManagerImpl implements GpuDataManager {
// The reusable uniform buffers
private freeUniformBuffers: Map;
- // The external buffers registered users for IO Binding.
- private externalBuffers: Map;
-
// The pendingBuffers for capture graph.
// a SessionID -> GPUBuffer[] mapping.
private capturedPendingBuffers: Map;
@@ -209,7 +206,6 @@ class GpuDataManagerImpl implements GpuDataManager {
this.freeUniformBuffers = new Map();
this.buffersForUploadingPending = [];
this.buffersPending = [];
- this.externalBuffers = new Map();
this.capturedPendingBuffers = new Map();
for (const [key] of bucketFreelist) {
@@ -284,14 +280,11 @@ class GpuDataManagerImpl implements GpuDataManager {
);
}
- registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number {
+ registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previous?: [GpuDataId, GPUBuffer]): number {
let id: number | undefined;
- if (previousBuffer) {
- id = this.externalBuffers.get(previousBuffer);
- if (id === undefined) {
- throw new Error('previous buffer is not registered');
- }
- if (buffer === previousBuffer) {
+ if (previous) {
+ id = previous[0];
+ if (buffer === previous[1]) {
LOG_DEBUG(
'verbose',
() =>
@@ -304,13 +297,11 @@ class GpuDataManagerImpl implements GpuDataManager {
throw new Error(`Registering a different external buffer under graph capture mode is not supported yet.
Please use the previous external buffer!`);
}
- this.externalBuffers.delete(previousBuffer);
} else {
id = createNewGpuDataId();
}
this.storageCache.set(id, { gpuData: { id, type: GpuDataType.default, buffer }, originalSize });
- this.externalBuffers.set(buffer, id);
LOG_DEBUG(
'verbose',
() => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, registered.`,
@@ -318,11 +309,9 @@ class GpuDataManagerImpl implements GpuDataManager {
return id;
}
- unregisterExternalBuffer(buffer: GPUBuffer): void {
- const id = this.externalBuffers.get(buffer);
+ unregisterExternalBuffer(id: GpuDataId): void {
if (id !== undefined) {
this.storageCache.delete(id);
- this.externalBuffers.delete(buffer);
LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.unregisterExternalBuffer() => id=${id}`);
}
}
diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts
new file mode 100644
index 0000000000000..9475de019ed1d
--- /dev/null
+++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts
@@ -0,0 +1,303 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+import { WebNNBackend } from '../backend-webnn';
+import { LOG_DEBUG } from '../log';
+
+// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from
+// WebNN API specification.
+// https://github.com/webmachinelearning/webnn/issues/677
+///
+
+export type TensorId = number;
+
+/**
+ * Manages TensorId to MLTensor mapping.
+ */
+export interface TensorManager {
+ /**
+ * Reserve a new TensorId.
+ */
+ reserveTensorId(): TensorId;
+ /**
+ * Release a TensorId.
+ */
+ releaseTensorId(tensorId: TensorId): void;
+ /**
+ * Ensure a MLTensor is created for the TensorId.
+ */
+ ensureTensor(
+ tensorId: TensorId,
+ dataType: MLOperandDataType,
+ shape: readonly number[],
+ copyOld: boolean,
+ ): Promise;
+ /**
+ * Upload data to a MLTensor.
+ */
+ upload(tensorId: TensorId, data: Uint8Array): void;
+ /**
+ * Download data from a MLTensor.
+ */
+ download(tensorId: TensorId): Promise;
+ download(tensorId: TensorId, dstTensor: ArrayBufferView | ArrayBuffer): Promise;
+ /**
+ * Release all tensors for a MLContext.
+ */
+ releaseTensorsForContext(mlContext: MLContext): void;
+ /**
+ * Register an externally created MLTensor with a given MLContext and return a TensorId.
+ */
+ registerTensor(mlContext: MLContext, mlTensor: MLTensor, dataType: MLOperandDataType, shape: number[]): TensorId;
+}
+
+let tensorGuid = 1;
+const createNewTensorId = (): TensorId => tensorGuid++;
+
+export type MLTensorEntry = [MLTensor, MLOperandDataType, readonly number[]];
+
+/**
+ * TensorTracker tracks the MLTensor and pending upload data.
+ *
+ * We need to track the MLTensor and pending upload data because we delay the creation of MLTensor until
+ * we know the data type and shape. This is because future implementations of WebNN will only support creating
+ * MLTensors with dataTypes and shape.
+ */
+class TensorTracker {
+ private tensorEntry?: MLTensorEntry;
+ private activeUpload?: Uint8Array;
+ private tensorCache: MLTensorEntry[];
+
+ constructor(
+ private mlContext?: MLContext,
+ tensorEntry?: MLTensorEntry,
+ ) {
+ this.tensorEntry = tensorEntry;
+ this.tensorCache = tensorEntry ? [tensorEntry] : [];
+ }
+
+ public get tensor(): MLTensor | undefined {
+ return this.tensorEntry?.[0];
+ }
+
+ public get context(): MLContext {
+ if (!this.mlContext) {
+ throw new Error('MLContext has not been set.');
+ }
+ return this.mlContext;
+ }
+
+ public set context(mlContext: MLContext) {
+ if (this.mlContext && this.mlContext !== mlContext) {
+ throw new Error('MLTensor in use in a different MLContext.');
+ }
+ this.mlContext = mlContext;
+ }
+
+ public destroy(): void {
+ for (const [mlTensor] of this.tensorCache) {
+ mlTensor.destroy();
+ }
+ this.tensorCache = [];
+ this.tensorEntry = undefined;
+ }
+
+ public trySelectTensor(context: MLContext, tryMLTensor: MLTensor): boolean {
+ for (const [mlTensor, dataType, shape] of this.tensorCache) {
+ if (tryMLTensor === mlTensor) {
+ if (this.context !== context) {
+ throw new Error('MLTensor cannot be registered with a different MLContext.');
+ }
+ this.tensorEntry = [mlTensor, dataType, shape];
+ return true;
+ }
+ }
+ return false;
+ }
+
+ public async ensureTensor(
+ dataType: MLOperandDataType,
+ shape: readonly number[],
+ copyOld: boolean,
+ ): Promise {
+ if (this.tensorEntry) {
+ const [mlTensor, existingDataType, existingShape] = this.tensorEntry;
+ if (existingDataType === dataType && existingShape.every((v, i) => v === shape[i])) {
+ return mlTensor;
+ }
+ }
+
+ for (const [mlTensor, existingDataType, existingShape] of this.tensorCache) {
+ if (existingDataType === dataType && existingShape.every((v, i) => v === shape[i])) {
+ if (copyOld && this.tensorEntry) {
+ // WebNN does not support copyTensorToTensor, so we need to read and write the tensors.
+ LOG_DEBUG(
+ 'verbose',
+ () => `[WebNN] Slowdown may occur, having to copy existing tensor {dataType: ${dataType}, shape: ${shape}}`,
+ );
+ const data = await this.context.readTensor(this.tensorEntry[0]);
+ this.context.writeTensor(mlTensor, data);
+ }
+ this.tensorEntry = [mlTensor, existingDataType, existingShape];
+ return mlTensor;
+ }
+ }
+ LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`);
+ // eslint-disable-next-line no-bitwise
+ const usage = MLTensorUsage.READ | MLTensorUsage.WRITE;
+ const tensor = await this.context.createTensor({
+ dataType,
+ shape,
+ // Assign both shape and dimensions while transitioning to new API.
+ dimensions: shape,
+ usage,
+ });
+ this.tensorEntry = [tensor, dataType, shape];
+ this.tensorCache.push(this.tensorEntry);
+
+ if (this.activeUpload) {
+ this.mlContext?.writeTensor(tensor, this.activeUpload);
+ this.activeUpload = undefined;
+ }
+
+ return tensor;
+ }
+
+ public upload(data: Uint8Array): void {
+ if (!this.tensorEntry) {
+ this.activeUpload = new Uint8Array(data);
+ return;
+ }
+ this.mlContext?.writeTensor(this.tensorEntry[0], data);
+ }
+
+ public async download(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise {
+ if (this.activeUpload) {
+ if (dstBuffer) {
+ if (dstBuffer instanceof ArrayBuffer) {
+ new Uint8Array(dstBuffer).set(this.activeUpload);
+ } else {
+ new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength).set(this.activeUpload);
+ }
+
+ return;
+ } else {
+ return this.activeUpload.buffer;
+ }
+ }
+ if (!this.tensorEntry) {
+ throw new Error('Tensor has not been created.');
+ }
+ if (dstBuffer) {
+ return this.context.readTensor(this.tensorEntry[0], dstBuffer);
+ }
+ return this.context.readTensor(this.tensorEntry[0]);
+ }
+}
+
+class TensorManagerImpl implements TensorManager {
+ private tensorsById = new Map();
+ private tensorIdsByContext = new Map>();
+
+ constructor(private backend: WebNNBackend) {}
+
+ public reserveTensorId(): TensorId {
+ const tensorId = createNewTensorId();
+ this.tensorsById.set(tensorId, new TensorTracker());
+ return tensorId;
+ }
+
+ public releaseTensorId(tensorId: TensorId): void {
+ const tensorTracker = this.tensorsById.get(tensorId);
+ if (!tensorTracker) {
+ return;
+ }
+ tensorTracker.destroy();
+ this.tensorsById.delete(tensorId);
+ for (const [mlContext, tensors] of this.tensorIdsByContext) {
+ if (tensors.has(tensorId)) {
+ tensors.delete(tensorId);
+ if (tensors.size === 0) {
+ this.tensorIdsByContext.delete(mlContext);
+ }
+ break;
+ }
+ }
+ }
+
+ public async ensureTensor(
+ tensorId: TensorId,
+ dataType: MLOperandDataType,
+ shape: number[],
+ copyOld: boolean,
+ ): Promise {
+ LOG_DEBUG(
+ 'verbose',
+ () =>
+ `[WebNN] TensorManager.ensureTensor {tensorId: ${tensorId}, dataType: ${
+ dataType
+ }, shape: ${shape}, copyOld: ${copyOld}}`,
+ );
+ const tensor = this.tensorsById.get(tensorId);
+ if (!tensor) {
+ throw new Error('Tensor not found.');
+ }
+ tensor.context = this.backend.currentContext;
+ if (!this.tensorIdsByContext.has(this.backend.currentContext)) {
+ this.tensorIdsByContext.set(this.backend.currentContext, new Set());
+ }
+ this.tensorIdsByContext.get(this.backend.currentContext)?.add(tensorId);
+ return tensor.ensureTensor(dataType, shape, copyOld);
+ }
+
+ public upload(tensorId: TensorId, data: Uint8Array): void {
+ this.tensorsById.get(tensorId)!.upload(data);
+ }
+
+ public async download(tensorId: TensorId): Promise;
+ public async download(tensorId: TensorId, dstBuffer: ArrayBufferView | ArrayBuffer): Promise;
+ async download(tensorId: TensorId, dstBuffer?: ArrayBufferView | ArrayBuffer): Promise {
+ LOG_DEBUG(
+ 'verbose',
+ () => `[WebNN] TensorManager.download {tensorId: ${tensorId}, dstBuffer: ${dstBuffer?.byteLength}}`,
+ );
+ return this.tensorsById.get(tensorId)!.download(dstBuffer);
+ }
+
+ public releaseTensorsForContext(mlContext: MLContext): void {
+ const tensors = this.tensorIdsByContext.get(mlContext);
+ if (!tensors) {
+ return;
+ }
+ for (const tensorId of tensors) {
+ this.tensorsById.get(tensorId)!.destroy();
+ this.tensorsById.delete(tensorId);
+ }
+ this.tensorIdsByContext.delete(mlContext);
+ }
+
+ public registerTensor(
+ mlContext: MLContext,
+ mlTensor: MLTensor,
+ dataType: MLOperandDataType,
+ shape: readonly number[],
+ ): TensorId {
+ for (const [tensorId, tensorTracker] of this.tensorsById) {
+ if (tensorTracker.trySelectTensor(mlContext, mlTensor)) {
+ return tensorId;
+ }
+ }
+ const tensorId = createNewTensorId();
+ this.tensorsById.set(tensorId, new TensorTracker(mlContext, [mlTensor, dataType, shape]));
+ let tensors = this.tensorIdsByContext.get(mlContext);
+ if (!tensors) {
+ tensors = new Set();
+ this.tensorIdsByContext.set(mlContext, tensors);
+ }
+ tensors.add(tensorId);
+ return tensorId;
+ }
+}
+
+export const createTensorManager = (...args: ConstructorParameters): TensorManager =>
+ new TensorManagerImpl(...args);
diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts
index f8a1e1966fd4c..5cb0f4e74c3df 100644
--- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts
+++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+/* eslint-disable @typescript-eslint/naming-convention */
+
interface NavigatorML {
readonly ml: ML;
}
@@ -30,7 +32,9 @@ type MLInputOperandLayout = 'nchw'|'nhwc';
type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8';
interface MLOperandDescriptor {
dataType: MLOperandDataType;
- dimensions?: number[];
+ shape?: readonly number[];
+ /** @deprecated Use shape instead of dimensions */
+ dimensions?: readonly number[];
}
interface MLOperand {
dataType(): MLOperandDataType;
@@ -379,23 +383,32 @@ interface MLGraphBuilder {
where(condition: MLOperand, input: MLOperand, other: MLOperand): MLOperand;
}
-// Experimental MLBuffer interface
+// Experimental MLTensor interface
-type MLSize64Out = number;
-interface MLBuffer {
- readonly size: MLSize64Out;
+interface MLTensor {
destroy(): void;
}
-type MLSize64 = number;
-interface MLBufferDescriptor {
- size: MLSize64;
+
+type MLNamedTensor = Record;
+
+type MLTensorUsageFlags = number;
+
+declare const MLTensorUsage: {
+ readonly WEBGPU_INTEROP: MLTensorUsageFlags;
+ readonly READ: MLTensorUsageFlags;
+ readonly WRITE: MLTensorUsageFlags;
+};
+
+interface MLTensorDescriptor extends MLOperandDescriptor {
+ usage: MLTensorUsageFlags;
}
-type MLNamedBuffers = Record;
+
interface MLContext {
- createBuffer(descriptor: MLBufferDescriptor): MLBuffer;
- writeBuffer(
- dstBuffer: MLBuffer, srcData: ArrayBufferView|ArrayBuffer, srcElementOffset?: MLSize64,
- srcElementSize?: MLSize64): void;
- readBuffer(srcBuffer: MLBuffer): Promise;
- dispatch(graph: MLGraph, inputs: MLNamedBuffers, outputs: MLNamedBuffers): void;
+ createTensor(descriptor: MLTensorDescriptor): Promise;
+ writeTensor(
+ destinationTensor: MLTensor, sourceData: ArrayBufferView|ArrayBuffer, sourceElementOffset?: number,
+ sourceElementSize?: number): void;
+ readTensor(sourceTensor: MLTensor): Promise;
+ readTensor(sourceTensor: MLTensor, destinationData: ArrayBufferView|ArrayBuffer): Promise;
+ dispatch(graph: MLGraph, inputs: MLNamedTensor, outputs: MLNamedTensor): void;
}
diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts
index 8f3acdd582445..559f319a10f66 100644
--- a/js/web/lib/wasm/proxy-messages.ts
+++ b/js/web/lib/wasm/proxy-messages.ts
@@ -19,11 +19,18 @@ export type GpuBufferMetadata = {
dispose?: () => void;
};
+export type MLTensorMetadata = {
+ mlTensor: Tensor.MLTensorType;
+ download?: () => Promise;
+ dispose?: () => void;
+};
+
/**
- * Tensors on location 'cpu-pinned' and 'gpu-buffer' are not serializable.
+ * Tensors on location 'cpu-pinned', 'gpu-buffer', and 'ml-tensor' are not serializable.
*/
export type UnserializableTensorMetadata =
| [dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']
+ | [dataType: Tensor.Type, dims: readonly number[], data: MLTensorMetadata, location: 'ml-tensor']
| [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned'];
/**
@@ -34,6 +41,7 @@ export type UnserializableTensorMetadata =
* - cpu: Uint8Array
* - cpu-pinned: Uint8Array
* - gpu-buffer: GpuBufferMetadata
+ * - ml-tensor: MLTensorMetadata
* - location: tensor data location
*/
export type TensorMetadata = SerializableTensorMetadata | UnserializableTensorMetadata;
diff --git a/js/web/lib/wasm/session-handler-inference.ts b/js/web/lib/wasm/session-handler-inference.ts
index eff3e91389c98..c19043cc3637f 100644
--- a/js/web/lib/wasm/session-handler-inference.ts
+++ b/js/web/lib/wasm/session-handler-inference.ts
@@ -12,7 +12,7 @@ import {
import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages';
import { copyFromExternalBuffer, createSession, endProfiling, releaseSession, run } from './proxy-wrapper';
-import { isGpuBufferSupportedType } from './wasm-common';
+import { isGpuBufferSupportedType, isMLTensorSupportedType } from './wasm-common';
import { isNode } from './wasm-utils-env';
import { loadFile } from './wasm-utils-load-file';
@@ -22,6 +22,8 @@ export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): Ten
return [tensor.type, tensor.dims, tensor.data, 'cpu'];
case 'gpu-buffer':
return [tensor.type, tensor.dims, { gpuBuffer: tensor.gpuBuffer }, 'gpu-buffer'];
+ case 'ml-tensor':
+ return [tensor.type, tensor.dims, { mlTensor: tensor.mlTensor }, 'ml-tensor'];
default:
throw new Error(`invalid data location: ${tensor.location} for ${getName()}`);
}
@@ -39,6 +41,14 @@ export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => {
const { gpuBuffer, download, dispose } = tensor[2];
return Tensor.fromGpuBuffer(gpuBuffer, { dataType, dims: tensor[1], download, dispose });
}
+ case 'ml-tensor': {
+ const dataType = tensor[0];
+ if (!isMLTensorSupportedType(dataType)) {
+ throw new Error(`not supported data type: ${dataType} for deserializing MLTensor tensor`);
+ }
+ const { mlTensor, download, dispose } = tensor[2];
+ return Tensor.fromMLTensor(mlTensor, { dataType, dims: tensor[1], download, dispose });
+ }
default:
throw new Error(`invalid data location: ${tensor[3]}`);
}
diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts
index 78ff14540d8cb..ad2ff62587252 100644
--- a/js/web/lib/wasm/wasm-common.ts
+++ b/js/web/lib/wasm/wasm-common.ts
@@ -240,6 +240,20 @@ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuB
type === 'uint4' ||
type === 'int4';
+/**
+ * Check whether the given tensor type is supported by WebNN MLTensor
+ */
+export const isMLTensorSupportedType = (type: Tensor.Type): type is Tensor.MLTensorDataTypes =>
+ type === 'float32' ||
+ type === 'float16' ||
+ type === 'int32' ||
+ type === 'int64' ||
+ type === 'uint32' ||
+ type === 'uint64' ||
+ type === 'int8' ||
+ type === 'uint8' ||
+ type === 'bool';
+
/**
* Map string data location to integer value
*/
@@ -255,6 +269,8 @@ export const dataLocationStringToEnum = (location: Tensor.DataLocation): number
return 3;
case 'gpu-buffer':
return 4;
+ case 'ml-tensor':
+ return 5;
default:
throw new Error(`unsupported data location: ${location}`);
}
@@ -264,4 +280,4 @@ export const dataLocationStringToEnum = (location: Tensor.DataLocation): number
* Map integer data location to string value
*/
export const dataLocationEnumToString = (location: number): Tensor.DataLocation | undefined =>
- (['none', 'cpu', 'cpu-pinned', 'texture', 'gpu-buffer'] as const)[location];
+ (['none', 'cpu', 'cpu-pinned', 'texture', 'gpu-buffer', 'ml-tensor'] as const)[location];
diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts
index ed001cfa90f59..0668ac1931988 100644
--- a/js/web/lib/wasm/wasm-core-impl.ts
+++ b/js/web/lib/wasm/wasm-core-impl.ts
@@ -20,6 +20,7 @@ import {
calculateTensorSizeInBytes,
dataLocationStringToEnum,
isGpuBufferSupportedType,
+ isMLTensorSupportedType,
logLevelStringToEnum,
tensorDataTypeEnumToString,
tensorDataTypeStringToEnum,
@@ -162,7 +163,7 @@ export const initEp = async (env: Env, epName: string): Promise => {
/**
* valid data locations for input/output tensors.
*/
-type SupportedTensorDataLocationForInputOutput = 'cpu' | 'cpu-pinned' | 'gpu-buffer';
+type SupportedTensorDataLocationForInputOutput = 'cpu' | 'cpu-pinned' | 'gpu-buffer' | 'ml-tensor';
type IOBindingState = {
/**
@@ -173,7 +174,7 @@ type IOBindingState = {
/**
* the preferred location for each output tensor.
*
- * value is one of 'cpu', 'cpu-pinned', 'gpu-buffer'.
+ * value is one of 'cpu', 'cpu-pinned', 'gpu-buffer', 'ml-tensor'.
*/
readonly outputPreferredLocations: readonly SupportedTensorDataLocationForInputOutput[];
@@ -287,6 +288,7 @@ export const createSession = async (
for (const provider of options?.executionProviders ?? []) {
const providerName = typeof provider === 'string' ? provider : provider.name;
if (providerName === 'webnn') {
+ wasm.shouldTransferToMLTensor = false;
if (wasm.currentContext) {
throw new Error('WebNN execution provider is already set.');
}
@@ -318,7 +320,9 @@ export const createSession = async (
// clear current MLContext after session creation
if (wasm.currentContext) {
+ wasm.jsepRegisterMLContext!(sessionHandle, wasm.currentContext);
wasm.currentContext = undefined;
+ wasm.shouldTransferToMLTensor = true;
}
const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);
@@ -354,7 +358,7 @@ export const createSession = async (
typeof options?.preferredOutputLocation === 'string'
? options.preferredOutputLocation
: (options?.preferredOutputLocation?.[nameString] ?? 'cpu');
- if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') {
+ if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer' && location !== 'ml-tensor') {
throw new Error(`Not supported preferred output location: ${location}.`);
}
if (enableGraphCapture && location !== 'gpu-buffer') {
@@ -366,9 +370,9 @@ export const createSession = async (
}
}
- // use IO binding only when at least one output is preffered to be on GPU.
+ // use IO binding only when at least one output is preferred to be on GPU.
let bindingState: IOBindingState | null = null;
- if (!BUILD_DEFS.DISABLE_JSEP && outputPreferredLocations.some((l) => l === 'gpu-buffer')) {
+ if (!BUILD_DEFS.DISABLE_JSEP && outputPreferredLocations.some((l) => l === 'gpu-buffer' || l === 'ml-tensor')) {
ioBindingHandle = wasm._OrtCreateBinding(sessionHandle);
if (ioBindingHandle === 0) {
checkLastError("Can't create IO binding.");
@@ -459,7 +463,7 @@ export const prepareInputOutputTensor = (
let rawData: number;
let dataByteLength: number;
- if (dataType === 'string' && location === 'gpu-buffer') {
+ if (dataType === 'string' && (location === 'gpu-buffer' || location === 'ml-tensor')) {
throw new Error('String tensor is not supported on GPU.');
}
@@ -478,6 +482,15 @@ export const prepareInputOutputTensor = (
throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.');
}
rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength);
+ } else if (location === 'ml-tensor') {
+ const mlTensor = tensor[2].mlTensor as MLTensor;
+ dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!;
+
+ const registerMLTensor = wasm.jsepRegisterMLTensor;
+ if (!registerMLTensor) {
+ throw new Error('Tensor location "ml-tensor" is not supported without using WebNN.');
+ }
+ rawData = registerMLTensor(mlTensor, tensorDataTypeStringToEnum(dataType), dims);
} else {
const data = tensor[2];
@@ -563,6 +576,9 @@ export const run = async (
const outputNamesOffset = wasm.stackAlloc(outputCount * 4);
try {
+ // WebNN backend needs the active session to check MLTensors with the current context.
+ wasm.jsepOnRunStart?.(sessionHandle);
+
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
// create input tensors
@@ -654,7 +670,6 @@ export const run = async (
]);
}
- wasm.jsepOnRunStart?.(sessionHandle);
let errorCode: number;
if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState) {
errorCode = await wasm._OrtRunWithBinding(
@@ -726,7 +741,7 @@ export const run = async (
const preferredLocation = ioBindingState?.outputPreferredLocations[outputIndices[i]];
if (type === 'string') {
- if (preferredLocation === 'gpu-buffer') {
+ if (preferredLocation === 'gpu-buffer' || preferredLocation === 'ml-tensor') {
throw new Error('String tensor is not supported on GPU.');
}
const stringData: string[] = [];
@@ -766,6 +781,37 @@ export const run = async (
},
'gpu-buffer',
]);
+ } else if (preferredLocation === 'ml-tensor' && size > 0) {
+ const ensureTensor = wasm.jsepEnsureTensor;
+ if (!ensureTensor) {
+ throw new Error('preferredLocation "ml-tensor" is not supported without using WebNN.');
+ }
+ const tensorSize = calculateTensorSizeInBytes(dataType, size);
+ if (tensorSize === undefined || !isMLTensorSupportedType(type)) {
+ throw new Error(`Unsupported data type: ${type}`);
+ }
+
+ // If the graph has been partitioned, the output tensor may have not been created. For this reason, we use
+ // ensureTensor to get/create the MLTensor. In which case, we don't need to copy the data if a new tensor
+ // has been created.
+ const mlTensor = await ensureTensor(dataOffset, dataType, dims, false);
+
+ // do not release the tensor right now. it will be released when user calls tensor.dispose().
+ keepOutputTensor = true;
+
+ output.push([
+ type,
+ dims,
+ {
+ mlTensor,
+ download: wasm.jsepCreateMLTensorDownloader!(dataOffset, type),
+ dispose: () => {
+ wasm.jsepReleaseTensorId!(dataOffset);
+ wasm._OrtReleaseTensor(tensor);
+ },
+ },
+ 'ml-tensor',
+ ]);
} else {
const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);
const data = new typedArrayConstructor(size);
diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts
index 828cd3cfd94fa..3e08fe97f559d 100644
--- a/js/web/lib/wasm/wasm-types.ts
+++ b/js/web/lib/wasm/wasm-types.ts
@@ -7,6 +7,7 @@
///
import type { Tensor } from 'onnxruntime-common';
+import { DataType } from './wasm-common';
/* eslint-disable @typescript-eslint/naming-convention */
@@ -27,6 +28,16 @@ export declare namespace JSEP {
type CaptureBeginFunction = () => void;
type CaptureEndFunction = () => void;
type ReplayFunction = () => void;
+ type ReserveTensorIdFunction = () => number;
+ type ReleaseTensorIdFunction = (tensorId: number) => void;
+ type EnsureTensorFunction = (
+ tensorId: number,
+ dataType: DataType,
+ shape: readonly number[],
+ copyOld: boolean,
+ ) => Promise;
+ type UploadTensorFunction = (tensorId: number, data: Uint8Array) => void;
+ type DownloadTensorFunction = (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise;
export interface Module extends WebGpuModule, WebNnModule {
/**
@@ -62,7 +73,17 @@ export declare namespace JSEP {
replay: ReplayFunction,
],
): void;
- jsepInit(name: 'webnn', initParams?: never): void;
+ jsepInit(
+ name: 'webnn',
+ initParams: [
+ backend: BackendType,
+ reserveTensorId: ReserveTensorIdFunction,
+ releaseTensorId: ReleaseTensorIdFunction,
+ ensureTensor: EnsureTensorFunction,
+ uploadTensor: UploadTensorFunction,
+ downloadTensor: DownloadTensorFunction,
+ ],
+ ): void;
}
export interface WebGpuModule {
@@ -134,6 +155,70 @@ export declare namespace JSEP {
* Active MLContext used to create WebNN EP.
*/
currentContext: MLContext;
+
+ /**
+ * Disables creating MLTensors. This is used to avoid creating MLTensors for graph initializers.
+ */
+ shouldTransferToMLTensor: boolean;
+
+ /**
+ * [exported from pre-jsep.js] Register MLContext for a session.
+ * @param sessionId - specify the session ID.
+ * @param context - specify the MLContext.
+ * @returns
+ */
+ jsepRegisterMLContext: (sessionId: number, context: MLContext) => void;
+ /**
+ * [exported from pre-jsep.js] Reserve a MLTensor ID attached to the current session.
+ * @returns the MLTensor ID.
+ */
+ jsepReserveTensorId: () => number;
+ /**
+ * [exported from pre-jsep.js] Release an MLTensor ID from use and destroys underlying MLTensor if no longer in use.
+ * @param tensorId - specify the MLTensor ID.
+ * @returns
+ */
+ jsepReleaseTensorId: (tensorId: number) => void;
+ /**
+ * [exported from pre-jsep.js] Ensure that an MLTensor of a given type and shape exists for a MLTensor ID.
+ * @param tensorId - specify the MLTensor ID.
+ * @param onnxDataType - specify the data type.
+ * @param shape - specify the dimensions (WebNN shape) of the tensor.
+ * @param copyOld - specify whether to copy the old tensor if a new tensor was created.
+ * @returns the MLTensor associated with the tensor ID.
+ */
+ jsepEnsureTensor: (tensorId: number, dataType: DataType, shape: number[], copyOld: boolean) => Promise;
+ /**
+ * [exported from pre-jsep.js] Upload data to an MLTensor.
+ * @param tensorId - specify the MLTensor ID.
+ * @param data - specify the data to upload. It can be a TensorProto::data_type or a WebNN MLOperandDataType.
+ * @returns
+ */
+ jsepUploadTensor: (tensorId: number, data: Uint8Array) => void;
+ /**
+ * [exported from pre-jsep.js] Download data from an MLTensor.
+ * @param tensorId - specify the MLTensor ID.
+ * @returns the downloaded data.
+ */
+ jsepDownloadTensor: (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise;
+ /**
+ * [exported from pre-jsep.js] Creates a downloader function to download data from an MLTensor.
+ * @param tensorId - specify the MLTensor ID.
+ * @param type - specify the data type.
+ * @returns the downloader function.
+ */
+ jsepCreateMLTensorDownloader: (
+ tensorId: number,
+ type: Tensor.MLTensorDataTypes,
+ ) => () => Promise;
+ /**
+ * [exported from pre-jsep.js] Registers an external MLTensor to a session.
+ * @param tensor - specify the MLTensor.
+ * @param dataType - specify the data type.
+ * @param dimensions - specify the dimensions.
+ * @returns the MLTensor ID for the external MLTensor.
+ */
+ jsepRegisterMLTensor: (tensor: MLTensor, onnxDataType: DataType, dimensions: readonly number[]) => number;
}
}
diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts
index d237293dbb192..e94e11d0ace56 100644
--- a/js/web/script/test-runner-cli-args.ts
+++ b/js/web/script/test-runner-cli-args.ts
@@ -62,6 +62,8 @@ Options:
none (default)
gpu-tensor use pre-allocated GPU tensors for inputs and outputs
gpu-location use pre-allocated GPU tensors for inputs and set preferredOutputLocation to 'gpu-buffer'
+ ml-tensor use pre-allocated MLTensor tensors for inputs and outputs
+ ml-location use pre-allocated MLTensor tensors for inputs and set preferredOutputLocation to 'ml-tensor'
*** Logging Options ***
@@ -133,7 +135,7 @@ export declare namespace TestRunnerCliArgs {
type Backend = 'cpu' | 'webgl' | 'webgpu' | 'wasm' | 'onnxruntime' | 'webnn';
type Environment = 'chrome' | 'chromecanary' | 'edge' | 'firefox' | 'electron' | 'safari' | 'node' | 'bs';
type BundleMode = 'dev' | 'perf';
- type IOBindingMode = 'none' | 'gpu-tensor' | 'gpu-location';
+ type IOBindingMode = 'none' | 'gpu-tensor' | 'gpu-location' | 'ml-tensor' | 'ml-location';
}
export interface TestRunnerCliArgs {
@@ -455,7 +457,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs
// Option: -i=<...>, --io-binding=<...>
const ioBindingArg = args['io-binding'] || args.i;
const ioBindingMode = typeof ioBindingArg !== 'string' ? 'none' : ioBindingArg;
- if (['none', 'gpu-tensor', 'gpu-location'].indexOf(ioBindingMode) === -1) {
+ if (['none', 'gpu-tensor', 'gpu-location', 'ml-tensor', 'ml-location'].indexOf(ioBindingMode) === -1) {
throw new Error(`not supported io binding mode ${ioBindingMode}`);
}
diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts
index a9fcd7b876b2f..68ee58dab7094 100644
--- a/js/web/script/test-runner-cli.ts
+++ b/js/web/script/test-runner-cli.ts
@@ -380,7 +380,7 @@ async function main() {
}
let ioBinding: Test.IOBindingMode;
- if (backend !== 'webgpu' && args.ioBindingMode !== 'none') {
+ if (!['webgpu', 'webnn'].includes(backend) && args.ioBindingMode !== 'none') {
npmlog.warn(
'TestRunnerCli.Init.Model',
`Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`,
diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc
index 5c1e2e27a6eff..ae708467be8a2 100644
--- a/js/web/test/suite-test-list.jsonc
+++ b/js/web/test/suite-test-list.jsonc
@@ -1912,9 +1912,9 @@
// "test_lrn_default",
// "test_lrn",
// // "test_lstm_batchwise",
- // // "test_lstm_defaults",
- // // "test_lstm_with_initial_bias",
- // // "test_lstm_with_peepholes",
+ "test_lstm_defaults",
+ "test_lstm_with_initial_bias",
+ "test_lstm_with_peepholes",
"test_matmul_2d",
"test_matmul_3d",
"test_matmul_4d",
diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts
index aa9555c191501..2176a776a0192 100644
--- a/js/web/test/test-runner.ts
+++ b/js/web/test/test-runner.ts
@@ -1,6 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from
+// WebNN API specification.
+// https://github.com/webmachinelearning/webnn/issues/677
+///
+
import { Float16Array as Float16ArrayPolyfill } from '@petamoriken/float16';
import { expect } from 'chai';
import * as ort from 'onnxruntime-common';
@@ -19,6 +24,7 @@ import { createView } from '../lib/wasm/jsep/tensor-view';
import {
calculateTensorSizeInBytes,
isGpuBufferSupportedType,
+ isMLTensorSupportedType,
tensorDataTypeStringToEnum,
} from '../lib/wasm/wasm-common';
@@ -170,13 +176,20 @@ async function initializeSession(
}`,
);
+ let preferredOutputLocation: ort.Tensor.DataLocation | undefined;
+ if (ioBindingMode === 'gpu-location') {
+ preferredOutputLocation = 'gpu-buffer';
+ } else if (ioBindingMode === 'ml-location') {
+ preferredOutputLocation = 'ml-tensor';
+ }
+
const profilerConfig = profile ? { maxNumberEvents: 65536 } : undefined;
const sessionConfig = {
...sessionOptions,
executionProviders: [backendHint],
profiler: profilerConfig,
enableProfiling: profile,
- preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined,
+ preferredOutputLocation,
externalData,
};
@@ -219,6 +232,7 @@ export class ModelTestContext {
readonly perfData: ModelTestContext.ModelTestPerfData,
readonly ioBinding: Test.IOBindingMode,
private readonly profile: boolean,
+ public readonly mlContext?: MLContext,
) {}
/**
@@ -272,7 +286,24 @@ export class ModelTestContext {
const initStart = now();
const executionProviderConfig =
- modelTest.backend === 'webnn' ? testOptions?.webnnOptions || 'webnn' : modelTest.backend!;
+ modelTest.backend === 'webnn' ? testOptions?.webnnOptions || { name: 'webnn' } : modelTest.backend!;
+ let mlContext: MLContext | undefined;
+ if (['ml-tensor', 'ml-location'].includes(modelTest.ioBinding)) {
+ const webnnOptions = executionProviderConfig as ort.InferenceSession.WebNNExecutionProviderOption;
+ const deviceType = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.deviceType;
+ const numThreads = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.numThreads;
+ const powerPreference = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.powerPreference;
+
+ mlContext = await navigator.ml.createContext({
+ deviceType,
+ numThreads,
+ powerPreference,
+ });
+ (executionProviderConfig as ort.InferenceSession.WebNNExecutionProviderOption).context = mlContext;
+ if (!deviceType) {
+ (executionProviderConfig as ort.InferenceSession.WebNNContextOptions).deviceType = deviceType;
+ }
+ }
const session = await initializeSession(
modelTest.modelUrl,
executionProviderConfig,
@@ -295,6 +326,7 @@ export class ModelTestContext {
{ init: initEnd - initStart, firstRun: -1, runs: [], count: 0 },
modelTest.ioBinding,
profile,
+ mlContext,
);
} finally {
this.initializing = false;
@@ -622,30 +654,82 @@ function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]
});
}
+async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Type, dims: readonly number[]) {
+ if (!isMLTensorSupportedType(type)) {
+ throw new Error(`createMLTensorForOutput can not work with ${type} tensor`);
+ }
+
+ const dataType = type === 'bool' ? 'uint8' : type;
+
+ const mlTensor = await mlContext.createTensor({
+ dataType,
+ shape: dims as number[],
+ // Assign both shape and dimensions while transitioning to new API.
+ dimensions: dims as number[],
+ usage: MLTensorUsage.READ,
+ });
+
+ return ort.Tensor.fromMLTensor(mlTensor, {
+ dataType: type,
+ dims,
+ dispose: () => mlTensor.destroy(),
+ download: async () => {
+ const arrayBuffer = await mlContext.readTensor(mlTensor);
+ return createView(arrayBuffer, type) as ort.Tensor.DataTypeMap[ort.Tensor.MLTensorDataTypes];
+ },
+ });
+}
+
+async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tensor): Promise {
+ if (!isMLTensorSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) {
+ throw new Error(`createMLTensorForInput can not work with ${cpuTensor.type} tensor`);
+ }
+ const dataType = cpuTensor.type === 'bool' ? 'uint8' : cpuTensor.type;
+ const mlTensor = await mlContext.createTensor({
+ dataType,
+ shape: cpuTensor.dims as number[],
+ // Assign both shape and dimensions while transitioning to new API.
+ dimensions: cpuTensor.dims as number[],
+ usage: MLTensorUsage.WRITE,
+ });
+ mlContext.writeTensor(mlTensor, cpuTensor.data);
+ return ort.Tensor.fromMLTensor(mlTensor, {
+ dataType: cpuTensor.type,
+ dims: cpuTensor.dims,
+ dispose: () => mlTensor.destroy(),
+ });
+}
+
export async function sessionRun(options: {
session: ort.InferenceSession;
feeds: Record;
outputsMetaInfo: Record>;
ioBinding: Test.IOBindingMode;
+ mlContext?: MLContext;
}): Promise<[number, number, ort.InferenceSession.OnnxValueMapType]> {
const session = options.session;
const feeds = options.feeds;
const fetches: Record = {};
- // currently we only support IO Binding for WebGPU
+ // currently we only support IO Binding for WebGPU and WebNN
//
- // For inputs, we create GPU tensors on both 'gpu-tensor' and 'gpu-location' binding testing mode.
- // For outputs, we create GPU tensors on 'gpu-tensor' binding testing mode only.
+ // For inputs, we create tensors on 'gpu-tensor', 'gpu-location', 'ml-tensor', and 'ml-location' binding testing
+ // modes.
+ // For outputs, we create tensors on 'gpu-tensor' and 'ml-tensor' binding testing modes.
// in 'gpu-device' binding mode, outputs are not pre-allocated.
- const shouldUploadInput = options.ioBinding === 'gpu-tensor' || options.ioBinding === 'gpu-location';
- const shouldUploadOutput = options.ioBinding === 'gpu-tensor';
+ const shouldUploadInput = ['gpu-tensor', 'gpu-location', 'ml-location', 'ml-tensor'].includes(options.ioBinding);
+ const shouldUploadOutput = options.ioBinding === 'gpu-tensor' || options.ioBinding === 'ml-tensor';
try {
if (shouldUploadInput) {
// replace the CPU tensors in feeds into GPU tensors
for (const name in feeds) {
if (Object.hasOwnProperty.call(feeds, name)) {
if (feeds[name].size > 0) {
- feeds[name] = createGpuTensorForInput(feeds[name]);
+ if (options.ioBinding === 'ml-location' || options.ioBinding === 'ml-tensor') {
+ feeds[name] = await createMLTensorForInput(options.mlContext!, feeds[name]);
+ } else {
+ feeds[name] = createGpuTensorForInput(feeds[name]);
+ }
}
}
}
@@ -658,7 +742,11 @@ export async function sessionRun(options: {
if (dims.some((d) => d === 0)) {
fetches[name] = new ort.Tensor(type, [], dims);
} else {
- fetches[name] = createGpuTensorForOutput(type, dims);
+ if (options.ioBinding === 'ml-tensor') {
+ fetches[name] = await createMLTensorForOutput(options.mlContext!, type, dims);
+ } else {
+ fetches[name] = createGpuTensorForOutput(type, dims);
+ }
}
}
}
@@ -714,6 +802,7 @@ export async function runModelTestSet(
feeds,
outputsMetaInfo,
ioBinding: context.ioBinding,
+ mlContext: context.mlContext,
});
if (context.perfData.count === 0) {
context.perfData.firstRun = end - start;
diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts
index be1e56485ec5a..29a11f969ffea 100644
--- a/js/web/test/test-types.ts
+++ b/js/web/test/test-types.ts
@@ -52,8 +52,12 @@ export declare namespace Test {
* `preferredOutputLocation` will be set to `gpu-buffer`.
* - gpu-tensor: inputs and outputs will all be pre-allocated as GPU tensors. `preferredOutputLocation`
* will not be set.
+ * - ml-location: inputs will be pre-allocated as ML tensors; no output will be pre-allocated;
+ * `preferredOutputLocation` will be set to `ml-tensor`.
+ * - ml-tensor: inputs and outputs will all be pre-allocated as MLTensor tensors. `preferredOutputLocation`
+ * will not be set.
*/
- export type IOBindingMode = 'none' | 'gpu-tensor' | 'gpu-location';
+ export type IOBindingMode = 'none' | 'gpu-tensor' | 'gpu-location' | 'ml-tensor' | 'ml-location';
export interface ModelTestCase {
name: string;
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc
index 7b84971585f9f..c8fe9c77d8ff8 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc
@@ -48,13 +48,13 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat
constexpr size_t element_size = sizeof(T);
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
- per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array();
+ per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array();
},
[](BroadcastHelper& per_iter_bh) {
- per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1();
+ per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1();
},
[](BroadcastHelper& per_iter_bh) {
- per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1();
+ per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1();
}}; // For element-wise add
// Allocate space for output of Q(BS, D) + bias(D)
@@ -132,13 +132,13 @@ Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is
constexpr size_t element_size = sizeof(T);
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
- per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array();
+ per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array();
},
[](BroadcastHelper& per_iter_bh) {
- per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1();
+ per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1();
},
[](BroadcastHelper& per_iter_bh) {
- per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1();
+ per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1();
}}; // For element-wise add
// Get Q's bias from combined bias
@@ -219,6 +219,10 @@ template Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);
+template Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator,
+ int batch_size, int num_heads, int sequence_length, int head_size,
+ const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);
+
template
Status MaybeTransposeToBNSH(AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
@@ -242,5 +246,9 @@ template Status MaybeTransposeToBNSH(AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, OrtValue& out);
+template Status MaybeTransposeToBNSH(AllocatorPtr allocator,
+ int batch_size, int num_heads, int sequence_length, int head_size,
+ const Tensor* in, OrtValue& out);
+
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc
index 570f4108c3f62..72adfa025da57 100644
--- a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc
@@ -86,6 +86,11 @@ Status EmbedLayerNorm::Compute(OpKernelContext* context) const {
std::atomic_bool failed{false};
int n = batch_size * sequence_length;
+
+ // Put epsilon into local variable here to avoid the need to capture 'this' in the TryBatchParallelFor() lambda.
+ // Using the copy capture default (=) to implicitly capture 'this' is deprecated.
+ const float epsilon_value = epsilon();
+
concurrency::ThreadPool::TryBatchParallelFor(
context->GetOperatorThreadPool(), n, [=, &failed](ptrdiff_t index) {
int word_col_index = input_ids_data[index];
@@ -136,7 +141,7 @@ Status EmbedLayerNorm::Compute(OpKernelContext* context) const {
y[i] = a;
sum += a * a;
}
- T e = sqrt(sum / hidden_size + static_cast(epsilon()));
+ T e = sqrt(sum / hidden_size + static_cast(epsilon_value));
for (int i = 0; i < hidden_size; i++) {
y[i] = y[i] / e * gamma_data[i] + beta_data[i];
}
diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
index bfec9aef56727..ccaeb6654e286 100644
--- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
@@ -75,7 +75,7 @@ class GQAAttentionBase {
int seqlen_present_kv_cache = static_cast(present_key->Shape().GetDims()[2]);
// Compute the attention score.
- size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(T);
+ size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(float);
auto attention_probs = allocator->Alloc(bytes);
BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));
@@ -87,16 +87,17 @@ class GQAAttentionBase {
bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data;
const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
- ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size,
+ ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size,
sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
- present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp);
+ present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);
// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
- ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, seqlens_k->Data(),
+ ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v,
+ seqlens_k->Data(),
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
- is_prompt, tp);
+ is_prompt, tp, allocator);
return Status::OK();
}
@@ -106,7 +107,7 @@ class GQAAttentionBase {
// attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T)
// attention_probs(B, N, S, T) = Softmax(attention_probs)
template
- void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
+ void ComputeAttentionProbs(float* attention_probs, // output buffer with size BxNxSxT
const T* Q, // Q data. Its size is BxNxSxH
const T* K, // k data. Its size is BxNxLxH
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
@@ -120,7 +121,8 @@ class GQAAttentionBase {
const bool past_present_share_buffer, // whether present key and value share the same buffer
const bool packed_qkv, // whether Q, K, V are packed
const bool is_prompt, // whether it is prompt
- ThreadPool* tp) const { // thread pool
+ ThreadPool* tp, // thread pool
+ AllocatorPtr allocator) const { // allocator for temporary buffer
const ptrdiff_t packed_batch_stride =
packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
: SafeInt(0);
@@ -131,7 +133,9 @@ class GQAAttentionBase {
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H
if (!past_present_share_buffer) {
- memset(present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
+ memset((void*)present_key,
+ 0,
+ batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
}
const size_t loop_len = batch_size * num_heads_;
@@ -164,7 +168,7 @@ class GQAAttentionBase {
const size_t past_chunk_length = past_seqlen * head_size;
const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length;
- T* output = attention_probs + output_offset;
+ float* output = attention_probs + output_offset;
const T* k;
if (packed_qkv) {
@@ -190,12 +194,28 @@ class GQAAttentionBase {
q = Q + q_input_chunk_length * i;
}
- math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
- static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, output,
- static_cast(present_buffer_sequence_length), nullptr);
+ if constexpr (std::is_same::value) {
+ math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
+ static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/,
+ output, static_cast(present_buffer_sequence_length), nullptr);
+ } else {
+ size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float);
+ auto q_k_fp32 = allocator->Alloc(bytes);
+ BufferUniquePtr scratch_buffer(q_k_fp32, BufferDeleter(allocator));
+
+ float* q_fp32 = static_cast(q_k_fp32);
+ MlasConvertHalfToFloatBuffer(q, q_fp32, head_size * sequence_length);
+
+ float* k_fp32 = q_fp32 + head_size * sequence_length;
+ MlasConvertHalfToFloatBuffer(k, k_fp32, head_size * total_seqlen);
+
+ math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q_fp32,
+ static_cast(head_size), k_fp32, static_cast(head_size), 0.0f /*bata*/,
+ output, static_cast(present_buffer_sequence_length), nullptr);
+ }
// compute Softmax
- T* output_softmax = output;
+ float* output_softmax = output;
for (size_t seq = 0; seq < sequence_length; seq++) {
size_t seq_causal_length = past_seqlen + seq + 1;
if (local_window_size_ > 0 && seq_causal_length > static_cast(local_window_size_) + 1) {
@@ -237,7 +257,7 @@ class GQAAttentionBase {
template
void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH
- const T* attention_probs, // Attention probs with size BxNxSxT
+ const float* attention_probs, // Attention probs with size BxNxSxT
const T* V, // V value with size BxN_kvxSxH
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
const size_t batch_size, // batch size
@@ -251,7 +271,8 @@ class GQAAttentionBase {
const bool past_present_share_buffer, // whether present key and value share the same buffer
const bool packed_qkv, // whether Q, K, V are packed
const bool is_prompt, // whether it is prompt
- ThreadPool* tp) const {
+ ThreadPool* tp,
+ AllocatorPtr allocator) const {
const ptrdiff_t packed_batch_stride =
packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
: SafeInt(0);
@@ -261,7 +282,9 @@ class GQAAttentionBase {
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H
if (!past_present_share_buffer) {
- memset(present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
+ memset((void*)present_value,
+ 0,
+ batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
}
const size_t loop_len = batch_size * num_heads_;
@@ -285,6 +308,13 @@ class GQAAttentionBase {
unit_cost.bytes_loaded += bytes_to_copy_trans_all;
unit_cost.bytes_stored += bytes_to_copy_trans_all;
+ size_t output_fp32_bytes = 0;
+ if constexpr (std::is_same::value) {
+ output_fp32_bytes = SafeInt(sequence_length) * batch_size * num_heads_ * head_size * sizeof(float);
+ }
+ auto output_fp32 = allocator->Alloc(output_fp32_bytes);
+ BufferUniquePtr scratch_buffer(output_fp32, BufferDeleter(allocator));
+
ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const size_t batch_index = i / num_heads_;
@@ -305,15 +335,39 @@ class GQAAttentionBase {
i / kv_num_heads_factor);
}
- T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i;
- math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/
- attention_probs + attention_probs_offset,
- static_cast(present_buffer_sequence_length), v, static_cast(head_size),
- 0.0f /*beta*/, output_current, static_cast(hidden_size), nullptr);
+ if constexpr (std::is_same::value) {
+ T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
+ math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen,
+ 1.f, /*alpha*/ attention_probs + attention_probs_offset,
+ static_cast(present_buffer_sequence_length), v,
+ static_cast(head_size), 0.0f /*beta*/, output_current,
+ static_cast(hidden_size), nullptr);
+ } else {
+ size_t bytes = head_size * total_seqlen * sizeof(float);
+ auto v_fp32 = allocator->Alloc(bytes);
+ BufferUniquePtr scratch_buffer(v_fp32, BufferDeleter(allocator));
+
+ float* v_fp32_ptr = static_cast(v_fp32);
+ MlasConvertHalfToFloatBuffer(v, v_fp32_ptr, head_size * total_seqlen);
+
+ float* output_fp32_current = static_cast(output_fp32) +
+ (batch_index * sequence_length * num_heads_ + head_index) * head_size;
+ math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen,
+ 1.f, /*alpha*/ attention_probs + attention_probs_offset,
+ static_cast(present_buffer_sequence_length), v_fp32_ptr,
+ static_cast(head_size), 0.0f /*beta*/, output_fp32_current,
+ static_cast(hidden_size), nullptr);
+ }
}
});
+
+ if constexpr (std::is_same::value) {
+ MlasConvertFloatToHalfBuffer(static_cast(output_fp32),
+ output,
+ SafeInt(sequence_length) * batch_size * num_heads_ * head_size);
+ }
}
};
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
index 2a38e4a1ac636..a1ed35e54b008 100644
--- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
@@ -22,16 +22,20 @@ namespace onnxruntime {
namespace contrib {
// These ops are internal-only, so register outside of onnx
-ONNX_OPERATOR_TYPED_KERNEL_EX(
- GroupQueryAttention,
- kMSDomain,
- 1,
- float,
- kCpuExecutionProvider,
- KernelDefBuilder()
- .TypeConstraint("T", DataTypeImpl::GetTensorType())
- .TypeConstraint("M", DataTypeImpl::GetTensorType()),
- GroupQueryAttention);
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ GroupQueryAttention, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCpuExecutionProvider, \
+ KernelDefBuilder() \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("M", DataTypeImpl::GetTensorType()), \
+ GroupQueryAttention);
+
+REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(MLFloat16)
template
GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info)
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
index 6732f8b96cce2..cbfd2f0949363 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
@@ -13,16 +13,20 @@ namespace onnxruntime {
namespace contrib {
// These ops are internal-only, so register outside of onnx
-ONNX_OPERATOR_TYPED_KERNEL_EX(
- RotaryEmbedding,
- kMSDomain,
- 1,
- float,
- kCpuExecutionProvider,
- KernelDefBuilder()
- .TypeConstraint("T", DataTypeImpl::GetTensorType())
- .TypeConstraint("M", DataTypeImpl::GetTensorType()),
- RotaryEmbedding);
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ RotaryEmbedding, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCpuExecutionProvider, \
+ KernelDefBuilder() \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("M", DataTypeImpl::GetTensorType()), \
+ RotaryEmbedding);
+
+REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(MLFloat16)
template
RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
@@ -75,19 +79,27 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete
const T* sin_data = sin_cache + cache_offset;
int cache_idx = 0;
- T sign = 0;
+ bool sign = false;
int j = 0;
for (int i = 0; i < rotary_emb_dim; i++) {
if (interleaved) {
cache_idx = (i / 2) % half_rotary_emb_dim;
- sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1);
- j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
+ sign = i & 1;
+ j = sign ? i - 1 : i + 1; // i - sign
} else {
cache_idx = i % half_rotary_emb_dim;
- sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1);
+ sign = (i >= half_rotary_emb_dim);
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
}
- output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
+ float output_data_i = static_cast(input_data[i]) * static_cast(cos_data[cache_idx]);
+ float input_data_j = static_cast(input_data[j]);
+ float sin_data_cache_idx = static_cast(sin_data[cache_idx]);
+ if (sign) {
+ output_data_i += input_data_j * sin_data_cache_idx;
+ } else {
+ output_data_i -= input_data_j * sin_data_cache_idx;
+ }
+ output_data[i] = static_cast(output_data_i);
}
for (int i = rotary_emb_dim; i < head_size; i++) {
output_data[i] = input_data[i];
@@ -102,6 +114,10 @@ template Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryPar
const int64_t* position_ids, const float* cos_cache, const float* sin_cache, float* output,
bool interleaved);
+template Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters parameters, const MLFloat16* input,
+ const int64_t* position_ids, const MLFloat16* cos_cache, const MLFloat16* sin_cache,
+ MLFloat16* output, bool interleaved);
+
template
Status RotaryEmbedding::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input(0);
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index dcd1f5ec22b52..6ffe861d19931 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -22,8 +22,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
@@ -134,12 +136,16 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomai
// LayerNormalization is now in the ONNX spec. As the contrib op (incorrectly) used kOnnxDomain we need to version it
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, float, LayerNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, double, LayerNormalization);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, MLFloat16, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, SimplifiedLayerNormalization);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipSimplifiedLayerNormalization);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu);
@@ -288,8 +294,10 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -338,12 +346,16 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cpu/layer_norm.cc b/onnxruntime/contrib_ops/cpu/layer_norm.cc
index 94f32360bd2f4..c949fcddad093 100644
--- a/onnxruntime/contrib_ops/cpu/layer_norm.cc
+++ b/onnxruntime/contrib_ops/cpu/layer_norm.cc
@@ -25,6 +25,7 @@ namespace contrib {
REGISTER_CONTRIB_KERNELS(float)
REGISTER_CONTRIB_KERNELS(double)
+REGISTER_CONTRIB_KERNELS(MLFloat16)
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
index f8f07b6e2827d..67af00beaba06 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
@@ -142,6 +142,8 @@ class MatMulNBits final : public OpKernel {
const bool column_wise_quant_{true};
IAllocatorUniquePtr packed_b_{};
size_t packed_b_size_{0};
+ IAllocatorUniquePtr scales_fp32_{};
+ IAllocatorUniquePtr bias_fp32_{};
bool has_zp_input_{false};
#if defined(ORT_NEURAL_SPEED)
@@ -175,30 +177,9 @@ class MatMulNBits final : public OpKernel {
const MatMulComputeHelper& helper) const {
ORT_THROW("ComputeBPacked is not supported for T1 type.");
}
-
- void PackScale(const Tensor& tensor) {
- ORT_THROW("PackScale is not supported for T1 type.");
- }
};
-#ifdef MLAS_TARGET_AMD64_IX86
-template <>
-void MatMulNBits::PackScale(const Tensor& tensor) {
- auto sptr = tensor.Data();
- std::vector scales_v(static_cast(tensor.Shape().Size()));
- MlasConvertHalfToFloatBuffer(sptr, &scales_v[0], scales_v.size());
- MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), &scales_v[0],
- has_zp_input_, nullptr, nullptr);
-}
-
-template <>
-void MatMulNBits::PackScale(const Tensor& tensor) {
- auto sptr = tensor.Data();
- MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr,
- has_zp_input_, nullptr, nullptr);
-}
-#endif
-
+#if defined(ORT_NEURAL_SPEED)
template
Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
/*out*/ bool& is_packed,
@@ -207,7 +188,6 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All
if (has_g_idx_ || has_unquantized_zero_point_) {
return Status::OK();
}
-#if defined(ORT_NEURAL_SPEED)
if (!all_constant_) {
return Status::OK();
@@ -259,8 +239,21 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All
is_packed = true;
}
+ return Status::OK();
+}
+
#else // defined(ORT_NEURAL_SPEED)
+
+template
+Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
+ /*out*/ bool& is_packed,
+ /*out*/ PrePackedWeights* prepacked_weights) {
ORT_UNUSED_PARAMETER(prepacked_weights);
+ is_packed = false;
+ if (has_g_idx_ || has_unquantized_zero_point_) {
+ return Status::OK();
+ }
+
if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type_)) {
return Status::OK();
}
@@ -276,20 +269,77 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All
} else if (compute_type_ == CompInt8) {
#ifdef MLAS_TARGET_AMD64_IX86
if (input_idx == InputIndex::scales && packed_b_ != nullptr) {
- PackScale(tensor);
+ auto sptr = tensor.Data();
+ MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr,
+ has_zp_input_, nullptr, nullptr);
is_packed = false;
} else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) {
auto zptr = tensor.Data();
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr);
is_packed = false;
}
-#endif
+#endif // MLAS_TARGET_AMD64_IX86
+ }
+
+ return Status::OK();
+}
+
+template <>
+Status MatMulNBits