Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[java] CUDA & TensorRT options fix #20549

Merged
merged 6 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2937,7 +2937,7 @@ struct OrtApi {
*
* Please refer to https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#cc
* to know the available keys and values. Key should be in null terminated string format of the member of ::OrtTensorRTProviderOptionsV2
* and value should be its related range.
* and value should be its related range. Recreates the options and only sets the supplied values.
*
* For example, key="trt_max_workspace_size" and value="2147483648"
*
Expand Down Expand Up @@ -3433,7 +3433,7 @@ struct OrtApi {
*
* Please refer to https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options
* to know the available keys and values. Key should be in null terminated string format of the member of ::OrtCUDAProviderOptionsV2
* and value should be its related range.
* and value should be its related range. Recreates the options and only sets the supplied values.
*
* For example, key="device_id" and value="0"
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
/** * Conversions between fp16, bfloat16 and fp32. */
public final class Fp16Conversions {
private static final Logger logger = Logger.getLogger(Fp16Conversions.class.getName());


private Fp16Conversions() {}

/**
* Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java).
*
Expand Down
9 changes: 8 additions & 1 deletion java/src/main/java/ai/onnxruntime/OrtProviderOptions.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand Down Expand Up @@ -53,6 +53,13 @@ protected static long getApiHandle() {
*/
public abstract OrtProvider getProvider();

/**
* Applies the Java side configuration to the native side object.
*
* @throws OrtException If the native call failed.
*/
protected abstract void applyToNative() throws OrtException;

/**
* Is the native object closed?
*
Expand Down
6 changes: 5 additions & 1 deletion java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand Down Expand Up @@ -1022,6 +1022,8 @@ public void addCUDA(int deviceNum) throws OrtException {
public void addCUDA(OrtCUDAProviderOptions cudaOpts) throws OrtException {
checkClosed();
if (OnnxRuntime.extractCUDA()) {
// Cast is to make the compiler pick the right overload.
((OrtProviderOptions) cudaOpts).applyToNative();
addCUDAV2(OnnxRuntime.ortApiHandle, nativeHandle, cudaOpts.nativeHandle);
} else {
throw new OrtException(
Expand Down Expand Up @@ -1125,6 +1127,8 @@ public void addTensorrt(int deviceNum) throws OrtException {
public void addTensorrt(OrtTensorRTProviderOptions tensorRTOpts) throws OrtException {
checkClosed();
if (OnnxRuntime.extractTensorRT()) {
// Cast is to make the compiler pick the right overload.
((OrtProviderOptions) tensorRTOpts).applyToNative();
addTensorrtV2(OnnxRuntime.ortApiHandle, nativeHandle, tensorRTOpts.nativeHandle);
} else {
throw new OrtException(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.providers;
Expand Down Expand Up @@ -41,7 +41,6 @@ public OrtCUDAProviderOptions(int deviceId) throws OrtException {

String id = "" + deviceId;
this.options.put("device_id", id);
add(getApiHandle(), this.nativeHandle, "device_id", id);
}

@Override
Expand All @@ -59,17 +58,17 @@ public OrtProvider getProvider() {
private static native long create(long apiHandle) throws OrtException;

/**
* Adds an option to this options instance.
* Adds the options to this options instance.
*
* @param apiHandle The api pointer.
* @param nativeHandle The native options pointer.
* @param key The option key.
* @param value The option value.
* @param keys The option keys.
* @param values The option values.
* @throws OrtException If the addition failed.
*/
@Override
protected native void add(long apiHandle, long nativeHandle, String key, String value)
throws OrtException;
protected native void applyToNative(
long apiHandle, long nativeHandle, String[] keys, String[] values) throws OrtException;

/**
* Closes this options instance.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.providers;
Expand Down Expand Up @@ -41,7 +41,6 @@ public OrtTensorRTProviderOptions(int deviceId) throws OrtException {

String id = "" + deviceId;
this.options.put("device_id", id);
add(getApiHandle(), this.nativeHandle, "device_id", id);
}

@Override
Expand All @@ -59,17 +58,17 @@ public OrtProvider getProvider() {
private static native long create(long apiHandle) throws OrtException;

/**
* Adds an option to this options instance.
* Adds the options to this options instance.
*
* @param apiHandle The api pointer.
* @param nativeHandle The native options pointer.
* @param key The option key.
* @param value The option value.
* @param keys The option keys.
* @param values The option values.
* @throws OrtException If the addition failed.
*/
@Override
protected native void add(long apiHandle, long nativeHandle, String key, String value)
throws OrtException;
protected native void applyToNative(
long apiHandle, long nativeHandle, String[] keys, String[] values) throws OrtException;

/**
* Closes this options instance.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.providers;
Expand Down Expand Up @@ -36,7 +36,6 @@ public void add(String key, String value) throws OrtException {
Objects.requireNonNull(key, "Key must not be null");
Objects.requireNonNull(value, "Value must not be null");
options.put(key, value);
add(getApiHandle(), nativeHandle, key, value);
}

/**
Expand All @@ -49,7 +48,7 @@ public void add(String key, String value) throws OrtException {
public void parseOptionsString(String serializedForm) throws OrtException {
String[] options = serializedForm.split(";");
for (String o : options) {
if (!o.isEmpty() && o.contains("=")) {
if (o.contains("=")) {
String[] curOption = o.split("=");
if ((curOption.length == 2) && !curOption[0].isEmpty() && !curOption[1].isEmpty()) {
add(curOption[0], curOption[1]);
Expand All @@ -76,15 +75,31 @@ public String getOptionsString() {
.collect(Collectors.joining(";", "", ";"));
}

@Override
protected void applyToNative() throws OrtException {
if (!options.isEmpty()) {
String[] keys = new String[options.size()];
String[] values = new String[options.size()];
int i = 0;
for (Map.Entry<String, String> e : options.entrySet()) {
keys[i] = e.getKey();
values[i] = e.getValue();
i++;
}

applyToNative(getApiHandle(), this.nativeHandle, keys, values);
}
}

/**
* Adds an option to this options instance.
* Add all the options to this options instance.
*
* @param apiHandle The api pointer.
* @param nativeHandle The native options pointer.
* @param key The option key.
* @param value The option value.
* @param key The option keys.
* @param value The option values.
* @throws OrtException If the addition failed.
*/
protected abstract void add(long apiHandle, long nativeHandle, String key, String value)
throws OrtException;
protected abstract void applyToNative(
long apiHandle, long nativeHandle, String[] key, String[] value) throws OrtException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ public final class Fp16Conversions {
fp32ToFp16 = tmp32;
}

private Fp16Conversions() {}

/**
* Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java).
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2024 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
Expand All @@ -24,19 +24,46 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_cre

/*
* Class: ai_onnxruntime_providers_OrtCUDAProviderOptions
* Method: add
* Signature: (JJLjava/lang/String;Ljava/lang/String;)V
* Method: applyToNative
* Signature: (JJ[Ljava/lang/String;[Ljava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_add
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring key, jstring value) {
JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_applyToNative
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jobjectArray jKeyArr, jobjectArray jValueArr) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtCUDAProviderOptionsV2* opts = (OrtCUDAProviderOptionsV2*) optionsHandle;
const char* keyStr = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL);
const char* valueStr = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL);
checkOrtStatus(jniEnv,api,api->UpdateCUDAProviderOptions(opts, &keyStr, &valueStr, 1));
(*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keyStr);
(*jniEnv)->ReleaseStringUTFChars(jniEnv,value,valueStr);

jsize keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr);
const char** keys = (const char**) allocarray(keyLength, sizeof(const char*));
const char** values = (const char**) allocarray(keyLength, sizeof(const char*));
if ((keys == NULL) || (values == NULL)) {
if (keys != NULL) {
free((void*)keys);
}
if (values != NULL) {
free((void*)values);
}
throwOrtException(jniEnv, 1, "Not enough memory");
} else {
// Copy out strings into UTF-8.
for (jsize i = 0; i < keyLength; i++) {
jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i);
keys[i] = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL);
jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jValueArr, i);
values[i] = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL);
}
// Write to the provider options.
checkOrtStatus(jniEnv,api,api->UpdateCUDAProviderOptions(opts, keys, values, keyLength));
// Release allocated strings.
for (jsize i = 0; i < keyLength; i++) {
jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i);
(*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keys[i]);
jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i);
(*jniEnv)->ReleaseStringUTFChars(jniEnv,value,values[i]);
}
free((void*)keys);
free((void*)values);
}
}

/*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2024 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
Expand All @@ -23,19 +23,46 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions

/*
* Class: ai_onnxruntime_providers_OrtTensorRTProviderOptions
* Method: add
* Signature: (JJLjava/lang/String;Ljava/lang/String;)V
* Method: applyToNative
* Signature: (JJ[Ljava/lang/String;[Ljava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_add
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring key, jstring value) {
JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_applyToNative
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jobjectArray jKeyArr, jobjectArray jValueArr) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtTensorRTProviderOptionsV2* opts = (OrtTensorRTProviderOptionsV2*) optionsHandle;
const char* keyStr = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL);
const char* valueStr = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL);
checkOrtStatus(jniEnv,api,api->UpdateTensorRTProviderOptions(opts, &keyStr, &valueStr, 1));
(*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keyStr);
(*jniEnv)->ReleaseStringUTFChars(jniEnv,value,valueStr);

jsize keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr);
const char** keys = (const char**) allocarray(keyLength, sizeof(const char*));
const char** values = (const char**) allocarray(keyLength, sizeof(const char*));
if ((keys == NULL) || (values == NULL)) {
if (keys != NULL) {
free((void*)keys);
}
if (values != NULL) {
free((void*)values);
}
throwOrtException(jniEnv, 1, "Not enough memory");
} else {
// Copy out strings into UTF-8.
for (jsize i = 0; i < keyLength; i++) {
jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i);
keys[i] = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL);
jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jValueArr, i);
values[i] = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL);
}
// Write to the provider options.
checkOrtStatus(jniEnv,api,api->UpdateTensorRTProviderOptions(opts, keys, values, keyLength));
// Release allocated strings.
for (jsize i = 0; i < keyLength; i++) {
jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i);
(*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keys[i]);
jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i);
(*jniEnv)->ReleaseStringUTFChars(jniEnv,value,values[i]);
}
free((void*)keys);
free((void*)values);
}
}

/*
Expand Down
5 changes: 4 additions & 1 deletion java/src/test/java/ai/onnxruntime/InferenceTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand Down Expand Up @@ -678,6 +678,9 @@ private void runProvider(OrtProvider provider) throws OrtException {
if (provider == OrtProvider.CORE_ML) {
// CoreML gives slightly different answers on a 2020 13" M1 MBP
assertArrayEquals(expectedOutput, resultArray, 1e-2f);
} else if (provider == OrtProvider.CUDA) {
// CUDA gives slightly different answers on a H100 with CUDA 12.2
assertArrayEquals(expectedOutput, resultArray, 1e-3f);
} else {
assertArrayEquals(expectedOutput, resultArray, 1e-5f);
}
Expand Down
Loading
Loading