Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[java] Multi-LoRA support #22280

Merged
merged 11 commits into from
Oct 1, 2024
162 changes: 162 additions & 0 deletions java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Objects;

/**
* A container for an adapter which can be supplied to {@link
* OrtSession.RunOptions#addActiveLoraAdapter(OrtLoraAdapter)} to apply the adapter to a specific
* execution of a model.
*/
public final class OrtLoraAdapter implements AutoCloseable {
static {
try {
OnnxRuntime.init();
} catch (IOException e) {
throw new RuntimeException("Failed to load onnx-runtime library", e);
}
}

private final long nativeHandle;

private boolean closed = false;

private OrtLoraAdapter(long nativeHandle) {
this.nativeHandle = nativeHandle;
}

/**
* Creates an instance of OrtLoraAdapter from a byte array.
*
* @param loraArray The LoRA stored in a byte array.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
public static OrtLoraAdapter create(byte[] loraArray) throws OrtException {
return create(loraArray, null);
}

/**
* Creates an instance of OrtLoraAdapter from a byte array.
*
* @param loraArray The LoRA stored in a byte array.
* @param allocator optional allocator or null. If supplied, adapter parameters are copied to the
* allocator memory.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
static OrtLoraAdapter create(byte[] loraArray, OrtAllocator allocator) throws OrtException {
Objects.requireNonNull(loraArray, "LoRA array must not be null");
long allocatorHandle = allocator == null ? 0 : allocator.handle;
return new OrtLoraAdapter(
createLoraAdapterFromArray(OnnxRuntime.ortApiHandle, loraArray, allocatorHandle));
}

/**
* Creates an instance of OrtLoraAdapter from a direct ByteBuffer.
*
* @param loraBuffer The buffer to load.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
public static OrtLoraAdapter create(ByteBuffer loraBuffer) throws OrtException {
return create(loraBuffer, null);
}

/**
* Creates an instance of OrtLoraAdapter from a direct ByteBuffer.
*
* @param loraBuffer The buffer to load.
* @param allocator optional allocator or null. If supplied, adapter parameters are copied to the
* allocator memory.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
static OrtLoraAdapter create(ByteBuffer loraBuffer, OrtAllocator allocator) throws OrtException {
Objects.requireNonNull(loraBuffer, "LoRA buffer must not be null");
if (loraBuffer.remaining() == 0) {
throw new OrtException("Invalid LoRA buffer, no elements remaining.");
} else if (!loraBuffer.isDirect()) {
throw new OrtException("ByteBuffer is not direct.");
}
long allocatorHandle = allocator == null ? 0 : allocator.handle;
return new OrtLoraAdapter(
createLoraAdapterFromBuffer(
OnnxRuntime.ortApiHandle,
loraBuffer,
loraBuffer.position(),
loraBuffer.remaining(),
allocatorHandle));
}

/**
* Creates an instance of OrtLoraAdapter.
*
* @param absoluteAdapterPath path to the adapter file that is going to be memory mapped.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
public static OrtLoraAdapter create(String absoluteAdapterPath) throws OrtException {
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
return create(absoluteAdapterPath, null);
}

/**
* Creates an instance of OrtLoraAdapter.
*
* @param absoluteAdapterPath path to the adapter file that is going to be memory mapped.
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
* @param allocator optional allocator or null. If supplied, adapter parameters are copied to the
* allocator memory.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
static OrtLoraAdapter create(String absoluteAdapterPath, OrtAllocator allocator)
throws OrtException {
long allocatorHandle = allocator == null ? 0 : allocator.handle;
return new OrtLoraAdapter(
createLoraAdapter(OnnxRuntime.ortApiHandle, absoluteAdapterPath, allocatorHandle));
}

/**
* Package accessor for native pointer.
*
* @return The native pointer.
*/
long getNativeHandle() {
return nativeHandle;
}

/** Checks if the OrtLoraAdapter is closed, if so throws {@link IllegalStateException}. */
void checkClosed() {
if (closed) {
throw new IllegalStateException("Trying to use a closed OrtLoraAdapter");
}
}

@Override
public void close() {
if (!closed) {
close(OnnxRuntime.ortApiHandle, nativeHandle);
closed = true;
} else {
throw new IllegalStateException("Trying to close an already closed OrtLoraAdapter");
}
}

private static native long createLoraAdapter(
long apiHandle, String adapterPath, long allocatorHandle) throws OrtException;

private static native long createLoraAdapterFromArray(
long apiHandle, byte[] loraBytes, long allocatorHandle) throws OrtException;

private static native long createLoraAdapterFromBuffer(
long apiHandle, ByteBuffer loraBuffer, int bufferPos, int bufferSize, long allocatorHandle)
throws OrtException;

private static native void close(long apiHandle, long nativeHandle);
}
34 changes: 26 additions & 8 deletions java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,8 @@ public static class SessionOptions implements AutoCloseable {
* The optimisation level to use. Needs to be kept in sync with the GraphOptimizationLevel enum
* in the C API.
*
* <p>See <a
* href="https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html">Graph
* <p>See <a href=
* "https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html">Graph
* Optimizations</a> for more details.
*/
public enum OptLevel {
Expand Down Expand Up @@ -684,6 +684,7 @@ public enum ExecutionMode {
SEQUENTIAL(0),
/** Executes some nodes in parallel. */
PARALLEL(1);

private final int id;

ExecutionMode(int id) {
Expand Down Expand Up @@ -1363,17 +1364,19 @@ private native void addConfigEntry(
throws OrtException;

/*
* To use additional providers, you must build ORT with the extra providers enabled. Then call one of these
* functions to enable them in the session:
* To use additional providers, you must build ORT with the extra providers enabled. Then call
* one of these functions to enable them in the session:
*
* OrtSessionOptionsAppendExecutionProvider_CPU
* OrtSessionOptionsAppendExecutionProvider_CUDA
* OrtSessionOptionsAppendExecutionProvider_ROCM
* OrtSessionOptionsAppendExecutionProvider_<remaining providers...>
* The order they care called indicates the preference order as well. In other words call this method
* on your most preferred execution provider first followed by the less preferred ones.
* If none are called Ort will use its internal CPU execution provider.
*
* If a backend is unavailable then it throws an OrtException
* The order they are called indicates the preference order as well. In other words call this
* method on your most preferred execution provider first followed by the less preferred ones.
* If none are called ORT will use its internal CPU execution provider.
*
* If a backend is unavailable then it throws an OrtException.
*/
private native void addCPU(long apiHandle, long nativeHandle, int useArena) throws OrtException;

Expand Down Expand Up @@ -1555,6 +1558,18 @@ public void addRunConfigEntry(String key, String value) throws OrtException {
addRunConfigEntry(OnnxRuntime.ortApiHandle, nativeHandle, key, value);
}

/**
* Adds the specified adapter to the list of active adapters for this run.
*
* @param loraAdapter valid OrtLoraAdapter object
* @throws OrtException of the native library call failed
*/
public void addActiveLoraAdapter(OrtLoraAdapter loraAdapter) throws OrtException {
checkClosed();
loraAdapter.checkClosed();
addActiveLoraAdapter(OnnxRuntime.ortApiHandle, nativeHandle, loraAdapter.getNativeHandle());
}

/** Checks if the RunOptions is closed, if so throws {@link IllegalStateException}. */
private void checkClosed() {
if (closed) {
Expand Down Expand Up @@ -1595,6 +1610,9 @@ private native void setTerminate(long apiHandle, long nativeHandle, boolean term
private native void addRunConfigEntry(
long apiHandle, long nativeHandle, String key, String value) throws OrtException;

private native void addActiveLoraAdapter(
long apiHandle, long nativeHandle, long loraAdapterHandle) throws OrtException;

private static native void close(long apiHandle, long nativeHandle);
}

Expand Down
106 changes: 106 additions & 0 deletions java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
#include <string.h>
#include "onnxruntime/core/session/onnxruntime_c_api.h"
#include "OrtJniUtil.h"
#include "ai_onnxruntime_OrtLoraAdapter.h"

/*
* Class: ai_onnxruntime_OrtLoraAdapter
* Method: createLoraAdapter
* Signature: (JLjava/lang/String;J)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapter
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jstring loraPath, jlong allocatorHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
OrtLoraAdapter* lora;

#ifdef _WIN32
const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, loraPath, NULL);
size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, loraPath);
wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t));
if (newString == NULL) {
(*jniEnv)->ReleaseStringChars(jniEnv, loraPath, cPath);
throwOrtException(jniEnv, 1, "Not enough memory");
return 0;
}
wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength);
checkOrtStatus(jniEnv, api, api->CreateLoraAdapter(newString, allocator, &lora));
free(newString);
(*jniEnv)->ReleaseStringChars(jniEnv, loraPath, cPath);
#else
const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, loraPath, NULL);
checkOrtStatus(jniEnv, api, api->CreateLoraAdapter(cPath, allocator, &lora));
(*jniEnv)->ReleaseStringUTFChars(jniEnv, loraPath, cPath);
#endif

return (jlong) lora;
}

/*
* Class: ai_onnxruntime_OrtLoraAdapter
* Method: createLoraAdapterFromBuffer
* Signature: (JLjava/nio/ByteBuffer;IIJ)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapterFromBuffer
(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jobject buffer, jint bufferPos, jint bufferSize, jlong allocatorHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
OrtLoraAdapter* lora;

// Extract the buffer
char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer);
// Increment by bufferPos bytes
bufferArr = bufferArr + bufferPos;

// Create the adapter
checkOrtStatus(jniEnv, api, api->CreateLoraAdapterFromArray((const uint8_t*) bufferArr, bufferSize, allocator, &lora));

return (jlong) lora;
}

/*
* Class: ai_onnxruntime_OrtLoraAdapter
* Method: createLoraAdapterFromArray
* Signature: (J[BJ)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapterFromArray
(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jbyteArray jLoraArray, jlong allocatorHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
OrtLoraAdapter* lora;

size_t loraLength = (*jniEnv)->GetArrayLength(jniEnv, jLoraArray);
if (loraLength == 0) {
throwOrtException(jniEnv, 2, "Invalid LoRA, the byte array is zero length.");
return 0;
}

// Get a reference to the byte array elements
jbyte* loraArr = (*jniEnv)->GetByteArrayElements(jniEnv, jLoraArray, NULL);
checkOrtStatus(jniEnv, api, api->CreateLoraAdapterFromArray((const uint8_t*) loraArr, loraLength, allocator, &lora));
// Release the C array.
(*jniEnv)->ReleaseByteArrayElements(jniEnv, jLoraArray, loraArr, JNI_ABORT);

return (jlong) lora;
}

/*
* Class: ai_onnxruntime_OrtLoraAdapter
* Method: close
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtLoraAdapter_close
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong loraHandle) {
(void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
api->ReleaseLoraAdapter((OrtLoraAdapter*) loraHandle);
}

12 changes: 12 additions & 0 deletions java/src/main/native/ai_onnxruntime_OrtSession_RunOptions.c
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_addRunConf
(*jniEnv)->ReleaseStringUTFChars(jniEnv, valueStr, value);
}

/*
* Class: ai_onnxruntime_OrtSession_RunOptions
* Method: addActiveLoraAdapter
* Signature: (JJJ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_addActiveLoraAdapter
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jlong loraHandle) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
checkOrtStatus(jniEnv, api, api->RunOptionsAddActiveLoraAdapter((OrtRunOptions*) nativeHandle, (OrtLoraAdapter*) loraHandle));
}

/*
* Class: ai_onnxruntime_OrtSession_RunOptions
* Method: setTerminate
Expand Down
Loading
Loading