From 017a5f702c828193d4c6dc3069906d56439b9c18 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 26 Sep 2024 16:00:23 -0700 Subject: [PATCH 01/11] Add LoraAdapter class --- .../main/java/ai/onnxruntime/OrtSession.java | 836 +++++++++--------- 1 file changed, 430 insertions(+), 406 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 6d146d5857d3c..216cea247b42b 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -29,11 +29,13 @@ /** * Wraps an ONNX model and allows inference calls. * - *

Allows the inspection of the model's input and output nodes. Produced by an {@link - * OrtEnvironment}. + *

+ * Allows the inspection of the model's input and output nodes. Produced by an + * {@link OrtEnvironment}. * - *

Most instance methods throw {@link IllegalStateException} if the session is closed and the - * methods are called. + *

+ * Most instance methods throw {@link IllegalStateException} if the session is + * closed and the methods are called. */ public class OrtSession implements AutoCloseable { @@ -64,81 +66,70 @@ public class OrtSession implements AutoCloseable { /** * Create a session loading the model from disk. * - * @param env The environment. + * @param env The environment. * @param modelPath The path to the model. * @param allocator The allocator to use. - * @param options Session configuration options. - * @throws OrtException If the file could not be read, or the model was corrupted etc. + * @param options Session configuration options. + * @throws OrtException If the file could not be read, or the model was + * corrupted etc. */ - OrtSession(OrtEnvironment env, String modelPath, OrtAllocator allocator, SessionOptions options) - throws OrtException { - this( - createSession( - OnnxRuntime.ortApiHandle, env.getNativeHandle(), modelPath, options.getNativeHandle()), + OrtSession(OrtEnvironment env, String modelPath, OrtAllocator allocator, SessionOptions options) throws OrtException { + this(createSession(OnnxRuntime.ortApiHandle, env.getNativeHandle(), modelPath, options.getNativeHandle()), allocator); } /** * Creates a session reading the model from the supplied byte array. * - * @param env The environment. + * @param env The environment. * @param modelArray The model protobuf as a byte array. - * @param allocator The allocator to use. - * @param options Session configuration options. - * @throws OrtException If the model was corrupted or some other error occurred in native code. + * @param allocator The allocator to use. + * @param options Session configuration options. + * @throws OrtException If the model was corrupted or some other error occurred + * in native code. */ OrtSession(OrtEnvironment env, byte[] modelArray, OrtAllocator allocator, SessionOptions options) throws OrtException { - this( - createSession( - OnnxRuntime.ortApiHandle, env.getNativeHandle(), modelArray, options.getNativeHandle()), + this(createSession(OnnxRuntime.ortApiHandle, env.getNativeHandle(), modelArray, options.getNativeHandle()), allocator); } /** * Creates a session reading the model from the supplied byte buffer. * - *

Must be a direct byte buffer. + *

+ * Must be a direct byte buffer. * - * @param env The environment. + * @param env The environment. * @param modelBuffer The model protobuf as a byte buffer. - * @param allocator The allocator to use. - * @param options Session configuration options. - * @throws OrtException If the model was corrupted or some other error occurred in native code. + * @param allocator The allocator to use. + * @param options Session configuration options. + * @throws OrtException If the model was corrupted or some other error occurred + * in native code. */ - OrtSession( - OrtEnvironment env, ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options) + OrtSession(OrtEnvironment env, ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options) throws OrtException { - this( - createSession( - OnnxRuntime.ortApiHandle, - env.getNativeHandle(), - modelBuffer, - modelBuffer.position(), - modelBuffer.remaining(), - options.getNativeHandle()), - allocator); + this(createSession(OnnxRuntime.ortApiHandle, env.getNativeHandle(), modelBuffer, modelBuffer.position(), + modelBuffer.remaining(), options.getNativeHandle()), allocator); } /** * Private constructor to build the Java object wrapped around a native session. * * @param nativeHandle The pointer to the native session. - * @param allocator The allocator to use. - * @throws OrtException If the model's inputs, outputs or metadata could not be read. + * @param allocator The allocator to use. + * @throws OrtException If the model's inputs, outputs or metadata could not be + * read. */ private OrtSession(long nativeHandle, OrtAllocator allocator) throws OrtException { this.nativeHandle = nativeHandle; this.allocator = allocator; numInputs = getNumInputs(OnnxRuntime.ortApiHandle, nativeHandle); - inputNames = - new LinkedHashSet<>( - Arrays.asList(getInputNames(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle))); + inputNames = new LinkedHashSet<>( + Arrays.asList(getInputNames(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle))); numOutputs = getNumOutputs(OnnxRuntime.ortApiHandle, nativeHandle); - outputNames = - new LinkedHashSet<>( - Arrays.asList( - getOutputNames(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle))); + outputNames = new LinkedHashSet<>( + Arrays.asList(getOutputNames(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle))); } /** @@ -168,7 +159,8 @@ public long getNumOutputs() { } /** - * Returns the input names. The underlying collection is sorted based on the input id number. + * Returns the input names. The underlying collection is sorted based on the + * input id number. * * @return The input names. */ @@ -181,7 +173,8 @@ public Set getInputNames() { } /** - * Returns the output names. The underlying collection is sorted based on the output id number. + * Returns the output names. The underlying collection is sorted based on the + * output id number. * * @return The output names. */ @@ -194,8 +187,8 @@ public Set getOutputNames() { } /** - * Returns the info objects for the inputs, including their names and types. The underlying - * collection is sorted based on the input id number. + * Returns the info objects for the inputs, including their names and types. The + * underlying collection is sorted based on the input id number. * * @return The input information. * @throws OrtException If there was an error in native code. @@ -209,8 +202,8 @@ public Map getInputInfo() throws OrtException { } /** - * Returns the info objects for the outputs, including their names and types. The underlying - * collection is sorted based on the output id number. + * Returns the info objects for the outputs, including their names and types. + * The underlying collection is sorted based on the output id number. * * @return The output information. * @throws OrtException If there was an error in native code. @@ -226,12 +219,13 @@ public Map getOutputInfo() throws OrtException { /** * Scores an input feed dict, returning the map of all inferred outputs. * - *

The outputs are sorted based on their id number. + *

+ * The outputs are sorted based on their id number. * * @param inputs The inputs to score. * @return The inferred outputs. - * @throws OrtException If there was an error in native code, the input names are invalid, or if - * there are zero or too many inputs. + * @throws OrtException If there was an error in native code, the input names + * are invalid, or if there are zero or too many inputs. */ public Result run(Map inputs) throws OrtException { return run(inputs, outputNames); @@ -240,51 +234,51 @@ public Result run(Map inputs) throws OrtExcept /** * Scores an input feed dict, returning the map of all inferred outputs. * - *

The outputs are sorted based on their id number. + *

+ * The outputs are sorted based on their id number. * - * @param inputs The inputs to score. + * @param inputs The inputs to score. * @param runOptions The RunOptions to control this run. * @return The inferred outputs. - * @throws OrtException If there was an error in native code, the input names are invalid, or if - * there are zero or too many inputs. + * @throws OrtException If there was an error in native code, the input names + * are invalid, or if there are zero or too many inputs. */ - public Result run(Map inputs, RunOptions runOptions) - throws OrtException { + public Result run(Map inputs, RunOptions runOptions) throws OrtException { return run(inputs, outputNames, runOptions); } /** * Scores an input feed dict, returning the map of requested inferred outputs. * - *

The outputs are sorted based on the supplied set traversal order. + *

+ * The outputs are sorted based on the supplied set traversal order. * - * @param inputs The inputs to score. + * @param inputs The inputs to score. * @param requestedOutputs The requested outputs. * @return The inferred outputs. - * @throws OrtException If there was an error in native code, the input or output names are - * invalid, or if there are zero or too many inputs or outputs. + * @throws OrtException If there was an error in native code, the input or + * output names are invalid, or if there are zero or too + * many inputs or outputs. */ - public Result run(Map inputs, Set requestedOutputs) - throws OrtException { + public Result run(Map inputs, Set requestedOutputs) throws OrtException { return run(inputs, requestedOutputs, Collections.emptyMap(), null); } /** * Scores an input feed dict, returning the map of requested inferred outputs. * - *

The outputs are sorted based on the supplied set traversal order. + *

+ * The outputs are sorted based on the supplied set traversal order. * - * @param inputs The inputs to score. + * @param inputs The inputs to score. * @param requestedOutputs The requested outputs. - * @param runOptions The RunOptions to control this run. + * @param runOptions The RunOptions to control this run. * @return The inferred outputs. - * @throws OrtException If there was an error in native code, the input or output names are - * invalid, or if there are zero or too many inputs or outputs. + * @throws OrtException If there was an error in native code, the input or + * output names are invalid, or if there are zero or too + * many inputs or outputs. */ - public Result run( - Map inputs, - Set requestedOutputs, - RunOptions runOptions) + public Result run(Map inputs, Set requestedOutputs, RunOptions runOptions) throws OrtException { return run(inputs, requestedOutputs, Collections.emptyMap(), runOptions); } @@ -292,19 +286,21 @@ public Result run( /** * Scores an input feed dict, returning the map of pinned outputs. * - *

The outputs are sorted based on the supplied map traversal order. + *

+ * The outputs are sorted based on the supplied map traversal order. * - *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed - * when the result object is closed. + *

+ * Note: pinned outputs are not owned by the {@link Result} object, and are + * not closed when the result object is closed. * - * @param inputs The inputs to score. + * @param inputs The inputs to score. * @param pinnedOutputs The requested outputs which the user has allocated. * @return The inferred outputs. - * @throws OrtException If there was an error in native code, the input or output names are - * invalid, or if there are zero or too many inputs or outputs. + * @throws OrtException If there was an error in native code, the input or + * output names are invalid, or if there are zero or too + * many inputs or outputs. */ - public Result run( - Map inputs, Map pinnedOutputs) + public Result run(Map inputs, Map pinnedOutputs) throws OrtException { return run(inputs, Collections.emptySet(), pinnedOutputs, null); } @@ -312,64 +308,61 @@ public Result run( /** * Scores an input feed dict, returning the map of requested and pinned outputs. * - *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, - * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name - * appears in both the requested outputs and the pinned outputs. + *

+ * The outputs are sorted based on the supplied set traversal order with pinned + * outputs first, then requested outputs. An {@link IllegalArgumentException} is + * thrown if the same output name appears in both the requested outputs and the + * pinned outputs. * - *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed - * when the result object is closed. + *

+ * Note: pinned outputs are not owned by the {@link Result} object, and are + * not closed when the result object is closed. * - * @param inputs The inputs to score. + * @param inputs The inputs to score. * @param requestedOutputs The requested outputs which ORT will allocate. - * @param pinnedOutputs The requested outputs which the user has allocated. + * @param pinnedOutputs The requested outputs which the user has allocated. * @return The inferred outputs. - * @throws OrtException If there was an error in native code, the input or output names are - * invalid, or if there are zero or too many inputs or outputs. + * @throws OrtException If there was an error in native code, the input or + * output names are invalid, or if there are zero or too + * many inputs or outputs. */ - public Result run( - Map inputs, - Set requestedOutputs, - Map pinnedOutputs) - throws OrtException { + public Result run(Map inputs, Set requestedOutputs, + Map pinnedOutputs) throws OrtException { return run(inputs, requestedOutputs, pinnedOutputs, null); } /** * Scores an input feed dict, returning the map of requested and pinned outputs. * - *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, - * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name - * appears in both the requested outputs and the pinned outputs. + *

+ * The outputs are sorted based on the supplied set traversal order with pinned + * outputs first, then requested outputs. An {@link IllegalArgumentException} is + * thrown if the same output name appears in both the requested outputs and the + * pinned outputs. * - *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed - * when the result object is closed. + *

+ * Note: pinned outputs are not owned by the {@link Result} object, and are + * not closed when the result object is closed. * - * @param inputs The inputs to score. + * @param inputs The inputs to score. * @param requestedOutputs The requested outputs which ORT will allocate. - * @param pinnedOutputs The requested outputs which the user has allocated. - * @param runOptions The RunOptions to control this run. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @param runOptions The RunOptions to control this run. * @return The inferred outputs. - * @throws OrtException If there was an error in native code, the input or output names are - * invalid, or if there are zero or too many inputs or outputs. + * @throws OrtException If there was an error in native code, the input or + * output names are invalid, or if there are zero or too + * many inputs or outputs. */ - public Result run( - Map inputs, - Set requestedOutputs, - Map pinnedOutputs, - RunOptions runOptions) - throws OrtException { + public Result run(Map inputs, Set requestedOutputs, + Map pinnedOutputs, RunOptions runOptions) throws OrtException { if (!closed) { if ((inputs.isEmpty() && (numInputs != 0)) || (inputs.size() > numInputs)) { - throw new OrtException( - "Unexpected number of inputs, expected [1," + numInputs + ") found " + inputs.size()); + throw new OrtException("Unexpected number of inputs, expected [1," + numInputs + ") found " + inputs.size()); } int totalOutputs = requestedOutputs.size() + pinnedOutputs.size(); if ((totalOutputs == 0) || (totalOutputs > numOutputs)) { - throw new OrtException( - "Unexpected number of requestedOutputs & pinnedOutputs, expected [1," - + numOutputs - + ") found " - + totalOutputs); + throw new OrtException("Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + numOutputs + + ") found " + totalOutputs); } String[] inputNamesArray = new String[inputs.size()]; long[] inputHandles = new long[inputs.size()]; @@ -380,8 +373,7 @@ public Result run( inputHandles[i] = t.getValue().getNativeHandle(); i++; } else { - throw new OrtException( - "Unknown input name " + t.getKey() + ", expected one of " + inputNames.toString()); + throw new OrtException("Unknown input name " + t.getKey() + ", expected one of " + inputNames.toString()); } } String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()]; @@ -395,42 +387,28 @@ public Result run( outputHandles[i] = getHandle(e.getValue()); i++; } else { - throw new OrtException( - "Unknown output name " + e.getKey() + ", expected one of " + outputNames.toString()); + throw new OrtException("Unknown output name " + e.getKey() + ", expected one of " + outputNames.toString()); } } for (String s : requestedOutputs) { if (outputNames.contains(s)) { if (!pinnedOutputs.containsKey(s)) { outputNamesArray[i] = s; - // outputValues and outputHandles can be null/0 for these outputs as ORT will allocate + // outputValues and outputHandles can be null/0 for these outputs as ORT will + // allocate // them. i++; } else { - throw new OrtException( - "Output '" - + s - + "' was found in both the requested outputs and the pinned outputs"); + throw new OrtException("Output '" + s + "' was found in both the requested outputs and the pinned outputs"); } } else { - throw new OrtException( - "Unknown output name " + s + ", expected one of " + outputNames.toString()); + throw new OrtException("Unknown output name " + s + ", expected one of " + outputNames.toString()); } } long runOptionsHandle = runOptions == null ? 0 : runOptions.getNativeHandle(); - boolean[] ownedByResult = - run( - OnnxRuntime.ortApiHandle, - nativeHandle, - allocator.handle, - inputNamesArray, - inputHandles, - inputNamesArray.length, - outputNamesArray, - outputNamesArray.length, - outputValues, - outputHandles, - runOptionsHandle); + boolean[] ownedByResult = run(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle, inputNamesArray, + inputHandles, inputNamesArray.length, outputNamesArray, outputNamesArray.length, outputValues, outputHandles, + runOptionsHandle); return new Result(outputNamesArray, outputValues, ownedByResult); } else { throw new IllegalStateException("Trying to score a closed OrtSession."); @@ -445,9 +423,9 @@ public Result run( */ static long getHandle(OnnxValue v) { /* - * Note this method exists as interface methods are all public, but we do not want users to be - * able to access the native pointer via a public API so can't add a method to OnnxValue which - * exposes it. + * Note this method exists as interface methods are all public, but we do not + * want users to be able to access the native pointer via a public API so can't + * add a method to OnnxValue which exposes it. */ if (v instanceof OnnxTensorLike) { return ((OnnxTensorLike) v).nativeHandle; @@ -457,8 +435,7 @@ static long getHandle(OnnxValue v) { return ((OnnxMap) v).nativeHandle; } else { throw new IllegalArgumentException( - "Unexpected OnnxValue subclass, should be {OnnxTensorLike, OnnxSequence, OnnxMap}, found " - + v.getClass()); + "Unexpected OnnxValue subclass, should be {OnnxTensorLike, OnnxSequence, OnnxMap}, found " + v.getClass()); } } @@ -488,7 +465,9 @@ public long getProfilingStartTimeInNs() throws OrtException { /** * Ends the profiling session and returns the output of the profiler. * - *

Profiling should be enabled in the {@link SessionOptions} used to construct this {@code + *

+ * Profiling should be enabled in the {@link SessionOptions} used to construct + * this {@code * Session}. * * @return The profiling output. @@ -534,122 +513,106 @@ private static Map wrapInMap(NodeInfo[] infos) { return output; } - private static native long createSession( - long apiHandle, long envHandle, String modelPath, long optsHandle) throws OrtException; - - private static native long createSession( - long apiHandle, long envHandle, byte[] modelArray, long optsHandle) throws OrtException; + private static native long createSession(long apiHandle, long envHandle, String modelPath, long optsHandle) + throws OrtException; - private static native long createSession( - long apiHandle, - long envHandle, - ByteBuffer modelBuffer, - int bufferPos, - int bufferSize, - long optsHandle) + private static native long createSession(long apiHandle, long envHandle, byte[] modelArray, long optsHandle) throws OrtException; + private static native long createSession(long apiHandle, long envHandle, ByteBuffer modelBuffer, int bufferPos, + int bufferSize, long optsHandle) throws OrtException; + private native long getNumInputs(long apiHandle, long nativeHandle) throws OrtException; - private native String[] getInputNames(long apiHandle, long nativeHandle, long allocatorHandle) - throws OrtException; + private native String[] getInputNames(long apiHandle, long nativeHandle, long allocatorHandle) throws OrtException; - private native NodeInfo[] getInputInfo(long apiHandle, long nativeHandle, long allocatorHandle) - throws OrtException; + private native NodeInfo[] getInputInfo(long apiHandle, long nativeHandle, long allocatorHandle) throws OrtException; private native long getNumOutputs(long apiHandle, long nativeHandle) throws OrtException; - private native String[] getOutputNames(long apiHandle, long nativeHandle, long allocatorHandle) - throws OrtException; + private native String[] getOutputNames(long apiHandle, long nativeHandle, long allocatorHandle) throws OrtException; - private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long allocatorHandle) - throws OrtException; + private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long allocatorHandle) throws OrtException; /** - * The native run call. runOptionsHandle can be zero (i.e. the null pointer), outputValues can - * contain null entries, and outputHandles can contain zero values (i.e. the null pointer), but - * all other handles must be valid pointers. + * The native run call. runOptionsHandle can be zero (i.e. the null pointer), + * outputValues can contain null entries, and outputHandles can contain zero + * values (i.e. the null pointer), but all other handles must be valid pointers. * - * @param apiHandle The pointer to the api. - * @param nativeHandle The pointer to the session. - * @param allocatorHandle The pointer to the allocator. - * @param inputNamesArray The input names. - * @param inputs The input tensors. - * @param numInputs The number of inputs. + * @param apiHandle The pointer to the api. + * @param nativeHandle The pointer to the session. + * @param allocatorHandle The pointer to the allocator. + * @param inputNamesArray The input names. + * @param inputs The input tensors. + * @param numInputs The number of inputs. * @param outputNamesArray The requested output names. - * @param outputValues The OnnxValue output array. - * @param outputHandles The OrtValue output pointer array. - * @param numOutputs The number of requested outputs. + * @param outputValues The OnnxValue output array. + * @param outputHandles The OrtValue output pointer array. + * @param numOutputs The number of requested outputs. * @param runOptionsHandle The (possibly null) pointer to the run options. - * @return A boolean array representing if the OnnxValues were allocated by this run call. + * @return A boolean array representing if the OnnxValues were allocated by this + * run call. * @throws OrtException If the native call failed in some way. */ - private native boolean[] run( - long apiHandle, - long nativeHandle, - long allocatorHandle, - String[] inputNamesArray, - long[] inputs, - long numInputs, - String[] outputNamesArray, - long numOutputs, - OnnxValue[] outputValues, - long[] outputHandles, - long runOptionsHandle) - throws OrtException; + private native boolean[] run(long apiHandle, long nativeHandle, long allocatorHandle, String[] inputNamesArray, + long[] inputs, long numInputs, String[] outputNamesArray, long numOutputs, OnnxValue[] outputValues, + long[] outputHandles, long runOptionsHandle) throws OrtException; - private native long getProfilingStartTimeInNs(long apiHandle, long nativeHandle) - throws OrtException; + private native long getProfilingStartTimeInNs(long apiHandle, long nativeHandle) throws OrtException; - private native String endProfiling(long apiHandle, long nativeHandle, long allocatorHandle) - throws OrtException; + private native String endProfiling(long apiHandle, long nativeHandle, long allocatorHandle) throws OrtException; private native void closeSession(long apiHandle, long nativeHandle) throws OrtException; /** * Builds the {@link OnnxModelMetadata} for this session. * - * @param ortApiHandle The api pointer. - * @param nativeHandle The native session pointer. + * @param ortApiHandle The api pointer. + * @param nativeHandle The native session pointer. * @param allocatorHandle The OrtAllocator pointer. * @return The metadata. - * @throws OrtException If the native runtime failed to access or allocate the metadata. + * @throws OrtException If the native runtime failed to access or allocate the + * metadata. */ - private native OnnxModelMetadata constructMetadata( - long ortApiHandle, long nativeHandle, long allocatorHandle) throws OrtException; + private native OnnxModelMetadata constructMetadata(long ortApiHandle, long nativeHandle, long allocatorHandle) + throws OrtException; /** * Represents the options used to construct this session. * - *

Used to set the number of threads, optimisation level, computation backend and other - * options. + *

+ * Used to set the number of threads, optimisation level, computation backend + * and other options. * - *

Modifying this after the session has been constructed will have no effect. + *

+ * Modifying this after the session has been constructed will have no effect. * - *

The SessionOptions object must not be closed until all sessions which use it are closed, as - * otherwise it could release resources that are in use. + *

+ * The SessionOptions object must not be closed until all sessions which use it + * are closed, as otherwise it could release resources that are in use. */ public static class SessionOptions implements AutoCloseable { /** - * The optimisation level to use. Needs to be kept in sync with the GraphOptimizationLevel enum - * in the C API. + * The optimisation level to use. Needs to be kept in sync with the + * GraphOptimizationLevel enum in the C API. * - *

See Graph + *

+ * See Graph * Optimizations for more details. */ public enum OptLevel { /** Apply no optimizations to the ONNX graph. */ NO_OPT(0), /** - * Apply basic optimizations such as constant folding, redundant computation elimination and - * node fusions to the ONNX graph. + * Apply basic optimizations such as constant folding, redundant computation + * elimination and node fusions to the ONNX graph. */ BASIC_OPT(1), /** - * Applies all the basic optimizations plus more complex node fusion operations to the ONNX - * graph. + * Applies all the basic optimizations plus more complex node fusion operations + * to the ONNX graph. */ EXTENDED_OPT(2), /** Applies all available optimizations to the ONNX graph. */ @@ -672,18 +635,21 @@ public int getID() { } /** - * The execution mode to use. Needs to be kept in sync with the ExecutionMode enum in the C API. + * The execution mode to use. Needs to be kept in sync with the ExecutionMode + * enum in the C API. */ public enum ExecutionMode { /** * Executes all nodes sequentially. * - *

This is the default, and usually provides the most speedup as intra-op parallelism - * provides the most benefit. + *

+ * This is the default, and usually provides the most speedup as intra-op + * parallelism provides the most benefit. */ SEQUENTIAL(0), /** Executes some nodes in parallel. */ PARALLEL(1); + private final int id; ExecutionMode(int id) { @@ -741,7 +707,10 @@ public void close() { } } - /** Checks if the SessionOptions is closed, if so throws {@link IllegalStateException}. */ + /** + * Checks if the SessionOptions is closed, if so throws + * {@link IllegalStateException}. + */ private void checkClosed() { if (closed) { throw new IllegalStateException("Trying to use a closed SessionOptions"); @@ -769,7 +738,8 @@ public void setExecutionMode(ExecutionMode mode) throws OrtException { } /** - * Sets the optimization level of this options object, overriding the old setting. + * Sets the optimization level of this options object, overriding the old + * setting. * * @param level The optimization level to use. * @throws OrtException If there was an error in native code. @@ -780,8 +750,8 @@ public void setOptimizationLevel(OptLevel level) throws OrtException { } /** - * Sets the size of the CPU thread pool used for executing multiple request concurrently, if - * executing on a CPU. + * Sets the size of the CPU thread pool used for executing multiple request + * concurrently, if executing on a CPU. * * @param numThreads The number of threads to use. * @throws OrtException If there was an error in native code. @@ -792,8 +762,8 @@ public void setInterOpNumThreads(int numThreads) throws OrtException { } /** - * Sets the size of the CPU thread pool used for executing a single graph, if executing on a - * CPU. + * Sets the size of the CPU thread pool used for executing a single graph, if + * executing on a CPU. * * @param numThreads The number of threads to use. * @throws OrtException If there was an error in native code. @@ -847,22 +817,22 @@ public void disableProfiling() throws OrtException { } /** - * Turns on memory pattern optimizations, where memory is preallocated if all shapes are known. + * Turns on memory pattern optimizations, where memory is preallocated if all + * shapes are known. * * @param memoryPatternOptimization If true enable memory pattern optimizations. * @throws OrtException If there was an error in native code. */ - public void setMemoryPatternOptimization(boolean memoryPatternOptimization) - throws OrtException { + public void setMemoryPatternOptimization(boolean memoryPatternOptimization) throws OrtException { checkClosed(); - setMemoryPatternOptimization( - OnnxRuntime.ortApiHandle, nativeHandle, memoryPatternOptimization); + setMemoryPatternOptimization(OnnxRuntime.ortApiHandle, nativeHandle, memoryPatternOptimization); } /** * Sets the CPU to use an arena memory allocator. * - * @param useArena If true use an arena memory allocator for the CPU execution provider. + * @param useArena If true use an arena memory allocator for the CPU execution + * provider. * @throws OrtException If there was an error in native code. */ public void setCPUArenaAllocator(boolean useArena) throws OrtException { @@ -893,7 +863,8 @@ public void setSessionLogVerbosityLevel(int logLevel) throws OrtException { } /** - * Registers a library of custom ops for use with {@link OrtSession}s using this SessionOptions. + * Registers a library of custom ops for use with {@link OrtSession}s using this + * SessionOptions. * * @param path The path to the library on disk. * @throws OrtException If there was an error loading the library. @@ -906,21 +877,26 @@ public void registerCustomOpLibrary(String path) throws OrtException { } /** - * Registers custom ops for use with {@link OrtSession}s using this SessionOptions by calling - * the specified native function name. The custom ops library must either be linked against, or - * have previously been loaded by the user. + * Registers custom ops for use with {@link OrtSession}s using this + * SessionOptions by calling the specified native function name. The custom ops + * library must either be linked against, or have previously been loaded by the + * user. * - *

The registration function must have the signature: + *

+ * The registration function must have the signature: * - *

 OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api); + *

+ *  OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api); * - *

See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for more - * information on custom ops. See + *

+ * See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for + * more information on custom ops. See * https://github.com/microsoft/onnxruntime/blob/342a5bf2b756d1a1fc6fdc582cfeac15182632fe/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc#L115 * for an example of a custom op library registration function. * * @param registrationFuncName The name of the registration function to call. - * @throws OrtException If there was an error finding or calling the registration function. + * @throws OrtException If there was an error finding or calling the + * registration function. */ public void registerCustomOpsUsingFunction(String registrationFuncName) throws OrtException { checkClosed(); @@ -928,25 +904,25 @@ public void registerCustomOpsUsingFunction(String registrationFuncName) throws O } /** - * Sets the value of a symbolic dimension. Fixed dimension computations may have more - * optimizations applied to them. + * Sets the value of a symbolic dimension. Fixed dimension computations may have + * more optimizations applied to them. * - * @param dimensionName The name of the symbolic dimension. + * @param dimensionName The name of the symbolic dimension. * @param dimensionValue The value to set that dimension to. * @throws OrtException If there was an error in native code. */ - public void setSymbolicDimensionValue(String dimensionName, long dimensionValue) - throws OrtException { + public void setSymbolicDimensionValue(String dimensionName, long dimensionValue) throws OrtException { checkClosed(); - addFreeDimensionOverrideByName( - OnnxRuntime.ortApiHandle, nativeHandle, dimensionName, dimensionValue); + addFreeDimensionOverrideByName(OnnxRuntime.ortApiHandle, nativeHandle, dimensionName, dimensionValue); } /** * Set whether to use deterministic compute. * - *

Default is false. If set to true, this will enable deterministic compute for GPU kernels - * where possible. Note that this most likely will have a performance cost. + *

+ * Default is false. If set to true, this will enable deterministic compute for + * GPU kernels where possible. Note that this most likely will have a + * performance cost. * * @param value Should the compute be deterministic? * @throws OrtException If there was an error in native code. @@ -957,8 +933,8 @@ public void setDeterministicCompute(boolean value) throws OrtException { } /** - * Disables the per session thread pools. Must be used in conjunction with an environment - * containing global thread pools. + * Disables the per session thread pools. Must be used in conjunction with an + * environment containing global thread pools. * * @throws OrtException If there was an error in native code. */ @@ -970,7 +946,7 @@ public void disablePerSessionThreads() throws OrtException { /** * Adds a single session configuration entry as a pair of strings. * - * @param configKey The config key string. + * @param configKey The config key string. * @param configValue The config value string. * @throws OrtException If there was an error in native code. */ @@ -981,7 +957,8 @@ public void addConfigEntry(String configKey, String configValue) throws OrtExcep } /** - * Returns an unmodifiable view of the map contains all session configuration entries. + * Returns an unmodifiable view of the map contains all session configuration + * entries. * * @return All session configuration entries */ @@ -993,17 +970,18 @@ public Map getConfigEntries() { /** * Adds in the supplied externally loaded initializers. * - *

Note the initializers are copied into the session once it has been created, and the native - * references are removed from this {@code SessionOptions}. Once the session has been created - * those initializers can be closed. This is a different lifetime to initializers added via - * {@link #addInitializer(String, OnnxTensorLike)}. The initializers must be created from {@link - * java.nio.Buffer} objects. + *

+ * Note the initializers are copied into the session once it has been created, + * and the native references are removed from this {@code SessionOptions}. Once + * the session has been created those initializers can be closed. This is a + * different lifetime to initializers added via + * {@link #addInitializer(String, OnnxTensorLike)}. The initializers must be + * created from {@link java.nio.Buffer} objects. * * @param initializers The map of names to initializers. * @throws OrtException If the initializers could not be loaded. */ - public void addExternalInitializers(Map initializers) - throws OrtException { + public void addExternalInitializers(Map initializers) throws OrtException { checkClosed(); if (initializers.isEmpty()) { return; @@ -1022,13 +1000,16 @@ public void addExternalInitializers(Map initializers) /** * Adds an initializer to override one from the ONNX model. * - *

Note the initializer lifetime must outlive the session and session options. This is a - * different lifetime to initializers added via {@link #addExternalInitializers(Map)}. The - * initializers must be created from {@link java.nio.Buffer} objects. + *

+ * Note the initializer lifetime must outlive the session and session options. + * This is a different lifetime to initializers added via + * {@link #addExternalInitializers(Map)}. The initializers must be created from + * {@link java.nio.Buffer} objects. * - * @param name The initializer name. + * @param name The initializer name. * @param initializer The initializer value. - * @throws OrtException If the initializer could not be loaded into the session options. + * @throws OrtException If the initializer could not be loaded into the session + * options. */ public void addInitializer(String name, OnnxTensorLike initializer) throws OrtException { checkClosed(); @@ -1058,8 +1039,7 @@ public void addCUDA(int deviceNum) throws OrtException { if (OnnxRuntime.extractCUDA()) { addCUDA(OnnxRuntime.ortApiHandle, nativeHandle, deviceNum); } else { - throw new OrtException( - OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find CUDA shared provider"); + throw new OrtException(OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find CUDA shared provider"); } } @@ -1076,8 +1056,7 @@ public void addCUDA(OrtCUDAProviderOptions cudaOpts) throws OrtException { ((OrtProviderOptions) cudaOpts).applyToNative(); addCUDAV2(OnnxRuntime.ortApiHandle, nativeHandle, cudaOpts.nativeHandle); } else { - throw new OrtException( - OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find CUDA shared provider"); + throw new OrtException(OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find CUDA shared provider"); } } @@ -1101,16 +1080,16 @@ public void addROCM(int deviceNum) throws OrtException { if (OnnxRuntime.extractROCM()) { addROCM(OnnxRuntime.ortApiHandle, nativeHandle, deviceNum); } else { - throw new OrtException( - OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find ROCM shared provider"); + throw new OrtException(OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find ROCM shared provider"); } } /** * Adds the CPU as an execution backend, using the arena allocator if desired. * - *

By default this backend is used, but if other backends are requested, it should be - * requested last. + *

+ * By default this backend is used, but if other backends are requested, it + * should be requested last. * * @param useArena If true use the arena memory allocator. * @throws OrtException If there was an error in native code. @@ -1131,8 +1110,7 @@ public void addDnnl(boolean useArena) throws OrtException { if (OnnxRuntime.extractDNNL()) { addDnnl(OnnxRuntime.ortApiHandle, nativeHandle, useArena ? 1 : 0); } else { - throw new OrtException( - OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find DNNL shared provider"); + throw new OrtException(OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find DNNL shared provider"); } } @@ -1147,8 +1125,7 @@ public void addOpenVINO(String deviceId) throws OrtException { if (OnnxRuntime.extractOpenVINO()) { addOpenVINO(OnnxRuntime.ortApiHandle, nativeHandle, deviceId); } else { - throw new OrtException( - OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find OpenVINO shared provider"); + throw new OrtException(OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find OpenVINO shared provider"); } } @@ -1163,8 +1140,7 @@ public void addTensorrt(int deviceNum) throws OrtException { if (OnnxRuntime.extractTensorRT()) { addTensorrt(OnnxRuntime.ortApiHandle, nativeHandle, deviceNum); } else { - throw new OrtException( - OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find TensorRT shared provider"); + throw new OrtException(OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find TensorRT shared provider"); } } @@ -1181,8 +1157,7 @@ public void addTensorrt(OrtTensorRTProviderOptions tensorRTOpts) throws OrtExcep ((OrtProviderOptions) tensorRTOpts).applyToNative(); addTensorrtV2(OnnxRuntime.ortApiHandle, nativeHandle, tensorRTOpts.nativeHandle); } else { - throw new OrtException( - OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find TensorRT shared provider"); + throw new OrtException(OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find TensorRT shared provider"); } } @@ -1271,11 +1246,12 @@ public void addCoreML(EnumSet flags) throws OrtException { } /** - * Adds Xnnpack as an execution backend. Needs to list all options hereif a new option - * supported. current supported options: {} The maximum number of provider options is set to 128 - * (see addExecutionProvider's comment). This number is controlled by - * ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH in ai_onnxruntime_OrtSession_SessionOptions.c. If 128 is - * not enough, please increase it or implementing an incremental way to add more options. + * Adds Xnnpack as an execution backend. Needs to list all options hereif a new + * option supported. current supported options: {} The maximum number of + * provider options is set to 128 (see addExecutionProvider's comment). This + * number is controlled by ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH in + * ai_onnxruntime_OrtSession_SessionOptions.c. If 128 is not enough, please + * increase it or implementing an incremental way to add more options. * * @param providerOptions options pass to XNNPACK EP for initialization. * @throws OrtException If there was an error in native code. @@ -1290,146 +1266,118 @@ public void addXnnpack(Map providerOptions) throws OrtException providerOptionVal[i] = entry.getValue(); i++; } - addExecutionProvider( - OnnxRuntime.ortApiHandle, nativeHandle, "XNNPACK", providerOptionKey, providerOptionVal); + addExecutionProvider(OnnxRuntime.ortApiHandle, nativeHandle, "XNNPACK", providerOptionKey, providerOptionVal); } - private native void setExecutionMode(long apiHandle, long nativeHandle, int mode) - throws OrtException; + private native void setExecutionMode(long apiHandle, long nativeHandle, int mode) throws OrtException; - private native void setOptimizationLevel(long apiHandle, long nativeHandle, int level) - throws OrtException; + private native void setOptimizationLevel(long apiHandle, long nativeHandle, int level) throws OrtException; - private native void setInterOpNumThreads(long apiHandle, long nativeHandle, int numThreads) - throws OrtException; + private native void setInterOpNumThreads(long apiHandle, long nativeHandle, int numThreads) throws OrtException; - private native void setIntraOpNumThreads(long apiHandle, long nativeHandle, int numThreads) - throws OrtException; + private native void setIntraOpNumThreads(long apiHandle, long nativeHandle, int numThreads) throws OrtException; - private native void setOptimizationModelFilePath( - long apiHandle, long nativeHandle, String modelPath) throws OrtException; + private native void setOptimizationModelFilePath(long apiHandle, long nativeHandle, String modelPath) + throws OrtException; private native long createOptions(long apiHandle); - private native void setLoggerId(long apiHandle, long nativeHandle, String loggerId) - throws OrtException; + private native void setLoggerId(long apiHandle, long nativeHandle, String loggerId) throws OrtException; - private native void enableProfiling(long apiHandle, long nativeHandle, String filePrefix) - throws OrtException; + private native void enableProfiling(long apiHandle, long nativeHandle, String filePrefix) throws OrtException; private native void disableProfiling(long apiHandle, long nativeHandle) throws OrtException; - private native void setMemoryPatternOptimization( - long apiHandle, long nativeHandle, boolean memoryPatternOptimization) throws OrtException; + private native void setMemoryPatternOptimization(long apiHandle, long nativeHandle, + boolean memoryPatternOptimization) throws OrtException; - private native void setCPUArenaAllocator(long apiHandle, long nativeHandle, boolean useArena) - throws OrtException; + private native void setCPUArenaAllocator(long apiHandle, long nativeHandle, boolean useArena) throws OrtException; - private native void setSessionLogLevel(long apiHandle, long nativeHandle, int logLevel) - throws OrtException; + private native void setSessionLogLevel(long apiHandle, long nativeHandle, int logLevel) throws OrtException; private native void setSessionLogVerbosityLevel(long apiHandle, long nativeHandle, int logLevel) throws OrtException; - private native long registerCustomOpLibrary(long apiHandle, long nativeHandle, String path) - throws OrtException; + private native long registerCustomOpLibrary(long apiHandle, long nativeHandle, String path) throws OrtException; - private native void registerCustomOpsUsingFunction( - long apiHandle, long nativeHandle, String registrationFuncName) throws OrtException; + private native void registerCustomOpsUsingFunction(long apiHandle, long nativeHandle, String registrationFuncName) + throws OrtException; private native void closeCustomLibraries(long[] nativeHandle); private native void closeOptions(long apiHandle, long nativeHandle); - private native void setDeterministicCompute( - long apiHandle, long nativeHandle, boolean isDeterministic) throws OrtException; - - private native void addFreeDimensionOverrideByName( - long apiHandle, long nativeHandle, String dimensionName, long dimensionValue) + private native void setDeterministicCompute(long apiHandle, long nativeHandle, boolean isDeterministic) throws OrtException; - private native void addExternalInitializers( - long apiHandle, long nativeHandle, String[] names, long[] tensorHandles) - throws OrtException; + private native void addFreeDimensionOverrideByName(long apiHandle, long nativeHandle, String dimensionName, + long dimensionValue) throws OrtException; - private native void addInitializer( - long apiHandle, long nativeHandle, String name, long tensorHandle) throws OrtException; + private native void addExternalInitializers(long apiHandle, long nativeHandle, String[] names, long[] tensorHandles) + throws OrtException; - private native void disablePerSessionThreads(long apiHandle, long nativeHandle) + private native void addInitializer(long apiHandle, long nativeHandle, String name, long tensorHandle) throws OrtException; - private native void addConfigEntry( - long apiHandle, long nativeHandle, String configKey, String configValue) + private native void disablePerSessionThreads(long apiHandle, long nativeHandle) throws OrtException; + + private native void addConfigEntry(long apiHandle, long nativeHandle, String configKey, String configValue) throws OrtException; /* - * To use additional providers, you must build ORT with the extra providers enabled. Then call one of these - * functions to enable them in the session: - * OrtSessionOptionsAppendExecutionProvider_CPU - * OrtSessionOptionsAppendExecutionProvider_CUDA - * OrtSessionOptionsAppendExecutionProvider_ROCM - * OrtSessionOptionsAppendExecutionProvider_ - * The order they care called indicates the preference order as well. In other words call this method - * on your most preferred execution provider first followed by the less preferred ones. - * If none are called Ort will use its internal CPU execution provider. + * To use additional providers, you must build ORT with the extra providers + * enabled. Then call one of these functions to enable them in the session: + * OrtSessionOptionsAppendExecutionProvider_CPU + * OrtSessionOptionsAppendExecutionProvider_CUDA + * OrtSessionOptionsAppendExecutionProvider_ROCM + * OrtSessionOptionsAppendExecutionProvider_ The order + * they care called indicates the preference order as well. In other words call + * this method on your most preferred execution provider first followed by the + * less preferred ones. If none are called Ort will use its internal CPU + * execution provider. * * If a backend is unavailable then it throws an OrtException */ private native void addCPU(long apiHandle, long nativeHandle, int useArena) throws OrtException; - private native void addCUDA(long apiHandle, long nativeHandle, int deviceNum) - throws OrtException; + private native void addCUDA(long apiHandle, long nativeHandle, int deviceNum) throws OrtException; - private native void addCUDAV2(long apiHandle, long nativeHandle, long cudaOptsHandle) - throws OrtException; + private native void addCUDAV2(long apiHandle, long nativeHandle, long cudaOptsHandle) throws OrtException; - private native void addROCM(long apiHandle, long nativeHandle, int deviceNum) - throws OrtException; + private native void addROCM(long apiHandle, long nativeHandle, int deviceNum) throws OrtException; - private native void addDnnl(long apiHandle, long nativeHandle, int useArena) - throws OrtException; + private native void addDnnl(long apiHandle, long nativeHandle, int useArena) throws OrtException; - private native void addOpenVINO(long apiHandle, long nativeHandle, String deviceId) - throws OrtException; + private native void addOpenVINO(long apiHandle, long nativeHandle, String deviceId) throws OrtException; - private native void addTensorrt(long apiHandle, long nativeHandle, int deviceNum) - throws OrtException; + private native void addTensorrt(long apiHandle, long nativeHandle, int deviceNum) throws OrtException; - private native void addTensorrtV2(long apiHandle, long nativeHandle, long tensorrtOptsHandle) - throws OrtException; + private native void addTensorrtV2(long apiHandle, long nativeHandle, long tensorrtOptsHandle) throws OrtException; - private native void addNnapi(long apiHandle, long nativeHandle, int nnapiFlags) - throws OrtException; + private native void addNnapi(long apiHandle, long nativeHandle, int nnapiFlags) throws OrtException; - private native void addTvm(long apiHandle, long nativeHandle, String settings) - throws OrtException; + private native void addTvm(long apiHandle, long nativeHandle, String settings) throws OrtException; - private native void addDirectML(long apiHandle, long nativeHandle, int deviceId) - throws OrtException; + private native void addDirectML(long apiHandle, long nativeHandle, int deviceId) throws OrtException; - private native void addACL(long apiHandle, long nativeHandle, boolean enableFastMath) - throws OrtException; + private native void addACL(long apiHandle, long nativeHandle, boolean enableFastMath) throws OrtException; - private native void addArmNN(long apiHandle, long nativeHandle, int useArena) - throws OrtException; + private native void addArmNN(long apiHandle, long nativeHandle, int useArena) throws OrtException; - private native void addCoreML(long apiHandle, long nativeHandle, int coreMLFlags) - throws OrtException; + private native void addCoreML(long apiHandle, long nativeHandle, int coreMLFlags) throws OrtException; /* - * The max length of providerOptionKey and providerOptionVal is 128, as specified by - * ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH (search ONNXRuntime PR #14067 for its location). - */ - private native void addExecutionProvider( - long apiHandle, - long nativeHandle, - String epName, - String[] providerOptionKey, - String[] providerOptionVal) - throws OrtException; + * The max length of providerOptionKey and providerOptionVal is 128, as + * specified by ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH (search ONNXRuntime PR #14067 + * for its location). + */ + private native void addExecutionProvider(long apiHandle, long nativeHandle, String epName, + String[] providerOptionKey, String[] providerOptionVal) throws OrtException; } - /** Used to control logging and termination of a call to {@link OrtSession#run}. */ + /** + * Used to control logging and termination of a call to {@link OrtSession#run}. + */ public static class RunOptions implements AutoCloseable { static { @@ -1529,8 +1477,10 @@ public String getRunTag() throws OrtException { } /** - * Sets a flag so that all incomplete {@link OrtSession#run} calls using this instance of {@code - * RunOptions} will terminate as soon as possible. If the flag is false, it resets this {@code + * Sets a flag so that all incomplete {@link OrtSession#run} calls using this + * instance of {@code + * RunOptions} will terminate as soon as possible. If the flag is false, it + * resets this {@code * RunOptions} so it can be used with other calls to {@link OrtSession#run}. * * @param terminate If true terminate all runs associated with this RunOptions. @@ -1544,9 +1494,10 @@ public void setTerminate(boolean terminate) throws OrtException { /** * Adds a configuration entry to this {@code RunOptions}. * - *

Setting the same key will overwrite the value. + *

+ * Setting the same key will overwrite the value. * - * @param key The configuration key. + * @param key The configuration key. * @param value The configuration value. * @throws OrtException If the native library call failed. */ @@ -1555,7 +1506,23 @@ public void addRunConfigEntry(String key, String value) throws OrtException { addRunConfigEntry(OnnxRuntime.ortApiHandle, nativeHandle, key, value); } - /** Checks if the RunOptions is closed, if so throws {@link IllegalStateException}. */ + /** + * Adds the specified adapter to the list of active adapters + * + *

+ * + * @param loraAdapter valid LoraAdapter object + * @throws OrtException of the native library call failed + */ + public void addActiveLoraAdapter(LoraAdapter loraAdapter) throws OrtException { + checkClosed(); + addActiveLoraAdapter(OnnxRuntime.ortApiHandle, nativeHandle, loraAdapter.getNativeHandle()); + } + + /** + * Checks if the RunOptions is closed, if so throws + * {@link IllegalStateException}. + */ private void checkClosed() { if (closed) { throw new IllegalStateException("Trying to use a closed RunOptions"); @@ -1574,40 +1541,97 @@ public void close() { private static native long createRunOptions(long apiHandle) throws OrtException; - private native void setLogLevel(long apiHandle, long nativeHandle, int logLevel) - throws OrtException; + private native void setLogLevel(long apiHandle, long nativeHandle, int logLevel) throws OrtException; private native int getLogLevel(long apiHandle, long nativeHandle) throws OrtException; - private native void setLogVerbosityLevel(long apiHandle, long nativeHandle, int logLevel) - throws OrtException; + private native void setLogVerbosityLevel(long apiHandle, long nativeHandle, int logLevel) throws OrtException; private native int getLogVerbosityLevel(long apiHandle, long nativeHandle) throws OrtException; - private native void setRunTag(long apiHandle, long nativeHandle, String runTag) - throws OrtException; + private native void setRunTag(long apiHandle, long nativeHandle, String runTag) throws OrtException; private native String getRunTag(long apiHandle, long nativeHandle) throws OrtException; - private native void setTerminate(long apiHandle, long nativeHandle, boolean terminate) + private native void setTerminate(long apiHandle, long nativeHandle, boolean terminate) throws OrtException; + + private native void addRunConfigEntry(long apiHandle, long nativeHandle, String key, String value) throws OrtException; - private native void addRunConfigEntry( - long apiHandle, long nativeHandle, String key, String value) throws OrtException; + private native void addActiveLoraAdapter(long apiHandle, long nativeHandle, long loraAdapterHandle) throws OrtException; + + private static native void close(long apiHandle, long nativeHandle); + } + + public static class LoraAdapter implements AutoCloseable { + static { + try { + OnnxRuntime.init(); + } catch (IOException e) { + throw new RuntimeException("Failed to load onnx-runtime library", e); + } + } + + private final long nativeHandle; + + private boolean closed = false; + + /** + * Creates an instance of LoraAdapter. + * + * @throws OrtException + */ + public LoraAdapter() throws OrtException { + this.nativeHandle = createLoraAdapter(OnnxRuntime.ortApiHandle); + } + + /** + * Package accessor for native pointer. + * + * @return The native pointer. + */ + long getNativeHandle() { + return nativeHandle; + } + + /** + * Checks if the LoraAdapter is closed, if so throws + * {@link IllegalStateException}. + */ + private void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed LoraAdapter"); + } + } + + @Override + public void close() { + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + throw new IllegalStateException("Trying to close an already closed LoraAdapter"); + } + } + + private native long createLoraAdapter(long apiHandle) throws OrtException; private static native void close(long apiHandle, long nativeHandle); } /** - * An {@link AutoCloseable} wrapper around a {@link Map} containing {@link OnnxValue}s. + * An {@link AutoCloseable} wrapper around a {@link Map} containing + * {@link OnnxValue}s. * - *

When this is closed it closes all the {@link OnnxValue}s owned by the result object. If you - * maintain a reference to a value after this object has been closed it will throw an {@link - * IllegalStateException} upon access. + *

+ * When this is closed it closes all the {@link OnnxValue}s owned by the result + * object. If you maintain a reference to a value after this object has been + * closed it will throw an {@link IllegalStateException} upon access. * - *

{@link OnnxValue}s which are supplied as pinned outputs to a {@code run} call are not closed - * by the {@link Result#close()} method. Ownership of each output can be checked with {@link - * Result#isResultOwner(int)}. + *

+ * {@link OnnxValue}s which are supplied as pinned outputs to a {@code run} call + * are not closed by the {@link Result#close()} method. Ownership of each output + * can be checked with {@link Result#isResultOwner(int)}. */ public static class Result implements AutoCloseable, Iterable> { @@ -1622,20 +1646,17 @@ public static class Result implements AutoCloseable, Iterable Date: Fri, 27 Sep 2024 10:37:07 -0400 Subject: [PATCH 03/11] Adding JNI & tests for OrtLoraAdapter. --- .../java/ai/onnxruntime/OrtLoraAdapter.java | 86 ++ .../main/java/ai/onnxruntime/OrtSession.java | 802 +++++++++--------- .../native/ai_onnxruntime_OrtLoraAdapter.c | 56 ++ .../ai_onnxruntime_OrtSession_RunOptions.c | 12 + .../java/ai/onnxruntime/InferenceTest.java | 91 +- 5 files changed, 596 insertions(+), 451 deletions(-) create mode 100644 java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java create mode 100644 java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c diff --git a/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java b/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java new file mode 100644 index 0000000000000..10b44cfb7def2 --- /dev/null +++ b/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java @@ -0,0 +1,86 @@ +/* + * Copyright © 2024, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import java.io.IOException; + +/** + * A container for an adapter which can be supplied to {@link + * OrtSession.RunOptions#addActiveLoraAdapter(OrtLoraAdapter)} to apply the adapter to a specific + * execution of a model. + */ +public final class OrtLoraAdapter implements AutoCloseable { + static { + try { + OnnxRuntime.init(); + } catch (IOException e) { + throw new RuntimeException("Failed to load onnx-runtime library", e); + } + } + + private final long nativeHandle; + + private boolean closed = false; + + private OrtLoraAdapter(long nativeHandle) { + this.nativeHandle = nativeHandle; + } + + /** + * Creates an instance of OrtLoraAdapter. + * + * @param absoluteAdapterPath path to the adapter file that is going to be memory mapped. + * @throws OrtException If the native call failed. + */ + public static OrtLoraAdapter create(String absoluteAdapterPath) throws OrtException { + return create(absoluteAdapterPath, null); + } + + /** + * Creates an instance of OrtLoraAdapter. + * + * @param absoluteAdapterPath path to the adapter file that is going to be memory mapped. + * @param allocator optional allocator or null. If supplied, adapter parameters are copied to the + * allocator memory. + * @throws OrtException If the native call failed. + */ + static OrtLoraAdapter create(String absoluteAdapterPath, OrtAllocator allocator) + throws OrtException { + long allocatorHandle = allocator == null ? 0 : allocator.handle; + return new OrtLoraAdapter( + createLoraAdapter(OnnxRuntime.ortApiHandle, absoluteAdapterPath, allocatorHandle)); + } + + /** + * Package accessor for native pointer. + * + * @return The native pointer. + */ + long getNativeHandle() { + return nativeHandle; + } + + /** Checks if the LoraAdapter is closed, if so throws {@link IllegalStateException}. */ + void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed LoraAdapter"); + } + } + + @Override + public void close() { + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + throw new IllegalStateException("Trying to close an already closed LoraAdapter"); + } + } + + private static native long createLoraAdapter( + long apiHandle, String adapterPath, long allocatorHandle) throws OrtException; + + private static native void close(long apiHandle, long nativeHandle); +} diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index a5713c680f074..47ee62401fca4 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -29,13 +29,11 @@ /** * Wraps an ONNX model and allows inference calls. * - *

- * Allows the inspection of the model's input and output nodes. Produced by an - * {@link OrtEnvironment}. + *

Allows the inspection of the model's input and output nodes. Produced by an {@link + * OrtEnvironment}. * - *

- * Most instance methods throw {@link IllegalStateException} if the session is - * closed and the methods are called. + *

Most instance methods throw {@link IllegalStateException} if the session is closed and the + * methods are called. */ public class OrtSession implements AutoCloseable { @@ -66,70 +64,81 @@ public class OrtSession implements AutoCloseable { /** * Create a session loading the model from disk. * - * @param env The environment. + * @param env The environment. * @param modelPath The path to the model. * @param allocator The allocator to use. - * @param options Session configuration options. - * @throws OrtException If the file could not be read, or the model was - * corrupted etc. + * @param options Session configuration options. + * @throws OrtException If the file could not be read, or the model was corrupted etc. */ - OrtSession(OrtEnvironment env, String modelPath, OrtAllocator allocator, SessionOptions options) throws OrtException { - this(createSession(OnnxRuntime.ortApiHandle, env.getNativeHandle(), modelPath, options.getNativeHandle()), + OrtSession(OrtEnvironment env, String modelPath, OrtAllocator allocator, SessionOptions options) + throws OrtException { + this( + createSession( + OnnxRuntime.ortApiHandle, env.getNativeHandle(), modelPath, options.getNativeHandle()), allocator); } /** * Creates a session reading the model from the supplied byte array. * - * @param env The environment. + * @param env The environment. * @param modelArray The model protobuf as a byte array. - * @param allocator The allocator to use. - * @param options Session configuration options. - * @throws OrtException If the model was corrupted or some other error occurred - * in native code. + * @param allocator The allocator to use. + * @param options Session configuration options. + * @throws OrtException If the model was corrupted or some other error occurred in native code. */ OrtSession(OrtEnvironment env, byte[] modelArray, OrtAllocator allocator, SessionOptions options) throws OrtException { - this(createSession(OnnxRuntime.ortApiHandle, env.getNativeHandle(), modelArray, options.getNativeHandle()), + this( + createSession( + OnnxRuntime.ortApiHandle, env.getNativeHandle(), modelArray, options.getNativeHandle()), allocator); } /** * Creates a session reading the model from the supplied byte buffer. * - *

- * Must be a direct byte buffer. + *

Must be a direct byte buffer. * - * @param env The environment. + * @param env The environment. * @param modelBuffer The model protobuf as a byte buffer. - * @param allocator The allocator to use. - * @param options Session configuration options. - * @throws OrtException If the model was corrupted or some other error occurred - * in native code. + * @param allocator The allocator to use. + * @param options Session configuration options. + * @throws OrtException If the model was corrupted or some other error occurred in native code. */ - OrtSession(OrtEnvironment env, ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options) + OrtSession( + OrtEnvironment env, ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options) throws OrtException { - this(createSession(OnnxRuntime.ortApiHandle, env.getNativeHandle(), modelBuffer, modelBuffer.position(), - modelBuffer.remaining(), options.getNativeHandle()), allocator); + this( + createSession( + OnnxRuntime.ortApiHandle, + env.getNativeHandle(), + modelBuffer, + modelBuffer.position(), + modelBuffer.remaining(), + options.getNativeHandle()), + allocator); } /** * Private constructor to build the Java object wrapped around a native session. * * @param nativeHandle The pointer to the native session. - * @param allocator The allocator to use. - * @throws OrtException If the model's inputs, outputs or metadata could not be - * read. + * @param allocator The allocator to use. + * @throws OrtException If the model's inputs, outputs or metadata could not be read. */ private OrtSession(long nativeHandle, OrtAllocator allocator) throws OrtException { this.nativeHandle = nativeHandle; this.allocator = allocator; numInputs = getNumInputs(OnnxRuntime.ortApiHandle, nativeHandle); - inputNames = new LinkedHashSet<>( - Arrays.asList(getInputNames(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle))); + inputNames = + new LinkedHashSet<>( + Arrays.asList(getInputNames(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle))); numOutputs = getNumOutputs(OnnxRuntime.ortApiHandle, nativeHandle); - outputNames = new LinkedHashSet<>( - Arrays.asList(getOutputNames(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle))); + outputNames = + new LinkedHashSet<>( + Arrays.asList( + getOutputNames(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle))); } /** @@ -159,8 +168,7 @@ public long getNumOutputs() { } /** - * Returns the input names. The underlying collection is sorted based on the - * input id number. + * Returns the input names. The underlying collection is sorted based on the input id number. * * @return The input names. */ @@ -173,8 +181,7 @@ public Set getInputNames() { } /** - * Returns the output names. The underlying collection is sorted based on the - * output id number. + * Returns the output names. The underlying collection is sorted based on the output id number. * * @return The output names. */ @@ -187,8 +194,8 @@ public Set getOutputNames() { } /** - * Returns the info objects for the inputs, including their names and types. The - * underlying collection is sorted based on the input id number. + * Returns the info objects for the inputs, including their names and types. The underlying + * collection is sorted based on the input id number. * * @return The input information. * @throws OrtException If there was an error in native code. @@ -202,8 +209,8 @@ public Map getInputInfo() throws OrtException { } /** - * Returns the info objects for the outputs, including their names and types. - * The underlying collection is sorted based on the output id number. + * Returns the info objects for the outputs, including their names and types. The underlying + * collection is sorted based on the output id number. * * @return The output information. * @throws OrtException If there was an error in native code. @@ -219,13 +226,12 @@ public Map getOutputInfo() throws OrtException { /** * Scores an input feed dict, returning the map of all inferred outputs. * - *

- * The outputs are sorted based on their id number. + *

The outputs are sorted based on their id number. * * @param inputs The inputs to score. * @return The inferred outputs. - * @throws OrtException If there was an error in native code, the input names - * are invalid, or if there are zero or too many inputs. + * @throws OrtException If there was an error in native code, the input names are invalid, or if + * there are zero or too many inputs. */ public Result run(Map inputs) throws OrtException { return run(inputs, outputNames); @@ -234,51 +240,51 @@ public Result run(Map inputs) throws OrtExcept /** * Scores an input feed dict, returning the map of all inferred outputs. * - *

- * The outputs are sorted based on their id number. + *

The outputs are sorted based on their id number. * - * @param inputs The inputs to score. + * @param inputs The inputs to score. * @param runOptions The RunOptions to control this run. * @return The inferred outputs. - * @throws OrtException If there was an error in native code, the input names - * are invalid, or if there are zero or too many inputs. + * @throws OrtException If there was an error in native code, the input names are invalid, or if + * there are zero or too many inputs. */ - public Result run(Map inputs, RunOptions runOptions) throws OrtException { + public Result run(Map inputs, RunOptions runOptions) + throws OrtException { return run(inputs, outputNames, runOptions); } /** * Scores an input feed dict, returning the map of requested inferred outputs. * - *

- * The outputs are sorted based on the supplied set traversal order. + *

The outputs are sorted based on the supplied set traversal order. * - * @param inputs The inputs to score. + * @param inputs The inputs to score. * @param requestedOutputs The requested outputs. * @return The inferred outputs. - * @throws OrtException If there was an error in native code, the input or - * output names are invalid, or if there are zero or too - * many inputs or outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. */ - public Result run(Map inputs, Set requestedOutputs) throws OrtException { + public Result run(Map inputs, Set requestedOutputs) + throws OrtException { return run(inputs, requestedOutputs, Collections.emptyMap(), null); } /** * Scores an input feed dict, returning the map of requested inferred outputs. * - *

- * The outputs are sorted based on the supplied set traversal order. + *

The outputs are sorted based on the supplied set traversal order. * - * @param inputs The inputs to score. + * @param inputs The inputs to score. * @param requestedOutputs The requested outputs. - * @param runOptions The RunOptions to control this run. + * @param runOptions The RunOptions to control this run. * @return The inferred outputs. - * @throws OrtException If there was an error in native code, the input or - * output names are invalid, or if there are zero or too - * many inputs or outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. */ - public Result run(Map inputs, Set requestedOutputs, RunOptions runOptions) + public Result run( + Map inputs, + Set requestedOutputs, + RunOptions runOptions) throws OrtException { return run(inputs, requestedOutputs, Collections.emptyMap(), runOptions); } @@ -286,21 +292,19 @@ public Result run(Map inputs, Set requ /** * Scores an input feed dict, returning the map of pinned outputs. * - *

- * The outputs are sorted based on the supplied map traversal order. + *

The outputs are sorted based on the supplied map traversal order. * - *

- * Note: pinned outputs are not owned by the {@link Result} object, and are - * not closed when the result object is closed. + *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed + * when the result object is closed. * - * @param inputs The inputs to score. + * @param inputs The inputs to score. * @param pinnedOutputs The requested outputs which the user has allocated. * @return The inferred outputs. - * @throws OrtException If there was an error in native code, the input or - * output names are invalid, or if there are zero or too - * many inputs or outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. */ - public Result run(Map inputs, Map pinnedOutputs) + public Result run( + Map inputs, Map pinnedOutputs) throws OrtException { return run(inputs, Collections.emptySet(), pinnedOutputs, null); } @@ -308,61 +312,64 @@ public Result run(Map inputs, Map - * The outputs are sorted based on the supplied set traversal order with pinned - * outputs first, then requested outputs. An {@link IllegalArgumentException} is - * thrown if the same output name appears in both the requested outputs and the - * pinned outputs. + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. * - *

- * Note: pinned outputs are not owned by the {@link Result} object, and are - * not closed when the result object is closed. + *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed + * when the result object is closed. * - * @param inputs The inputs to score. + * @param inputs The inputs to score. * @param requestedOutputs The requested outputs which ORT will allocate. - * @param pinnedOutputs The requested outputs which the user has allocated. + * @param pinnedOutputs The requested outputs which the user has allocated. * @return The inferred outputs. - * @throws OrtException If there was an error in native code, the input or - * output names are invalid, or if there are zero or too - * many inputs or outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. */ - public Result run(Map inputs, Set requestedOutputs, - Map pinnedOutputs) throws OrtException { + public Result run( + Map inputs, + Set requestedOutputs, + Map pinnedOutputs) + throws OrtException { return run(inputs, requestedOutputs, pinnedOutputs, null); } /** * Scores an input feed dict, returning the map of requested and pinned outputs. * - *

- * The outputs are sorted based on the supplied set traversal order with pinned - * outputs first, then requested outputs. An {@link IllegalArgumentException} is - * thrown if the same output name appears in both the requested outputs and the - * pinned outputs. + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. * - *

- * Note: pinned outputs are not owned by the {@link Result} object, and are - * not closed when the result object is closed. + *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed + * when the result object is closed. * - * @param inputs The inputs to score. + * @param inputs The inputs to score. * @param requestedOutputs The requested outputs which ORT will allocate. - * @param pinnedOutputs The requested outputs which the user has allocated. - * @param runOptions The RunOptions to control this run. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @param runOptions The RunOptions to control this run. * @return The inferred outputs. - * @throws OrtException If there was an error in native code, the input or - * output names are invalid, or if there are zero or too - * many inputs or outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. */ - public Result run(Map inputs, Set requestedOutputs, - Map pinnedOutputs, RunOptions runOptions) throws OrtException { + public Result run( + Map inputs, + Set requestedOutputs, + Map pinnedOutputs, + RunOptions runOptions) + throws OrtException { if (!closed) { if ((inputs.isEmpty() && (numInputs != 0)) || (inputs.size() > numInputs)) { - throw new OrtException("Unexpected number of inputs, expected [1," + numInputs + ") found " + inputs.size()); + throw new OrtException( + "Unexpected number of inputs, expected [1," + numInputs + ") found " + inputs.size()); } int totalOutputs = requestedOutputs.size() + pinnedOutputs.size(); if ((totalOutputs == 0) || (totalOutputs > numOutputs)) { - throw new OrtException("Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + numOutputs - + ") found " + totalOutputs); + throw new OrtException( + "Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + + numOutputs + + ") found " + + totalOutputs); } String[] inputNamesArray = new String[inputs.size()]; long[] inputHandles = new long[inputs.size()]; @@ -373,7 +380,8 @@ public Result run(Map inputs, Set requ inputHandles[i] = t.getValue().getNativeHandle(); i++; } else { - throw new OrtException("Unknown input name " + t.getKey() + ", expected one of " + inputNames.toString()); + throw new OrtException( + "Unknown input name " + t.getKey() + ", expected one of " + inputNames.toString()); } } String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()]; @@ -387,7 +395,8 @@ public Result run(Map inputs, Set requ outputHandles[i] = getHandle(e.getValue()); i++; } else { - throw new OrtException("Unknown output name " + e.getKey() + ", expected one of " + outputNames.toString()); + throw new OrtException( + "Unknown output name " + e.getKey() + ", expected one of " + outputNames.toString()); } } for (String s : requestedOutputs) { @@ -399,16 +408,30 @@ public Result run(Map inputs, Set requ // them. i++; } else { - throw new OrtException("Output '" + s + "' was found in both the requested outputs and the pinned outputs"); + throw new OrtException( + "Output '" + + s + + "' was found in both the requested outputs and the pinned outputs"); } } else { - throw new OrtException("Unknown output name " + s + ", expected one of " + outputNames.toString()); + throw new OrtException( + "Unknown output name " + s + ", expected one of " + outputNames.toString()); } } long runOptionsHandle = runOptions == null ? 0 : runOptions.getNativeHandle(); - boolean[] ownedByResult = run(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle, inputNamesArray, - inputHandles, inputNamesArray.length, outputNamesArray, outputNamesArray.length, outputValues, outputHandles, - runOptionsHandle); + boolean[] ownedByResult = + run( + OnnxRuntime.ortApiHandle, + nativeHandle, + allocator.handle, + inputNamesArray, + inputHandles, + inputNamesArray.length, + outputNamesArray, + outputNamesArray.length, + outputValues, + outputHandles, + runOptionsHandle); return new Result(outputNamesArray, outputValues, ownedByResult); } else { throw new IllegalStateException("Trying to score a closed OrtSession."); @@ -435,7 +458,8 @@ static long getHandle(OnnxValue v) { return ((OnnxMap) v).nativeHandle; } else { throw new IllegalArgumentException( - "Unexpected OnnxValue subclass, should be {OnnxTensorLike, OnnxSequence, OnnxMap}, found " + v.getClass()); + "Unexpected OnnxValue subclass, should be {OnnxTensorLike, OnnxSequence, OnnxMap}, found " + + v.getClass()); } } @@ -465,9 +489,7 @@ public long getProfilingStartTimeInNs() throws OrtException { /** * Ends the profiling session and returns the output of the profiler. * - *

- * Profiling should be enabled in the {@link SessionOptions} used to construct - * this {@code + *

Profiling should be enabled in the {@link SessionOptions} used to construct this {@code * Session}. * * @return The profiling output. @@ -513,92 +535,108 @@ private static Map wrapInMap(NodeInfo[] infos) { return output; } - private static native long createSession(long apiHandle, long envHandle, String modelPath, long optsHandle) - throws OrtException; + private static native long createSession( + long apiHandle, long envHandle, String modelPath, long optsHandle) throws OrtException; - private static native long createSession(long apiHandle, long envHandle, byte[] modelArray, long optsHandle) - throws OrtException; + private static native long createSession( + long apiHandle, long envHandle, byte[] modelArray, long optsHandle) throws OrtException; - private static native long createSession(long apiHandle, long envHandle, ByteBuffer modelBuffer, int bufferPos, - int bufferSize, long optsHandle) throws OrtException; + private static native long createSession( + long apiHandle, + long envHandle, + ByteBuffer modelBuffer, + int bufferPos, + int bufferSize, + long optsHandle) + throws OrtException; private native long getNumInputs(long apiHandle, long nativeHandle) throws OrtException; - private native String[] getInputNames(long apiHandle, long nativeHandle, long allocatorHandle) throws OrtException; + private native String[] getInputNames(long apiHandle, long nativeHandle, long allocatorHandle) + throws OrtException; - private native NodeInfo[] getInputInfo(long apiHandle, long nativeHandle, long allocatorHandle) throws OrtException; + private native NodeInfo[] getInputInfo(long apiHandle, long nativeHandle, long allocatorHandle) + throws OrtException; private native long getNumOutputs(long apiHandle, long nativeHandle) throws OrtException; - private native String[] getOutputNames(long apiHandle, long nativeHandle, long allocatorHandle) throws OrtException; + private native String[] getOutputNames(long apiHandle, long nativeHandle, long allocatorHandle) + throws OrtException; - private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long allocatorHandle) throws OrtException; + private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long allocatorHandle) + throws OrtException; /** - * The native run call. runOptionsHandle can be zero (i.e. the null pointer), - * outputValues can contain null entries, and outputHandles can contain zero - * values (i.e. the null pointer), but all other handles must be valid pointers. + * The native run call. runOptionsHandle can be zero (i.e. the null pointer), outputValues can + * contain null entries, and outputHandles can contain zero values (i.e. the null pointer), but + * all other handles must be valid pointers. * - * @param apiHandle The pointer to the api. - * @param nativeHandle The pointer to the session. - * @param allocatorHandle The pointer to the allocator. - * @param inputNamesArray The input names. - * @param inputs The input tensors. - * @param numInputs The number of inputs. + * @param apiHandle The pointer to the api. + * @param nativeHandle The pointer to the session. + * @param allocatorHandle The pointer to the allocator. + * @param inputNamesArray The input names. + * @param inputs The input tensors. + * @param numInputs The number of inputs. * @param outputNamesArray The requested output names. - * @param outputValues The OnnxValue output array. - * @param outputHandles The OrtValue output pointer array. - * @param numOutputs The number of requested outputs. + * @param outputValues The OnnxValue output array. + * @param outputHandles The OrtValue output pointer array. + * @param numOutputs The number of requested outputs. * @param runOptionsHandle The (possibly null) pointer to the run options. - * @return A boolean array representing if the OnnxValues were allocated by this - * run call. + * @return A boolean array representing if the OnnxValues were allocated by this run call. * @throws OrtException If the native call failed in some way. */ - private native boolean[] run(long apiHandle, long nativeHandle, long allocatorHandle, String[] inputNamesArray, - long[] inputs, long numInputs, String[] outputNamesArray, long numOutputs, OnnxValue[] outputValues, - long[] outputHandles, long runOptionsHandle) throws OrtException; + private native boolean[] run( + long apiHandle, + long nativeHandle, + long allocatorHandle, + String[] inputNamesArray, + long[] inputs, + long numInputs, + String[] outputNamesArray, + long numOutputs, + OnnxValue[] outputValues, + long[] outputHandles, + long runOptionsHandle) + throws OrtException; - private native long getProfilingStartTimeInNs(long apiHandle, long nativeHandle) throws OrtException; + private native long getProfilingStartTimeInNs(long apiHandle, long nativeHandle) + throws OrtException; - private native String endProfiling(long apiHandle, long nativeHandle, long allocatorHandle) throws OrtException; + private native String endProfiling(long apiHandle, long nativeHandle, long allocatorHandle) + throws OrtException; private native void closeSession(long apiHandle, long nativeHandle) throws OrtException; /** * Builds the {@link OnnxModelMetadata} for this session. * - * @param ortApiHandle The api pointer. - * @param nativeHandle The native session pointer. + * @param ortApiHandle The api pointer. + * @param nativeHandle The native session pointer. * @param allocatorHandle The OrtAllocator pointer. * @return The metadata. - * @throws OrtException If the native runtime failed to access or allocate the - * metadata. + * @throws OrtException If the native runtime failed to access or allocate the metadata. */ - private native OnnxModelMetadata constructMetadata(long ortApiHandle, long nativeHandle, long allocatorHandle) - throws OrtException; + private native OnnxModelMetadata constructMetadata( + long ortApiHandle, long nativeHandle, long allocatorHandle) throws OrtException; /** * Represents the options used to construct this session. * - *

- * Used to set the number of threads, optimisation level, computation backend - * and other options. + *

Used to set the number of threads, optimisation level, computation backend and other + * options. * - *

- * Modifying this after the session has been constructed will have no effect. + *

Modifying this after the session has been constructed will have no effect. * - *

- * The SessionOptions object must not be closed until all sessions which use it - * are closed, as otherwise it could release resources that are in use. + *

The SessionOptions object must not be closed until all sessions which use it are closed, as + * otherwise it could release resources that are in use. */ public static class SessionOptions implements AutoCloseable { /** - * The optimisation level to use. Needs to be kept in sync with the - * GraphOptimizationLevel enum in the C API. + * The optimisation level to use. Needs to be kept in sync with the GraphOptimizationLevel enum + * in the C API. * - *

- * See See Graph * Optimizations for more details. */ @@ -606,13 +644,13 @@ public enum OptLevel { /** Apply no optimizations to the ONNX graph. */ NO_OPT(0), /** - * Apply basic optimizations such as constant folding, redundant computation - * elimination and node fusions to the ONNX graph. + * Apply basic optimizations such as constant folding, redundant computation elimination and + * node fusions to the ONNX graph. */ BASIC_OPT(1), /** - * Applies all the basic optimizations plus more complex node fusion operations - * to the ONNX graph. + * Applies all the basic optimizations plus more complex node fusion operations to the ONNX + * graph. */ EXTENDED_OPT(2), /** Applies all available optimizations to the ONNX graph. */ @@ -635,16 +673,14 @@ public int getID() { } /** - * The execution mode to use. Needs to be kept in sync with the ExecutionMode - * enum in the C API. + * The execution mode to use. Needs to be kept in sync with the ExecutionMode enum in the C API. */ public enum ExecutionMode { /** * Executes all nodes sequentially. * - *

- * This is the default, and usually provides the most speedup as intra-op - * parallelism provides the most benefit. + *

This is the default, and usually provides the most speedup as intra-op parallelism + * provides the most benefit. */ SEQUENTIAL(0), /** Executes some nodes in parallel. */ @@ -707,10 +743,7 @@ public void close() { } } - /** - * Checks if the SessionOptions is closed, if so throws - * {@link IllegalStateException}. - */ + /** Checks if the SessionOptions is closed, if so throws {@link IllegalStateException}. */ private void checkClosed() { if (closed) { throw new IllegalStateException("Trying to use a closed SessionOptions"); @@ -738,8 +771,7 @@ public void setExecutionMode(ExecutionMode mode) throws OrtException { } /** - * Sets the optimization level of this options object, overriding the old - * setting. + * Sets the optimization level of this options object, overriding the old setting. * * @param level The optimization level to use. * @throws OrtException If there was an error in native code. @@ -750,8 +782,8 @@ public void setOptimizationLevel(OptLevel level) throws OrtException { } /** - * Sets the size of the CPU thread pool used for executing multiple request - * concurrently, if executing on a CPU. + * Sets the size of the CPU thread pool used for executing multiple request concurrently, if + * executing on a CPU. * * @param numThreads The number of threads to use. * @throws OrtException If there was an error in native code. @@ -762,8 +794,8 @@ public void setInterOpNumThreads(int numThreads) throws OrtException { } /** - * Sets the size of the CPU thread pool used for executing a single graph, if - * executing on a CPU. + * Sets the size of the CPU thread pool used for executing a single graph, if executing on a + * CPU. * * @param numThreads The number of threads to use. * @throws OrtException If there was an error in native code. @@ -817,22 +849,22 @@ public void disableProfiling() throws OrtException { } /** - * Turns on memory pattern optimizations, where memory is preallocated if all - * shapes are known. + * Turns on memory pattern optimizations, where memory is preallocated if all shapes are known. * * @param memoryPatternOptimization If true enable memory pattern optimizations. * @throws OrtException If there was an error in native code. */ - public void setMemoryPatternOptimization(boolean memoryPatternOptimization) throws OrtException { + public void setMemoryPatternOptimization(boolean memoryPatternOptimization) + throws OrtException { checkClosed(); - setMemoryPatternOptimization(OnnxRuntime.ortApiHandle, nativeHandle, memoryPatternOptimization); + setMemoryPatternOptimization( + OnnxRuntime.ortApiHandle, nativeHandle, memoryPatternOptimization); } /** * Sets the CPU to use an arena memory allocator. * - * @param useArena If true use an arena memory allocator for the CPU execution - * provider. + * @param useArena If true use an arena memory allocator for the CPU execution provider. * @throws OrtException If there was an error in native code. */ public void setCPUArenaAllocator(boolean useArena) throws OrtException { @@ -863,8 +895,7 @@ public void setSessionLogVerbosityLevel(int logLevel) throws OrtException { } /** - * Registers a library of custom ops for use with {@link OrtSession}s using this - * SessionOptions. + * Registers a library of custom ops for use with {@link OrtSession}s using this SessionOptions. * * @param path The path to the library on disk. * @throws OrtException If there was an error loading the library. @@ -877,26 +908,21 @@ public void registerCustomOpLibrary(String path) throws OrtException { } /** - * Registers custom ops for use with {@link OrtSession}s using this - * SessionOptions by calling the specified native function name. The custom ops - * library must either be linked against, or have previously been loaded by the - * user. + * Registers custom ops for use with {@link OrtSession}s using this SessionOptions by calling + * the specified native function name. The custom ops library must either be linked against, or + * have previously been loaded by the user. * - *

- * The registration function must have the signature: + *

The registration function must have the signature: * - *

- *  OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api); + *

 OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api); * - *

- * See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for - * more information on custom ops. See + *

See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for more + * information on custom ops. See * https://github.com/microsoft/onnxruntime/blob/342a5bf2b756d1a1fc6fdc582cfeac15182632fe/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc#L115 * for an example of a custom op library registration function. * * @param registrationFuncName The name of the registration function to call. - * @throws OrtException If there was an error finding or calling the - * registration function. + * @throws OrtException If there was an error finding or calling the registration function. */ public void registerCustomOpsUsingFunction(String registrationFuncName) throws OrtException { checkClosed(); @@ -904,25 +930,25 @@ public void registerCustomOpsUsingFunction(String registrationFuncName) throws O } /** - * Sets the value of a symbolic dimension. Fixed dimension computations may have - * more optimizations applied to them. + * Sets the value of a symbolic dimension. Fixed dimension computations may have more + * optimizations applied to them. * - * @param dimensionName The name of the symbolic dimension. + * @param dimensionName The name of the symbolic dimension. * @param dimensionValue The value to set that dimension to. * @throws OrtException If there was an error in native code. */ - public void setSymbolicDimensionValue(String dimensionName, long dimensionValue) throws OrtException { + public void setSymbolicDimensionValue(String dimensionName, long dimensionValue) + throws OrtException { checkClosed(); - addFreeDimensionOverrideByName(OnnxRuntime.ortApiHandle, nativeHandle, dimensionName, dimensionValue); + addFreeDimensionOverrideByName( + OnnxRuntime.ortApiHandle, nativeHandle, dimensionName, dimensionValue); } /** * Set whether to use deterministic compute. * - *

- * Default is false. If set to true, this will enable deterministic compute for - * GPU kernels where possible. Note that this most likely will have a - * performance cost. + *

Default is false. If set to true, this will enable deterministic compute for GPU kernels + * where possible. Note that this most likely will have a performance cost. * * @param value Should the compute be deterministic? * @throws OrtException If there was an error in native code. @@ -933,8 +959,8 @@ public void setDeterministicCompute(boolean value) throws OrtException { } /** - * Disables the per session thread pools. Must be used in conjunction with an - * environment containing global thread pools. + * Disables the per session thread pools. Must be used in conjunction with an environment + * containing global thread pools. * * @throws OrtException If there was an error in native code. */ @@ -946,7 +972,7 @@ public void disablePerSessionThreads() throws OrtException { /** * Adds a single session configuration entry as a pair of strings. * - * @param configKey The config key string. + * @param configKey The config key string. * @param configValue The config value string. * @throws OrtException If there was an error in native code. */ @@ -957,8 +983,7 @@ public void addConfigEntry(String configKey, String configValue) throws OrtExcep } /** - * Returns an unmodifiable view of the map contains all session configuration - * entries. + * Returns an unmodifiable view of the map contains all session configuration entries. * * @return All session configuration entries */ @@ -970,18 +995,17 @@ public Map getConfigEntries() { /** * Adds in the supplied externally loaded initializers. * - *

- * Note the initializers are copied into the session once it has been created, - * and the native references are removed from this {@code SessionOptions}. Once - * the session has been created those initializers can be closed. This is a - * different lifetime to initializers added via - * {@link #addInitializer(String, OnnxTensorLike)}. The initializers must be - * created from {@link java.nio.Buffer} objects. + *

Note the initializers are copied into the session once it has been created, and the native + * references are removed from this {@code SessionOptions}. Once the session has been created + * those initializers can be closed. This is a different lifetime to initializers added via + * {@link #addInitializer(String, OnnxTensorLike)}. The initializers must be created from {@link + * java.nio.Buffer} objects. * * @param initializers The map of names to initializers. * @throws OrtException If the initializers could not be loaded. */ - public void addExternalInitializers(Map initializers) throws OrtException { + public void addExternalInitializers(Map initializers) + throws OrtException { checkClosed(); if (initializers.isEmpty()) { return; @@ -1000,16 +1024,13 @@ public void addExternalInitializers(Map initializers) th /** * Adds an initializer to override one from the ONNX model. * - *

- * Note the initializer lifetime must outlive the session and session options. - * This is a different lifetime to initializers added via - * {@link #addExternalInitializers(Map)}. The initializers must be created from - * {@link java.nio.Buffer} objects. + *

Note the initializer lifetime must outlive the session and session options. This is a + * different lifetime to initializers added via {@link #addExternalInitializers(Map)}. The + * initializers must be created from {@link java.nio.Buffer} objects. * - * @param name The initializer name. + * @param name The initializer name. * @param initializer The initializer value. - * @throws OrtException If the initializer could not be loaded into the session - * options. + * @throws OrtException If the initializer could not be loaded into the session options. */ public void addInitializer(String name, OnnxTensorLike initializer) throws OrtException { checkClosed(); @@ -1039,7 +1060,8 @@ public void addCUDA(int deviceNum) throws OrtException { if (OnnxRuntime.extractCUDA()) { addCUDA(OnnxRuntime.ortApiHandle, nativeHandle, deviceNum); } else { - throw new OrtException(OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find CUDA shared provider"); + throw new OrtException( + OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find CUDA shared provider"); } } @@ -1056,7 +1078,8 @@ public void addCUDA(OrtCUDAProviderOptions cudaOpts) throws OrtException { ((OrtProviderOptions) cudaOpts).applyToNative(); addCUDAV2(OnnxRuntime.ortApiHandle, nativeHandle, cudaOpts.nativeHandle); } else { - throw new OrtException(OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find CUDA shared provider"); + throw new OrtException( + OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find CUDA shared provider"); } } @@ -1080,16 +1103,16 @@ public void addROCM(int deviceNum) throws OrtException { if (OnnxRuntime.extractROCM()) { addROCM(OnnxRuntime.ortApiHandle, nativeHandle, deviceNum); } else { - throw new OrtException(OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find ROCM shared provider"); + throw new OrtException( + OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find ROCM shared provider"); } } /** * Adds the CPU as an execution backend, using the arena allocator if desired. * - *

- * By default this backend is used, but if other backends are requested, it - * should be requested last. + *

By default this backend is used, but if other backends are requested, it should be + * requested last. * * @param useArena If true use the arena memory allocator. * @throws OrtException If there was an error in native code. @@ -1110,7 +1133,8 @@ public void addDnnl(boolean useArena) throws OrtException { if (OnnxRuntime.extractDNNL()) { addDnnl(OnnxRuntime.ortApiHandle, nativeHandle, useArena ? 1 : 0); } else { - throw new OrtException(OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find DNNL shared provider"); + throw new OrtException( + OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find DNNL shared provider"); } } @@ -1125,7 +1149,8 @@ public void addOpenVINO(String deviceId) throws OrtException { if (OnnxRuntime.extractOpenVINO()) { addOpenVINO(OnnxRuntime.ortApiHandle, nativeHandle, deviceId); } else { - throw new OrtException(OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find OpenVINO shared provider"); + throw new OrtException( + OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find OpenVINO shared provider"); } } @@ -1140,7 +1165,8 @@ public void addTensorrt(int deviceNum) throws OrtException { if (OnnxRuntime.extractTensorRT()) { addTensorrt(OnnxRuntime.ortApiHandle, nativeHandle, deviceNum); } else { - throw new OrtException(OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find TensorRT shared provider"); + throw new OrtException( + OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find TensorRT shared provider"); } } @@ -1157,7 +1183,8 @@ public void addTensorrt(OrtTensorRTProviderOptions tensorRTOpts) throws OrtExcep ((OrtProviderOptions) tensorRTOpts).applyToNative(); addTensorrtV2(OnnxRuntime.ortApiHandle, nativeHandle, tensorRTOpts.nativeHandle); } else { - throw new OrtException(OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find TensorRT shared provider"); + throw new OrtException( + OrtException.OrtErrorCode.ORT_EP_FAIL, "Failed to find TensorRT shared provider"); } } @@ -1246,12 +1273,11 @@ public void addCoreML(EnumSet flags) throws OrtException { } /** - * Adds Xnnpack as an execution backend. Needs to list all options hereif a new - * option supported. current supported options: {} The maximum number of - * provider options is set to 128 (see addExecutionProvider's comment). This - * number is controlled by ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH in - * ai_onnxruntime_OrtSession_SessionOptions.c. If 128 is not enough, please - * increase it or implementing an incremental way to add more options. + * Adds Xnnpack as an execution backend. Needs to list all options hereif a new option + * supported. current supported options: {} The maximum number of provider options is set to 128 + * (see addExecutionProvider's comment). This number is controlled by + * ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH in ai_onnxruntime_OrtSession_SessionOptions.c. If 128 is + * not enough, please increase it or implementing an incremental way to add more options. * * @param providerOptions options pass to XNNPACK EP for initialization. * @throws OrtException If there was an error in native code. @@ -1266,62 +1292,76 @@ public void addXnnpack(Map providerOptions) throws OrtException providerOptionVal[i] = entry.getValue(); i++; } - addExecutionProvider(OnnxRuntime.ortApiHandle, nativeHandle, "XNNPACK", providerOptionKey, providerOptionVal); + addExecutionProvider( + OnnxRuntime.ortApiHandle, nativeHandle, "XNNPACK", providerOptionKey, providerOptionVal); } - private native void setExecutionMode(long apiHandle, long nativeHandle, int mode) throws OrtException; - - private native void setOptimizationLevel(long apiHandle, long nativeHandle, int level) throws OrtException; + private native void setExecutionMode(long apiHandle, long nativeHandle, int mode) + throws OrtException; - private native void setInterOpNumThreads(long apiHandle, long nativeHandle, int numThreads) throws OrtException; + private native void setOptimizationLevel(long apiHandle, long nativeHandle, int level) + throws OrtException; - private native void setIntraOpNumThreads(long apiHandle, long nativeHandle, int numThreads) throws OrtException; + private native void setInterOpNumThreads(long apiHandle, long nativeHandle, int numThreads) + throws OrtException; - private native void setOptimizationModelFilePath(long apiHandle, long nativeHandle, String modelPath) + private native void setIntraOpNumThreads(long apiHandle, long nativeHandle, int numThreads) throws OrtException; + private native void setOptimizationModelFilePath( + long apiHandle, long nativeHandle, String modelPath) throws OrtException; + private native long createOptions(long apiHandle); - private native void setLoggerId(long apiHandle, long nativeHandle, String loggerId) throws OrtException; + private native void setLoggerId(long apiHandle, long nativeHandle, String loggerId) + throws OrtException; - private native void enableProfiling(long apiHandle, long nativeHandle, String filePrefix) throws OrtException; + private native void enableProfiling(long apiHandle, long nativeHandle, String filePrefix) + throws OrtException; private native void disableProfiling(long apiHandle, long nativeHandle) throws OrtException; - private native void setMemoryPatternOptimization(long apiHandle, long nativeHandle, - boolean memoryPatternOptimization) throws OrtException; + private native void setMemoryPatternOptimization( + long apiHandle, long nativeHandle, boolean memoryPatternOptimization) throws OrtException; - private native void setCPUArenaAllocator(long apiHandle, long nativeHandle, boolean useArena) throws OrtException; + private native void setCPUArenaAllocator(long apiHandle, long nativeHandle, boolean useArena) + throws OrtException; - private native void setSessionLogLevel(long apiHandle, long nativeHandle, int logLevel) throws OrtException; + private native void setSessionLogLevel(long apiHandle, long nativeHandle, int logLevel) + throws OrtException; private native void setSessionLogVerbosityLevel(long apiHandle, long nativeHandle, int logLevel) throws OrtException; - private native long registerCustomOpLibrary(long apiHandle, long nativeHandle, String path) throws OrtException; - - private native void registerCustomOpsUsingFunction(long apiHandle, long nativeHandle, String registrationFuncName) + private native long registerCustomOpLibrary(long apiHandle, long nativeHandle, String path) throws OrtException; + private native void registerCustomOpsUsingFunction( + long apiHandle, long nativeHandle, String registrationFuncName) throws OrtException; + private native void closeCustomLibraries(long[] nativeHandle); private native void closeOptions(long apiHandle, long nativeHandle); - private native void setDeterministicCompute(long apiHandle, long nativeHandle, boolean isDeterministic) - throws OrtException; - - private native void addFreeDimensionOverrideByName(long apiHandle, long nativeHandle, String dimensionName, - long dimensionValue) throws OrtException; + private native void setDeterministicCompute( + long apiHandle, long nativeHandle, boolean isDeterministic) throws OrtException; - private native void addExternalInitializers(long apiHandle, long nativeHandle, String[] names, long[] tensorHandles) + private native void addFreeDimensionOverrideByName( + long apiHandle, long nativeHandle, String dimensionName, long dimensionValue) throws OrtException; - private native void addInitializer(long apiHandle, long nativeHandle, String name, long tensorHandle) + private native void addExternalInitializers( + long apiHandle, long nativeHandle, String[] names, long[] tensorHandles) throws OrtException; - private native void disablePerSessionThreads(long apiHandle, long nativeHandle) throws OrtException; + private native void addInitializer( + long apiHandle, long nativeHandle, String name, long tensorHandle) throws OrtException; + + private native void disablePerSessionThreads(long apiHandle, long nativeHandle) + throws OrtException; - private native void addConfigEntry(long apiHandle, long nativeHandle, String configKey, String configValue) + private native void addConfigEntry( + long apiHandle, long nativeHandle, String configKey, String configValue) throws OrtException; /* @@ -1340,44 +1380,60 @@ private native void addConfigEntry(long apiHandle, long nativeHandle, String con */ private native void addCPU(long apiHandle, long nativeHandle, int useArena) throws OrtException; - private native void addCUDA(long apiHandle, long nativeHandle, int deviceNum) throws OrtException; + private native void addCUDA(long apiHandle, long nativeHandle, int deviceNum) + throws OrtException; - private native void addCUDAV2(long apiHandle, long nativeHandle, long cudaOptsHandle) throws OrtException; + private native void addCUDAV2(long apiHandle, long nativeHandle, long cudaOptsHandle) + throws OrtException; - private native void addROCM(long apiHandle, long nativeHandle, int deviceNum) throws OrtException; + private native void addROCM(long apiHandle, long nativeHandle, int deviceNum) + throws OrtException; - private native void addDnnl(long apiHandle, long nativeHandle, int useArena) throws OrtException; + private native void addDnnl(long apiHandle, long nativeHandle, int useArena) + throws OrtException; - private native void addOpenVINO(long apiHandle, long nativeHandle, String deviceId) throws OrtException; + private native void addOpenVINO(long apiHandle, long nativeHandle, String deviceId) + throws OrtException; - private native void addTensorrt(long apiHandle, long nativeHandle, int deviceNum) throws OrtException; + private native void addTensorrt(long apiHandle, long nativeHandle, int deviceNum) + throws OrtException; - private native void addTensorrtV2(long apiHandle, long nativeHandle, long tensorrtOptsHandle) throws OrtException; + private native void addTensorrtV2(long apiHandle, long nativeHandle, long tensorrtOptsHandle) + throws OrtException; - private native void addNnapi(long apiHandle, long nativeHandle, int nnapiFlags) throws OrtException; + private native void addNnapi(long apiHandle, long nativeHandle, int nnapiFlags) + throws OrtException; - private native void addTvm(long apiHandle, long nativeHandle, String settings) throws OrtException; + private native void addTvm(long apiHandle, long nativeHandle, String settings) + throws OrtException; - private native void addDirectML(long apiHandle, long nativeHandle, int deviceId) throws OrtException; + private native void addDirectML(long apiHandle, long nativeHandle, int deviceId) + throws OrtException; - private native void addACL(long apiHandle, long nativeHandle, boolean enableFastMath) throws OrtException; + private native void addACL(long apiHandle, long nativeHandle, boolean enableFastMath) + throws OrtException; - private native void addArmNN(long apiHandle, long nativeHandle, int useArena) throws OrtException; + private native void addArmNN(long apiHandle, long nativeHandle, int useArena) + throws OrtException; - private native void addCoreML(long apiHandle, long nativeHandle, int coreMLFlags) throws OrtException; + private native void addCoreML(long apiHandle, long nativeHandle, int coreMLFlags) + throws OrtException; /* * The max length of providerOptionKey and providerOptionVal is 128, as * specified by ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH (search ONNXRuntime PR #14067 * for its location). */ - private native void addExecutionProvider(long apiHandle, long nativeHandle, String epName, - String[] providerOptionKey, String[] providerOptionVal) throws OrtException; + private native void addExecutionProvider( + long apiHandle, + long nativeHandle, + String epName, + String[] providerOptionKey, + String[] providerOptionVal) + throws OrtException; } - /** - * Used to control logging and termination of a call to {@link OrtSession#run}. - */ + /** Used to control logging and termination of a call to {@link OrtSession#run}. */ public static class RunOptions implements AutoCloseable { static { @@ -1477,10 +1533,8 @@ public String getRunTag() throws OrtException { } /** - * Sets a flag so that all incomplete {@link OrtSession#run} calls using this - * instance of {@code - * RunOptions} will terminate as soon as possible. If the flag is false, it - * resets this {@code + * Sets a flag so that all incomplete {@link OrtSession#run} calls using this instance of {@code + * RunOptions} will terminate as soon as possible. If the flag is false, it resets this {@code * RunOptions} so it can be used with other calls to {@link OrtSession#run}. * * @param terminate If true terminate all runs associated with this RunOptions. @@ -1494,10 +1548,9 @@ public void setTerminate(boolean terminate) throws OrtException { /** * Adds a configuration entry to this {@code RunOptions}. * - *

- * Setting the same key will overwrite the value. + *

Setting the same key will overwrite the value. * - * @param key The configuration key. + * @param key The configuration key. * @param value The configuration value. * @throws OrtException If the native library call failed. */ @@ -1508,21 +1561,19 @@ public void addRunConfigEntry(String key, String value) throws OrtException { /** * Adds the specified adapter to the list of active adapters - * + * *

- * + * * @param loraAdapter valid LoraAdapter object * @throws OrtException of the native library call failed */ - public void addActiveLoraAdapter(LoraAdapter loraAdapter) throws OrtException { + public void addActiveLoraAdapter(OrtLoraAdapter loraAdapter) throws OrtException { checkClosed(); + loraAdapter.checkClosed(); addActiveLoraAdapter(OnnxRuntime.ortApiHandle, nativeHandle, loraAdapter.getNativeHandle()); } - /** - * Checks if the RunOptions is closed, if so throws - * {@link IllegalStateException}. - */ + /** Checks if the RunOptions is closed, if so throws {@link IllegalStateException}. */ private void checkClosed() { if (closed) { throw new IllegalStateException("Trying to use a closed RunOptions"); @@ -1541,106 +1592,43 @@ public void close() { private static native long createRunOptions(long apiHandle) throws OrtException; - private native void setLogLevel(long apiHandle, long nativeHandle, int logLevel) throws OrtException; + private native void setLogLevel(long apiHandle, long nativeHandle, int logLevel) + throws OrtException; private native int getLogLevel(long apiHandle, long nativeHandle) throws OrtException; - private native void setLogVerbosityLevel(long apiHandle, long nativeHandle, int logLevel) throws OrtException; + private native void setLogVerbosityLevel(long apiHandle, long nativeHandle, int logLevel) + throws OrtException; private native int getLogVerbosityLevel(long apiHandle, long nativeHandle) throws OrtException; - private native void setRunTag(long apiHandle, long nativeHandle, String runTag) throws OrtException; + private native void setRunTag(long apiHandle, long nativeHandle, String runTag) + throws OrtException; private native String getRunTag(long apiHandle, long nativeHandle) throws OrtException; - private native void setTerminate(long apiHandle, long nativeHandle, boolean terminate) throws OrtException; - - private native void addRunConfigEntry(long apiHandle, long nativeHandle, String key, String value) + private native void setTerminate(long apiHandle, long nativeHandle, boolean terminate) throws OrtException; - private native void addActiveLoraAdapter(long apiHandle, long nativeHandle, long loraAdapterHandle) throws OrtException; + private native void addRunConfigEntry( + long apiHandle, long nativeHandle, String key, String value) throws OrtException; - private static native void close(long apiHandle, long nativeHandle); - } - - public static class LoraAdapter implements AutoCloseable { - static { - try { - OnnxRuntime.init(); - } catch (IOException e) { - throw new RuntimeException("Failed to load onnx-runtime library", e); - } - } - - private final long nativeHandle; - - private boolean closed = false; - - private LoraAdapter(long nativeHandle) { - this.nativeHandle = nativeHandle; - } - - /** - * Creates an instance of LoraAdapter. - * - * @param absoluteAdapterPath path to the adapter file that is going to be - * memory mapped - * @param allocator optional allocator or null. If supplied, adapter parameters - * are copied to the allocator memory - * @throws OrtException - */ - public static LoraAdapter Create(String absoluteAdapterPath, OrtAllocator allocator) throws OrtException { - return new LoraAdapter(createLoraAdapter(OnnxRuntime.ortApiHandle, absoluteAdapterPath, allocator)); - } - - /** - * Package accessor for native pointer. - * - * @return The native pointer. - */ - long getNativeHandle() { - return nativeHandle; - } - - /** - * Checks if the LoraAdapter is closed, if so throws - * {@link IllegalStateException}. - */ - private void checkClosed() { - if (closed) { - throw new IllegalStateException("Trying to use a closed LoraAdapter"); - } - } - - @Override - public void close() { - if (!closed) { - close(OnnxRuntime.ortApiHandle, nativeHandle); - closed = true; - } else { - throw new IllegalStateException("Trying to close an already closed LoraAdapter"); - } - } - - private static native long createLoraAdapter(long apiHandle, String adapterPath, - OrtAllocator allocator) throws OrtException; + private native void addActiveLoraAdapter( + long apiHandle, long nativeHandle, long loraAdapterHandle) throws OrtException; private static native void close(long apiHandle, long nativeHandle); } /** - * An {@link AutoCloseable} wrapper around a {@link Map} containing - * {@link OnnxValue}s. + * An {@link AutoCloseable} wrapper around a {@link Map} containing {@link OnnxValue}s. * - *

- * When this is closed it closes all the {@link OnnxValue}s owned by the result - * object. If you maintain a reference to a value after this object has been - * closed it will throw an {@link IllegalStateException} upon access. + *

When this is closed it closes all the {@link OnnxValue}s owned by the result object. If you + * maintain a reference to a value after this object has been closed it will throw an {@link + * IllegalStateException} upon access. * - *

- * {@link OnnxValue}s which are supplied as pinned outputs to a {@code run} call - * are not closed by the {@link Result#close()} method. Ownership of each output - * can be checked with {@link Result#isResultOwner(int)}. + *

{@link OnnxValue}s which are supplied as pinned outputs to a {@code run} call are not closed + * by the {@link Result#close()} method. Ownership of each output can be checked with {@link + * Result#isResultOwner(int)}. */ public static class Result implements AutoCloseable, Iterable> { @@ -1655,17 +1643,20 @@ public static class Result implements AutoCloseable, IterableGetStringChars(jniEnv, loraPath, NULL); + size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, loraPath); + wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); + if (newString == NULL) { + (*jniEnv)->ReleaseStringChars(jniEnv, loraPath, cPath); + throwOrtException(jniEnv, 1, "Not enough memory"); + return 0; + } + wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); + checkOrtStatus(jniEnv, api, api->CreateLoraAdapter(newString, allocator, &lora)); + free(newString); + (*jniEnv)->ReleaseStringChars(jniEnv, loraPath, cPath); +#else + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, loraPath, NULL); + checkOrtStatus(jniEnv, api, api->CreateLoraAdapter(cPath, allocator, &lora)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, loraPath, cPath); +#endif + + return (jlong) lora; +} + +/* + * Class: ai_onnxruntime_OrtLoraAdapter + * Method: close + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtLoraAdapter_close + (JNIEnv * env, jclass clazz, jlong apiHandle, jlong loraHandle) { + (void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + api->ReleaseLoraAdapter((OrtLoraAdapter*) handle); +} + diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_RunOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_RunOptions.c index 0ab9bbb889209..3cbe2643716ad 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_RunOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_RunOptions.c @@ -124,6 +124,18 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_addRunConf (*jniEnv)->ReleaseStringUTFChars(jniEnv, valueStr, value); } +/* + * Class: ai_onnxruntime_OrtSession_RunOptions + * Method: addActiveLoraAdapter + * Signature: (JJJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_addActiveLoraAdapter + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jlong loraHandle) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + checkOrtStatus(jniEnv, api, api->RunOptionsAddActiveLoraAdapter((OrtRunOptions*) nativeHandle, (OrtLoraAdapter*) loraHandle)); +} + /* * Class: ai_onnxruntime_OrtSession_RunOptions * Method: setTerminate diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 0eeddd581c339..045d4d5fd9c71 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1299,51 +1299,54 @@ public void testRunWithLoraAdapter() throws OrtException { // XXX: Not sure how exactly to get paths to native testdata // it seems that it is included into resources String modelPath = TestHelpers.getResourcePath("lora/two_params_lora_model.onnx").toString(); - String adapterPath = TestHelpers.getResourcePath("lora/two_params_lora_model.onnx_adapter").toString(); - - var inputShape = new long[] {4, 4}; - var inputData = new float[16]; - // Fixme - Array.fill(inputData, 1.f); - - var expectedOutput = new float[] { - 154.f, 176.f, 198.f, 220.f, - 154.f, 176.f, 198.f, 220.f, - 154.f, 176.f, 198.f, 220.f, - 154.f, 176.f, 198.f, 220.f}; - - try(var session = new env.createSession(modelPath); - var adapter = new OrtSession.LoraAdapter.Create(adapterPath, null); - var runOptions = new OrtSession.RunOptions()) { - - runOptions.addActiveLoraAdapter(adapter); - session.Run(); - } - } - - @Test - public void testRunWithBaseLoraModel() throws OrtException { - // XXX: Not sure how exactly to get paths to native testdata - // it seems that it is included into resources - String modelPath = TestHelpers.getResourcePath("lora/two_params_lora_model.onnx").toString(); - - var inputShape = new long[] {4, 4}; - var inputData = new float[16]; - // Fixme - Array.fill(inputData, 1.f); - - var expectedOutput = new float[] { - 28.f, 32.f, 36.f, 40.f, - 28.f, 32.f, 36.f, 40.f, - 28.f, 32.f, 36.f, 40.f, - 28.f, 32.f, 36.f, 40.f}; - - // See C# tests - try(var session = new env.createSession(modelPath); - var adapter = new OrtSession.LoraAdapter.Create(adapterPath, null); - { - session.Run(); + String adapterPath = + TestHelpers.getResourcePath("lora/two_params_lora_model.onnx_adapter").toString(); + + long[] inputShape = new long[] {4, 4}; + float[] inputData = new float[16]; + Arrays.fill(inputData, 1.f); + FloatBuffer buf = + ByteBuffer.allocateDirect(Float.BYTES).order(ByteOrder.nativeOrder()).asFloatBuffer(); + buf.put(inputData); + buf.rewind(); + + float[][] expectedOutput = + new float[][] { + {28.f, 32.f, 36.f, 40.f}, + {28.f, 32.f, 36.f, 40.f}, + {28.f, 32.f, 36.f, 40.f}, + {28.f, 32.f, 36.f, 40.f} + }; + + float[][] expectedLoRAOutput = + new float[][] { + {154.f, 176.f, 198.f, 220.f}, + {154.f, 176.f, 198.f, 220.f}, + {154.f, 176.f, 198.f, 220.f}, + {154.f, 176.f, 198.f, 220.f} + }; + + try (OrtSession session = env.createSession(modelPath); + OnnxTensor tensor = OnnxTensor.createTensor(env, buf, inputShape)) { + + Map inputs = Collections.singletonMap("input", tensor); + + // Without LoRA + try (OrtSession.Result result = session.run(inputs)) { + float[][] resultArr = (float[][]) result.get(0).getValue(); + Assertions.assertArrayEquals(expectedOutput, resultArr); + } + + // With LoRA + try (OrtLoraAdapter adapter = OrtLoraAdapter.create(adapterPath); + OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) { + runOptions.addActiveLoraAdapter(adapter); + try (OrtSession.Result result = session.run(inputs, runOptions)) { + float[][] resultArr = (float[][]) result.get(0).getValue(); + Assertions.assertArrayEquals(expectedLoRAOutput, resultArr); + } } + } } @Test From 42e13fd86b600f99aa8b4d0ea673426096132ee2 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 27 Sep 2024 10:56:40 -0400 Subject: [PATCH 04/11] Fixing copyright. --- java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java b/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java index 10b44cfb7def2..8b9c253d8b6f2 100644 --- a/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java +++ b/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java @@ -1,5 +1,6 @@ /* - * Copyright © 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; From 6a9630f875e17c706d6daf2e31cb566d3d0e47e5 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 27 Sep 2024 12:39:31 -0400 Subject: [PATCH 05/11] Fixing compile errors in OrtLoraAdapter.c --- java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c b/java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c index 14cb91bdf87b5..d22e7b502e990 100644 --- a/java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c +++ b/java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c @@ -14,7 +14,7 @@ * Signature: (JLjava/lang/String;J)J */ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapter - (JNIEnv * env, jclass clazz, jlong apiHandle, jstring loraPath, jlong allocatorHandle) { + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jstring loraPath, jlong allocatorHandle) { (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; @@ -48,9 +48,9 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapter * Signature: (JJ)V */ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtLoraAdapter_close - (JNIEnv * env, jclass clazz, jlong apiHandle, jlong loraHandle) { - (void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong loraHandle) { + (void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; - api->ReleaseLoraAdapter((OrtLoraAdapter*) handle); + api->ReleaseLoraAdapter((OrtLoraAdapter*) loraHandle); } From a00fcb19b2e1f092e20aab3f475c56832f23b8d5 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 27 Sep 2024 12:45:26 -0400 Subject: [PATCH 06/11] Fixing test path. --- java/src/test/java/ai/onnxruntime/InferenceTest.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 045d4d5fd9c71..64b7e11ec332f 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1298,15 +1298,15 @@ public void testRunOptions() throws OrtException { public void testRunWithLoraAdapter() throws OrtException { // XXX: Not sure how exactly to get paths to native testdata // it seems that it is included into resources - String modelPath = TestHelpers.getResourcePath("lora/two_params_lora_model.onnx").toString(); + String modelPath = TestHelpers.getResourcePath("/lora/two_params_lora_model.onnx").toString(); String adapterPath = - TestHelpers.getResourcePath("lora/two_params_lora_model.onnx_adapter").toString(); + TestHelpers.getResourcePath("/lora/two_params_lora_model.onnx_adapter").toString(); long[] inputShape = new long[] {4, 4}; float[] inputData = new float[16]; Arrays.fill(inputData, 1.f); FloatBuffer buf = - ByteBuffer.allocateDirect(Float.BYTES).order(ByteOrder.nativeOrder()).asFloatBuffer(); + ByteBuffer.allocateDirect(Float.BYTES*16).order(ByteOrder.nativeOrder()).asFloatBuffer(); buf.put(inputData); buf.rewind(); @@ -1329,7 +1329,7 @@ public void testRunWithLoraAdapter() throws OrtException { try (OrtSession session = env.createSession(modelPath); OnnxTensor tensor = OnnxTensor.createTensor(env, buf, inputShape)) { - Map inputs = Collections.singletonMap("input", tensor); + Map inputs = Collections.singletonMap("input_x", tensor); // Without LoRA try (OrtSession.Result result = session.run(inputs)) { From dbf2e0ace05afbc1530179ca15317a23477da903 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 27 Sep 2024 15:44:55 -0400 Subject: [PATCH 07/11] Adding support for loading a lora from a byte array. --- .../java/ai/onnxruntime/OrtLoraAdapter.java | 71 ++++++++++++++++++- .../native/ai_onnxruntime_OrtLoraAdapter.c | 50 +++++++++++++ .../java/ai/onnxruntime/InferenceTest.java | 43 ++++++++--- 3 files changed, 151 insertions(+), 13 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java b/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java index 8b9c253d8b6f2..da62018319ec8 100644 --- a/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java +++ b/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java @@ -6,6 +6,8 @@ package ai.onnxruntime; import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; /** * A container for an adapter which can be supplied to {@link @@ -29,6 +31,63 @@ private OrtLoraAdapter(long nativeHandle) { this.nativeHandle = nativeHandle; } + /** + * Creates an instance of OrtLoraAdapter from a byte array. + * + * @param loraArray The LoRA stored in a byte array. + * @throws OrtException If the native call failed. + */ + public static OrtLoraAdapter create(byte[] loraArray) throws OrtException { + return create(loraArray, null); + } + + /** + * Creates an instance of OrtLoraAdapter from a byte array. + * + * @param loraArray The LoRA stored in a byte array. + * @param allocator optional allocator or null. If supplied, adapter parameters are copied to the + * allocator memory. + * @throws OrtException If the native call failed. + */ + static OrtLoraAdapter create(byte[] loraArray, OrtAllocator allocator) + throws OrtException { + Objects.requireNonNull(loraArray, "LoRA array must not be null"); + long allocatorHandle = allocator == null ? 0 : allocator.handle; + return new OrtLoraAdapter( + createLoraAdapterFromArray(OnnxRuntime.ortApiHandle, loraArray, allocatorHandle)); + } + + /** + * Creates an instance of OrtLoraAdapter from a direct ByteBuffer. + * + * @param loraBuffer The buffer to load. + * @throws OrtException If the native call failed. + */ + public static OrtLoraAdapter create(ByteBuffer loraBuffer) throws OrtException { + return create(loraBuffer, null); + } + + /** + * Creates an instance of OrtLoraAdapter from a direct ByteBuffer. + * + * @param loraBuffer The buffer to load. + * @param allocator optional allocator or null. If supplied, adapter parameters are copied to the + * allocator memory. + * @throws OrtException If the native call failed. + */ + static OrtLoraAdapter create(ByteBuffer loraBuffer, OrtAllocator allocator) + throws OrtException { + Objects.requireNonNull(loraBuffer, "LoRA buffer must not be null"); + if (loraBuffer.remaining() == 0) { + throw new OrtException("Invalid LoRA buffer, no elements remaining."); + } else if (!loraBuffer.isDirect()) { + throw new OrtException("ByteBuffer is not direct."); + } + long allocatorHandle = allocator == null ? 0 : allocator.handle; + return new OrtLoraAdapter( + createLoraAdapterFromBuffer(OnnxRuntime.ortApiHandle, loraBuffer, loraBuffer.position(), loraBuffer.remaining(), allocatorHandle)); + } + /** * Creates an instance of OrtLoraAdapter. * @@ -63,10 +122,10 @@ long getNativeHandle() { return nativeHandle; } - /** Checks if the LoraAdapter is closed, if so throws {@link IllegalStateException}. */ + /** Checks if the OrtLoraAdapter is closed, if so throws {@link IllegalStateException}. */ void checkClosed() { if (closed) { - throw new IllegalStateException("Trying to use a closed LoraAdapter"); + throw new IllegalStateException("Trying to use a closed OrtLoraAdapter"); } } @@ -76,12 +135,18 @@ public void close() { close(OnnxRuntime.ortApiHandle, nativeHandle); closed = true; } else { - throw new IllegalStateException("Trying to close an already closed LoraAdapter"); + throw new IllegalStateException("Trying to close an already closed OrtLoraAdapter"); } } private static native long createLoraAdapter( long apiHandle, String adapterPath, long allocatorHandle) throws OrtException; + private static native long createLoraAdapterFromArray( + long apiHandle, byte[] loraBytes, long allocatorHandle) throws OrtException; + + private static native long createLoraAdapterFromBuffer( + long apiHandle, ByteBuffer loraBuffer, int bufferPos, int bufferSize, long allocatorHandle) throws OrtException; + private static native void close(long apiHandle, long nativeHandle); } diff --git a/java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c b/java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c index d22e7b502e990..8b1ee82614b15 100644 --- a/java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c +++ b/java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c @@ -42,6 +42,56 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapter return (jlong) lora; } +/* + * Class: ai_onnxruntime_OrtLoraAdapter + * Method: createLoraAdapterFromBuffer + * Signature: (JLjava/nio/ByteBuffer;IIJ)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapterFromBuffer + (JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jobject buffer, jint bufferPos, jint bufferSize, jlong allocatorHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; + OrtLoraAdapter* lora; + + // Extract the buffer + char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer); + // Increment by bufferPos bytes + bufferArr = bufferArr + bufferPos; + + // Create the adapter + checkOrtStatus(jniEnv, api, api->CreateLoraAdapterFromArray((const uint8_t*) bufferArr, bufferSize, allocator, &lora)); + + return (jlong) lora; +} + +/* + * Class: ai_onnxruntime_OrtLoraAdapter + * Method: createLoraAdapterFromArray + * Signature: (J[BJ)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapterFromArray + (JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jbyteArray jLoraArray, jlong allocatorHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; + OrtLoraAdapter* lora; + + size_t loraLength = (*jniEnv)->GetArrayLength(jniEnv, jLoraArray); + if (loraLength == 0) { + throwOrtException(jniEnv, 2, "Invalid LoRA, the byte array is zero length."); + return 0; + } + + // Get a reference to the byte array elements + jbyte* loraArr = (*jniEnv)->GetByteArrayElements(jniEnv, jLoraArray, NULL); + checkOrtStatus(jniEnv, api, api->CreateLoraAdapterFromArray((const uint8_t*) loraArr, loraLength, allocator, &lora)); + // Release the C array. + (*jniEnv)->ReleaseByteArrayElements(jniEnv, jLoraArray, loraArr, JNI_ABORT); + + return (jlong) lora; +} + /* * Class: ai_onnxruntime_OrtLoraAdapter * Method: close diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 64b7e11ec332f..bf501dfaea0cb 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1295,12 +1295,10 @@ public void testRunOptions() throws OrtException { } @Test - public void testRunWithLoraAdapter() throws OrtException { - // XXX: Not sure how exactly to get paths to native testdata - // it seems that it is included into resources - String modelPath = TestHelpers.getResourcePath("/lora/two_params_lora_model.onnx").toString(); - String adapterPath = - TestHelpers.getResourcePath("/lora/two_params_lora_model.onnx_adapter").toString(); + public void testRunWithLoraAdapter() throws IOException, OrtException { + Path modelPath = TestHelpers.getResourcePath("/lora/two_params_lora_model.onnx"); + Path adapterPath = + TestHelpers.getResourcePath("/lora/two_params_lora_model.onnx_adapter"); long[] inputShape = new long[] {4, 4}; float[] inputData = new float[16]; @@ -1326,7 +1324,7 @@ public void testRunWithLoraAdapter() throws OrtException { {154.f, 176.f, 198.f, 220.f} }; - try (OrtSession session = env.createSession(modelPath); + try (OrtSession session = env.createSession(modelPath.toString()); OnnxTensor tensor = OnnxTensor.createTensor(env, buf, inputShape)) { Map inputs = Collections.singletonMap("input_x", tensor); @@ -1337,15 +1335,40 @@ public void testRunWithLoraAdapter() throws OrtException { Assertions.assertArrayEquals(expectedOutput, resultArr); } - // With LoRA - try (OrtLoraAdapter adapter = OrtLoraAdapter.create(adapterPath); - OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) { + // With LoRA from path + try (OrtLoraAdapter adapter = OrtLoraAdapter.create(adapterPath.toString()); + OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) { runOptions.addActiveLoraAdapter(adapter); try (OrtSession.Result result = session.run(inputs, runOptions)) { float[][] resultArr = (float[][]) result.get(0).getValue(); Assertions.assertArrayEquals(expectedLoRAOutput, resultArr); } } + + // With LoRA from array + byte[] loraArray = Files.readAllBytes(adapterPath); + try (OrtLoraAdapter adapter = OrtLoraAdapter.create(loraArray); + OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) { + runOptions.addActiveLoraAdapter(adapter); + try (OrtSession.Result result = session.run(inputs, runOptions)) { + float[][] resultArr = (float[][]) result.get(0).getValue(); + Assertions.assertArrayEquals(expectedLoRAOutput, resultArr); + } + } + + // With LoRA from buffer + ByteBuffer loraBuf = ByteBuffer.allocateDirect(loraArray.length); + loraBuf.put(loraArray); + loraBuf.rewind(); + try (OrtLoraAdapter adapter = OrtLoraAdapter.create(loraBuf); + OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) { + runOptions.addActiveLoraAdapter(adapter); + try (OrtSession.Result result = session.run(inputs, runOptions)) { + float[][] resultArr = (float[][]) result.get(0).getValue(); + Assertions.assertArrayEquals(expectedLoRAOutput, resultArr); + } + } + } } From e329cee15c322a8361c655058b066a9b66027d88 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 27 Sep 2024 15:57:06 -0400 Subject: [PATCH 08/11] Tidying up javadoc and formatting --- .../java/ai/onnxruntime/OrtLoraAdapter.java | 22 ++++++++++++++----- .../main/java/ai/onnxruntime/OrtSession.java | 6 ++--- .../java/ai/onnxruntime/InferenceTest.java | 12 +++++----- 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java b/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java index da62018319ec8..d9f2579e31d22 100644 --- a/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java +++ b/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java @@ -36,6 +36,7 @@ private OrtLoraAdapter(long nativeHandle) { * * @param loraArray The LoRA stored in a byte array. * @throws OrtException If the native call failed. + * @return An OrtLoraAdapter instance. */ public static OrtLoraAdapter create(byte[] loraArray) throws OrtException { return create(loraArray, null); @@ -48,9 +49,9 @@ public static OrtLoraAdapter create(byte[] loraArray) throws OrtException { * @param allocator optional allocator or null. If supplied, adapter parameters are copied to the * allocator memory. * @throws OrtException If the native call failed. + * @return An OrtLoraAdapter instance. */ - static OrtLoraAdapter create(byte[] loraArray, OrtAllocator allocator) - throws OrtException { + static OrtLoraAdapter create(byte[] loraArray, OrtAllocator allocator) throws OrtException { Objects.requireNonNull(loraArray, "LoRA array must not be null"); long allocatorHandle = allocator == null ? 0 : allocator.handle; return new OrtLoraAdapter( @@ -62,6 +63,7 @@ static OrtLoraAdapter create(byte[] loraArray, OrtAllocator allocator) * * @param loraBuffer The buffer to load. * @throws OrtException If the native call failed. + * @return An OrtLoraAdapter instance. */ public static OrtLoraAdapter create(ByteBuffer loraBuffer) throws OrtException { return create(loraBuffer, null); @@ -74,9 +76,9 @@ public static OrtLoraAdapter create(ByteBuffer loraBuffer) throws OrtException { * @param allocator optional allocator or null. If supplied, adapter parameters are copied to the * allocator memory. * @throws OrtException If the native call failed. + * @return An OrtLoraAdapter instance. */ - static OrtLoraAdapter create(ByteBuffer loraBuffer, OrtAllocator allocator) - throws OrtException { + static OrtLoraAdapter create(ByteBuffer loraBuffer, OrtAllocator allocator) throws OrtException { Objects.requireNonNull(loraBuffer, "LoRA buffer must not be null"); if (loraBuffer.remaining() == 0) { throw new OrtException("Invalid LoRA buffer, no elements remaining."); @@ -85,7 +87,12 @@ static OrtLoraAdapter create(ByteBuffer loraBuffer, OrtAllocator allocator) } long allocatorHandle = allocator == null ? 0 : allocator.handle; return new OrtLoraAdapter( - createLoraAdapterFromBuffer(OnnxRuntime.ortApiHandle, loraBuffer, loraBuffer.position(), loraBuffer.remaining(), allocatorHandle)); + createLoraAdapterFromBuffer( + OnnxRuntime.ortApiHandle, + loraBuffer, + loraBuffer.position(), + loraBuffer.remaining(), + allocatorHandle)); } /** @@ -93,6 +100,7 @@ static OrtLoraAdapter create(ByteBuffer loraBuffer, OrtAllocator allocator) * * @param absoluteAdapterPath path to the adapter file that is going to be memory mapped. * @throws OrtException If the native call failed. + * @return An OrtLoraAdapter instance. */ public static OrtLoraAdapter create(String absoluteAdapterPath) throws OrtException { return create(absoluteAdapterPath, null); @@ -105,6 +113,7 @@ public static OrtLoraAdapter create(String absoluteAdapterPath) throws OrtExcept * @param allocator optional allocator or null. If supplied, adapter parameters are copied to the * allocator memory. * @throws OrtException If the native call failed. + * @return An OrtLoraAdapter instance. */ static OrtLoraAdapter create(String absoluteAdapterPath, OrtAllocator allocator) throws OrtException { @@ -146,7 +155,8 @@ private static native long createLoraAdapterFromArray( long apiHandle, byte[] loraBytes, long allocatorHandle) throws OrtException; private static native long createLoraAdapterFromBuffer( - long apiHandle, ByteBuffer loraBuffer, int bufferPos, int bufferSize, long allocatorHandle) throws OrtException; + long apiHandle, ByteBuffer loraBuffer, int bufferPos, int bufferSize, long allocatorHandle) + throws OrtException; private static native void close(long apiHandle, long nativeHandle); } diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 47ee62401fca4..02fb47a421ba7 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -1560,11 +1560,9 @@ public void addRunConfigEntry(String key, String value) throws OrtException { } /** - * Adds the specified adapter to the list of active adapters + * Adds the specified adapter to the list of active adapters for this run. * - *

- * - * @param loraAdapter valid LoraAdapter object + * @param loraAdapter valid OrtLoraAdapter object * @throws OrtException of the native library call failed */ public void addActiveLoraAdapter(OrtLoraAdapter loraAdapter) throws OrtException { diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index bf501dfaea0cb..5999237bea6bc 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1297,14 +1297,13 @@ public void testRunOptions() throws OrtException { @Test public void testRunWithLoraAdapter() throws IOException, OrtException { Path modelPath = TestHelpers.getResourcePath("/lora/two_params_lora_model.onnx"); - Path adapterPath = - TestHelpers.getResourcePath("/lora/two_params_lora_model.onnx_adapter"); + Path adapterPath = TestHelpers.getResourcePath("/lora/two_params_lora_model.onnx_adapter"); long[] inputShape = new long[] {4, 4}; float[] inputData = new float[16]; Arrays.fill(inputData, 1.f); FloatBuffer buf = - ByteBuffer.allocateDirect(Float.BYTES*16).order(ByteOrder.nativeOrder()).asFloatBuffer(); + ByteBuffer.allocateDirect(Float.BYTES * 16).order(ByteOrder.nativeOrder()).asFloatBuffer(); buf.put(inputData); buf.rewind(); @@ -1337,7 +1336,7 @@ public void testRunWithLoraAdapter() throws IOException, OrtException { // With LoRA from path try (OrtLoraAdapter adapter = OrtLoraAdapter.create(adapterPath.toString()); - OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) { + OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) { runOptions.addActiveLoraAdapter(adapter); try (OrtSession.Result result = session.run(inputs, runOptions)) { float[][] resultArr = (float[][]) result.get(0).getValue(); @@ -1348,7 +1347,7 @@ public void testRunWithLoraAdapter() throws IOException, OrtException { // With LoRA from array byte[] loraArray = Files.readAllBytes(adapterPath); try (OrtLoraAdapter adapter = OrtLoraAdapter.create(loraArray); - OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) { + OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) { runOptions.addActiveLoraAdapter(adapter); try (OrtSession.Result result = session.run(inputs, runOptions)) { float[][] resultArr = (float[][]) result.get(0).getValue(); @@ -1361,14 +1360,13 @@ public void testRunWithLoraAdapter() throws IOException, OrtException { loraBuf.put(loraArray); loraBuf.rewind(); try (OrtLoraAdapter adapter = OrtLoraAdapter.create(loraBuf); - OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) { + OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) { runOptions.addActiveLoraAdapter(adapter); try (OrtSession.Result result = session.run(inputs, runOptions)) { float[][] resultArr = (float[][]) result.get(0).getValue(); Assertions.assertArrayEquals(expectedLoRAOutput, resultArr); } } - } } From 7d1a394f6b01d2885e3ac96ef29c2041ced7e190 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Mon, 30 Sep 2024 20:48:03 -0400 Subject: [PATCH 09/11] Fixing some comment reformatting. --- .../main/java/ai/onnxruntime/OrtSession.java | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 02fb47a421ba7..8380afed1cbd2 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -403,8 +403,7 @@ public Result run( if (outputNames.contains(s)) { if (!pinnedOutputs.containsKey(s)) { outputNamesArray[i] = s; - // outputValues and outputHandles can be null/0 for these outputs as ORT will - // allocate + // outputValues and outputHandles can be null/0 for these outputs as ORT will allocate // them. i++; } else { @@ -446,9 +445,9 @@ public Result run( */ static long getHandle(OnnxValue v) { /* - * Note this method exists as interface methods are all public, but we do not - * want users to be able to access the native pointer via a public API so can't - * add a method to OnnxValue which exposes it. + * Note this method exists as interface methods are all public, but we do not want users to be + * able to access the native pointer via a public API so can't add a method to OnnxValue which + * exposes it. */ if (v instanceof OnnxTensorLike) { return ((OnnxTensorLike) v).nativeHandle; @@ -1365,18 +1364,19 @@ private native void addConfigEntry( throws OrtException; /* - * To use additional providers, you must build ORT with the extra providers - * enabled. Then call one of these functions to enable them in the session: - * OrtSessionOptionsAppendExecutionProvider_CPU - * OrtSessionOptionsAppendExecutionProvider_CUDA - * OrtSessionOptionsAppendExecutionProvider_ROCM - * OrtSessionOptionsAppendExecutionProvider_ The order - * they care called indicates the preference order as well. In other words call - * this method on your most preferred execution provider first followed by the - * less preferred ones. If none are called Ort will use its internal CPU - * execution provider. - * - * If a backend is unavailable then it throws an OrtException + * To use additional providers, you must build ORT with the extra providers enabled. Then call + * one of these functions to enable them in the session: + * + * OrtSessionOptionsAppendExecutionProvider_CPU + * OrtSessionOptionsAppendExecutionProvider_CUDA + * OrtSessionOptionsAppendExecutionProvider_ROCM + * OrtSessionOptionsAppendExecutionProvider_ + * + * The order they are called indicates the preference order as well. In other words call this + * method on your most preferred execution provider first followed by the less preferred ones. + * If none are called ORT will use its internal CPU execution provider. + * + * If a backend is unavailable then it throws an OrtException. */ private native void addCPU(long apiHandle, long nativeHandle, int useArena) throws OrtException; @@ -1420,9 +1420,8 @@ private native void addCoreML(long apiHandle, long nativeHandle, int coreMLFlags throws OrtException; /* - * The max length of providerOptionKey and providerOptionVal is 128, as - * specified by ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH (search ONNXRuntime PR #14067 - * for its location). + * The max length of providerOptionKey and providerOptionVal is 128, as specified by + * ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH (search ONNXRuntime PR #14067 for its location). */ private native void addExecutionProvider( long apiHandle, From 635710d8c516525a1640f3ec9b061e78ea88bf1e Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Mon, 30 Sep 2024 21:53:42 -0400 Subject: [PATCH 10/11] Fix argument names. --- .../src/main/java/ai/onnxruntime/OrtLoraAdapter.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java b/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java index d9f2579e31d22..13f225dedc768 100644 --- a/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java +++ b/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java @@ -98,28 +98,28 @@ static OrtLoraAdapter create(ByteBuffer loraBuffer, OrtAllocator allocator) thro /** * Creates an instance of OrtLoraAdapter. * - * @param absoluteAdapterPath path to the adapter file that is going to be memory mapped. + * @param adapterPath path to the adapter file that is going to be memory mapped. * @throws OrtException If the native call failed. * @return An OrtLoraAdapter instance. */ - public static OrtLoraAdapter create(String absoluteAdapterPath) throws OrtException { - return create(absoluteAdapterPath, null); + public static OrtLoraAdapter create(String adapterPath) throws OrtException { + return create(adapterPath, null); } /** * Creates an instance of OrtLoraAdapter. * - * @param absoluteAdapterPath path to the adapter file that is going to be memory mapped. + * @param adapterPath path to the adapter file that is going to be memory mapped. * @param allocator optional allocator or null. If supplied, adapter parameters are copied to the * allocator memory. * @throws OrtException If the native call failed. * @return An OrtLoraAdapter instance. */ - static OrtLoraAdapter create(String absoluteAdapterPath, OrtAllocator allocator) + static OrtLoraAdapter create(String adapterPath, OrtAllocator allocator) throws OrtException { long allocatorHandle = allocator == null ? 0 : allocator.handle; return new OrtLoraAdapter( - createLoraAdapter(OnnxRuntime.ortApiHandle, absoluteAdapterPath, allocatorHandle)); + createLoraAdapter(OnnxRuntime.ortApiHandle, adapterPath, allocatorHandle)); } /** From 44dfb49680b2b384e59bce826a083b36dc5a3b7f Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Mon, 30 Sep 2024 23:04:54 -0400 Subject: [PATCH 11/11] Fixing spotless error. --- java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java b/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java index 13f225dedc768..cf16e290e20c4 100644 --- a/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java +++ b/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java @@ -115,8 +115,7 @@ public static OrtLoraAdapter create(String adapterPath) throws OrtException { * @throws OrtException If the native call failed. * @return An OrtLoraAdapter instance. */ - static OrtLoraAdapter create(String adapterPath, OrtAllocator allocator) - throws OrtException { + static OrtLoraAdapter create(String adapterPath, OrtAllocator allocator) throws OrtException { long allocatorHandle = allocator == null ? 0 : allocator.handle; return new OrtLoraAdapter( createLoraAdapter(OnnxRuntime.ortApiHandle, adapterPath, allocatorHandle));