diff --git a/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java b/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java new file mode 100644 index 0000000000000..cf16e290e20c4 --- /dev/null +++ b/java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; + +/** + * A container for an adapter which can be supplied to {@link + * OrtSession.RunOptions#addActiveLoraAdapter(OrtLoraAdapter)} to apply the adapter to a specific + * execution of a model. + */ +public final class OrtLoraAdapter implements AutoCloseable { + static { + try { + OnnxRuntime.init(); + } catch (IOException e) { + throw new RuntimeException("Failed to load onnx-runtime library", e); + } + } + + private final long nativeHandle; + + private boolean closed = false; + + private OrtLoraAdapter(long nativeHandle) { + this.nativeHandle = nativeHandle; + } + + /** + * Creates an instance of OrtLoraAdapter from a byte array. + * + * @param loraArray The LoRA stored in a byte array. + * @throws OrtException If the native call failed. + * @return An OrtLoraAdapter instance. + */ + public static OrtLoraAdapter create(byte[] loraArray) throws OrtException { + return create(loraArray, null); + } + + /** + * Creates an instance of OrtLoraAdapter from a byte array. + * + * @param loraArray The LoRA stored in a byte array. + * @param allocator optional allocator or null. If supplied, adapter parameters are copied to the + * allocator memory. + * @throws OrtException If the native call failed. + * @return An OrtLoraAdapter instance. + */ + static OrtLoraAdapter create(byte[] loraArray, OrtAllocator allocator) throws OrtException { + Objects.requireNonNull(loraArray, "LoRA array must not be null"); + long allocatorHandle = allocator == null ? 0 : allocator.handle; + return new OrtLoraAdapter( + createLoraAdapterFromArray(OnnxRuntime.ortApiHandle, loraArray, allocatorHandle)); + } + + /** + * Creates an instance of OrtLoraAdapter from a direct ByteBuffer. + * + * @param loraBuffer The buffer to load. + * @throws OrtException If the native call failed. + * @return An OrtLoraAdapter instance. + */ + public static OrtLoraAdapter create(ByteBuffer loraBuffer) throws OrtException { + return create(loraBuffer, null); + } + + /** + * Creates an instance of OrtLoraAdapter from a direct ByteBuffer. + * + * @param loraBuffer The buffer to load. + * @param allocator optional allocator or null. If supplied, adapter parameters are copied to the + * allocator memory. + * @throws OrtException If the native call failed. + * @return An OrtLoraAdapter instance. + */ + static OrtLoraAdapter create(ByteBuffer loraBuffer, OrtAllocator allocator) throws OrtException { + Objects.requireNonNull(loraBuffer, "LoRA buffer must not be null"); + if (loraBuffer.remaining() == 0) { + throw new OrtException("Invalid LoRA buffer, no elements remaining."); + } else if (!loraBuffer.isDirect()) { + throw new OrtException("ByteBuffer is not direct."); + } + long allocatorHandle = allocator == null ? 0 : allocator.handle; + return new OrtLoraAdapter( + createLoraAdapterFromBuffer( + OnnxRuntime.ortApiHandle, + loraBuffer, + loraBuffer.position(), + loraBuffer.remaining(), + allocatorHandle)); + } + + /** + * Creates an instance of OrtLoraAdapter. + * + * @param adapterPath path to the adapter file that is going to be memory mapped. + * @throws OrtException If the native call failed. + * @return An OrtLoraAdapter instance. + */ + public static OrtLoraAdapter create(String adapterPath) throws OrtException { + return create(adapterPath, null); + } + + /** + * Creates an instance of OrtLoraAdapter. + * + * @param adapterPath path to the adapter file that is going to be memory mapped. + * @param allocator optional allocator or null. If supplied, adapter parameters are copied to the + * allocator memory. + * @throws OrtException If the native call failed. + * @return An OrtLoraAdapter instance. + */ + static OrtLoraAdapter create(String adapterPath, OrtAllocator allocator) throws OrtException { + long allocatorHandle = allocator == null ? 0 : allocator.handle; + return new OrtLoraAdapter( + createLoraAdapter(OnnxRuntime.ortApiHandle, adapterPath, allocatorHandle)); + } + + /** + * Package accessor for native pointer. + * + * @return The native pointer. + */ + long getNativeHandle() { + return nativeHandle; + } + + /** Checks if the OrtLoraAdapter is closed, if so throws {@link IllegalStateException}. */ + void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OrtLoraAdapter"); + } + } + + @Override + public void close() { + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + throw new IllegalStateException("Trying to close an already closed OrtLoraAdapter"); + } + } + + private static native long createLoraAdapter( + long apiHandle, String adapterPath, long allocatorHandle) throws OrtException; + + private static native long createLoraAdapterFromArray( + long apiHandle, byte[] loraBytes, long allocatorHandle) throws OrtException; + + private static native long createLoraAdapterFromBuffer( + long apiHandle, ByteBuffer loraBuffer, int bufferPos, int bufferSize, long allocatorHandle) + throws OrtException; + + private static native void close(long apiHandle, long nativeHandle); +} diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 6d146d5857d3c..8380afed1cbd2 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -635,8 +635,8 @@ public static class SessionOptions implements AutoCloseable { * The optimisation level to use. Needs to be kept in sync with the GraphOptimizationLevel enum * in the C API. * - *

See Graph + *

See Graph * Optimizations for more details. */ public enum OptLevel { @@ -684,6 +684,7 @@ public enum ExecutionMode { SEQUENTIAL(0), /** Executes some nodes in parallel. */ PARALLEL(1); + private final int id; ExecutionMode(int id) { @@ -1363,17 +1364,19 @@ private native void addConfigEntry( throws OrtException; /* - * To use additional providers, you must build ORT with the extra providers enabled. Then call one of these - * functions to enable them in the session: + * To use additional providers, you must build ORT with the extra providers enabled. Then call + * one of these functions to enable them in the session: + * * OrtSessionOptionsAppendExecutionProvider_CPU * OrtSessionOptionsAppendExecutionProvider_CUDA * OrtSessionOptionsAppendExecutionProvider_ROCM * OrtSessionOptionsAppendExecutionProvider_ - * The order they care called indicates the preference order as well. In other words call this method - * on your most preferred execution provider first followed by the less preferred ones. - * If none are called Ort will use its internal CPU execution provider. * - * If a backend is unavailable then it throws an OrtException + * The order they are called indicates the preference order as well. In other words call this + * method on your most preferred execution provider first followed by the less preferred ones. + * If none are called ORT will use its internal CPU execution provider. + * + * If a backend is unavailable then it throws an OrtException. */ private native void addCPU(long apiHandle, long nativeHandle, int useArena) throws OrtException; @@ -1555,6 +1558,18 @@ public void addRunConfigEntry(String key, String value) throws OrtException { addRunConfigEntry(OnnxRuntime.ortApiHandle, nativeHandle, key, value); } + /** + * Adds the specified adapter to the list of active adapters for this run. + * + * @param loraAdapter valid OrtLoraAdapter object + * @throws OrtException of the native library call failed + */ + public void addActiveLoraAdapter(OrtLoraAdapter loraAdapter) throws OrtException { + checkClosed(); + loraAdapter.checkClosed(); + addActiveLoraAdapter(OnnxRuntime.ortApiHandle, nativeHandle, loraAdapter.getNativeHandle()); + } + /** Checks if the RunOptions is closed, if so throws {@link IllegalStateException}. */ private void checkClosed() { if (closed) { @@ -1595,6 +1610,9 @@ private native void setTerminate(long apiHandle, long nativeHandle, boolean term private native void addRunConfigEntry( long apiHandle, long nativeHandle, String key, String value) throws OrtException; + private native void addActiveLoraAdapter( + long apiHandle, long nativeHandle, long loraAdapterHandle) throws OrtException; + private static native void close(long apiHandle, long nativeHandle); } diff --git a/java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c b/java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c new file mode 100644 index 0000000000000..8b1ee82614b15 --- /dev/null +++ b/java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include +#include "onnxruntime/core/session/onnxruntime_c_api.h" +#include "OrtJniUtil.h" +#include "ai_onnxruntime_OrtLoraAdapter.h" + +/* + * Class: ai_onnxruntime_OrtLoraAdapter + * Method: createLoraAdapter + * Signature: (JLjava/lang/String;J)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapter + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jstring loraPath, jlong allocatorHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; + OrtLoraAdapter* lora; + +#ifdef _WIN32 + const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, loraPath, NULL); + size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, loraPath); + wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); + if (newString == NULL) { + (*jniEnv)->ReleaseStringChars(jniEnv, loraPath, cPath); + throwOrtException(jniEnv, 1, "Not enough memory"); + return 0; + } + wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); + checkOrtStatus(jniEnv, api, api->CreateLoraAdapter(newString, allocator, &lora)); + free(newString); + (*jniEnv)->ReleaseStringChars(jniEnv, loraPath, cPath); +#else + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, loraPath, NULL); + checkOrtStatus(jniEnv, api, api->CreateLoraAdapter(cPath, allocator, &lora)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, loraPath, cPath); +#endif + + return (jlong) lora; +} + +/* + * Class: ai_onnxruntime_OrtLoraAdapter + * Method: createLoraAdapterFromBuffer + * Signature: (JLjava/nio/ByteBuffer;IIJ)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapterFromBuffer + (JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jobject buffer, jint bufferPos, jint bufferSize, jlong allocatorHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; + OrtLoraAdapter* lora; + + // Extract the buffer + char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer); + // Increment by bufferPos bytes + bufferArr = bufferArr + bufferPos; + + // Create the adapter + checkOrtStatus(jniEnv, api, api->CreateLoraAdapterFromArray((const uint8_t*) bufferArr, bufferSize, allocator, &lora)); + + return (jlong) lora; +} + +/* + * Class: ai_onnxruntime_OrtLoraAdapter + * Method: createLoraAdapterFromArray + * Signature: (J[BJ)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapterFromArray + (JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jbyteArray jLoraArray, jlong allocatorHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; + OrtLoraAdapter* lora; + + size_t loraLength = (*jniEnv)->GetArrayLength(jniEnv, jLoraArray); + if (loraLength == 0) { + throwOrtException(jniEnv, 2, "Invalid LoRA, the byte array is zero length."); + return 0; + } + + // Get a reference to the byte array elements + jbyte* loraArr = (*jniEnv)->GetByteArrayElements(jniEnv, jLoraArray, NULL); + checkOrtStatus(jniEnv, api, api->CreateLoraAdapterFromArray((const uint8_t*) loraArr, loraLength, allocator, &lora)); + // Release the C array. + (*jniEnv)->ReleaseByteArrayElements(jniEnv, jLoraArray, loraArr, JNI_ABORT); + + return (jlong) lora; +} + +/* + * Class: ai_onnxruntime_OrtLoraAdapter + * Method: close + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtLoraAdapter_close + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong loraHandle) { + (void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + api->ReleaseLoraAdapter((OrtLoraAdapter*) loraHandle); +} + diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_RunOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_RunOptions.c index 0ab9bbb889209..3cbe2643716ad 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_RunOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_RunOptions.c @@ -124,6 +124,18 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_addRunConf (*jniEnv)->ReleaseStringUTFChars(jniEnv, valueStr, value); } +/* + * Class: ai_onnxruntime_OrtSession_RunOptions + * Method: addActiveLoraAdapter + * Signature: (JJJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_addActiveLoraAdapter + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jlong loraHandle) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + checkOrtStatus(jniEnv, api, api->RunOptionsAddActiveLoraAdapter((OrtRunOptions*) nativeHandle, (OrtLoraAdapter*) loraHandle)); +} + /* * Class: ai_onnxruntime_OrtSession_RunOptions * Method: setTerminate diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 7cb6305923279..5999237bea6bc 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1294,6 +1294,82 @@ public void testRunOptions() throws OrtException { } } + @Test + public void testRunWithLoraAdapter() throws IOException, OrtException { + Path modelPath = TestHelpers.getResourcePath("/lora/two_params_lora_model.onnx"); + Path adapterPath = TestHelpers.getResourcePath("/lora/two_params_lora_model.onnx_adapter"); + + long[] inputShape = new long[] {4, 4}; + float[] inputData = new float[16]; + Arrays.fill(inputData, 1.f); + FloatBuffer buf = + ByteBuffer.allocateDirect(Float.BYTES * 16).order(ByteOrder.nativeOrder()).asFloatBuffer(); + buf.put(inputData); + buf.rewind(); + + float[][] expectedOutput = + new float[][] { + {28.f, 32.f, 36.f, 40.f}, + {28.f, 32.f, 36.f, 40.f}, + {28.f, 32.f, 36.f, 40.f}, + {28.f, 32.f, 36.f, 40.f} + }; + + float[][] expectedLoRAOutput = + new float[][] { + {154.f, 176.f, 198.f, 220.f}, + {154.f, 176.f, 198.f, 220.f}, + {154.f, 176.f, 198.f, 220.f}, + {154.f, 176.f, 198.f, 220.f} + }; + + try (OrtSession session = env.createSession(modelPath.toString()); + OnnxTensor tensor = OnnxTensor.createTensor(env, buf, inputShape)) { + + Map inputs = Collections.singletonMap("input_x", tensor); + + // Without LoRA + try (OrtSession.Result result = session.run(inputs)) { + float[][] resultArr = (float[][]) result.get(0).getValue(); + Assertions.assertArrayEquals(expectedOutput, resultArr); + } + + // With LoRA from path + try (OrtLoraAdapter adapter = OrtLoraAdapter.create(adapterPath.toString()); + OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) { + runOptions.addActiveLoraAdapter(adapter); + try (OrtSession.Result result = session.run(inputs, runOptions)) { + float[][] resultArr = (float[][]) result.get(0).getValue(); + Assertions.assertArrayEquals(expectedLoRAOutput, resultArr); + } + } + + // With LoRA from array + byte[] loraArray = Files.readAllBytes(adapterPath); + try (OrtLoraAdapter adapter = OrtLoraAdapter.create(loraArray); + OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) { + runOptions.addActiveLoraAdapter(adapter); + try (OrtSession.Result result = session.run(inputs, runOptions)) { + float[][] resultArr = (float[][]) result.get(0).getValue(); + Assertions.assertArrayEquals(expectedLoRAOutput, resultArr); + } + } + + // With LoRA from buffer + ByteBuffer loraBuf = ByteBuffer.allocateDirect(loraArray.length); + loraBuf.put(loraArray); + loraBuf.rewind(); + try (OrtLoraAdapter adapter = OrtLoraAdapter.create(loraBuf); + OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) { + runOptions.addActiveLoraAdapter(adapter); + try (OrtSession.Result result = session.run(inputs, runOptions)) { + float[][] resultArr = (float[][]) result.get(0).getValue(); + Assertions.assertArrayEquals(expectedLoRAOutput, resultArr); + } + } + } + } + @Test public void testExtraSessionOptions() throws OrtException, IOException { // model takes 1x5 input of fixed type, echoes back