From 9544adc2051ccd52c724094e4be54aa7ad9e10ac Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Thu, 2 May 2024 20:08:18 -0400 Subject: [PATCH 1/6] Adding a private constructor to Fp16Conversions to silence a javadoc warning. --- .../main/android/ai/onnxruntime/platform/Fp16Conversions.java | 4 +++- .../src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java b/java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java index dd7dd07fc1f5d..c5ee8aa5b4648 100644 --- a/java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java +++ b/java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java @@ -17,7 +17,9 @@ /** * Conversions between fp16, bfloat16 and fp32. */ public final class Fp16Conversions { private static final Logger logger = Logger.getLogger(Fp16Conversions.class.getName()); - + + private Fp16Conversions() {} + /** * Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java). * diff --git a/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java b/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java index fce872688aa1f..451c0d9848586 100644 --- a/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java +++ b/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java @@ -54,6 +54,8 @@ public final class Fp16Conversions { fp32ToFp16 = tmp32; } + private Fp16Conversions() {} + /** * Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java). * From f33d03f667efa03ffa11571ab9768432745e874d Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Thu, 2 May 2024 20:11:25 -0400 Subject: [PATCH 2/6] Fixing OrtCUDAProviderOptions and OrtTensorRTProviderOptions so they call the update method once. --- .../ai/onnxruntime/OrtProviderOptions.java | 9 +++- .../main/java/ai/onnxruntime/OrtSession.java | 6 ++- .../providers/OrtCUDAProviderOptions.java | 13 +++-- .../providers/OrtTensorRTProviderOptions.java | 13 +++-- .../StringConfigProviderOptions.java | 31 ++++++++---- ...runtime_providers_OrtCUDAProviderOptions.c | 45 ++++++++++++++---- ...ime_providers_OrtTensorRTProviderOptions.c | 45 ++++++++++++++---- .../providers/ProviderOptionsTest.java | 47 ++++++++++++------- 8 files changed, 148 insertions(+), 61 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java b/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java index 70af10ff8cd79..ca7bf2f317ce4 100644 --- a/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -53,6 +53,13 @@ protected static long getApiHandle() { */ public abstract OrtProvider getProvider(); + /** + * Applies the Java side configuration to the native side object. + * + * @throws OrtException If the native call failed. + */ + protected abstract void applyToNative() throws OrtException; + /** * Is the native object closed? * diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index fbea13d155507..8ab4a1cb26bb1 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -1022,6 +1022,8 @@ public void addCUDA(int deviceNum) throws OrtException { public void addCUDA(OrtCUDAProviderOptions cudaOpts) throws OrtException { checkClosed(); if (OnnxRuntime.extractCUDA()) { + // Cast is to make the compiler pick the right overload. + ((OrtProviderOptions) cudaOpts).applyToNative(); addCUDAV2(OnnxRuntime.ortApiHandle, nativeHandle, cudaOpts.nativeHandle); } else { throw new OrtException( @@ -1125,6 +1127,8 @@ public void addTensorrt(int deviceNum) throws OrtException { public void addTensorrt(OrtTensorRTProviderOptions tensorRTOpts) throws OrtException { checkClosed(); if (OnnxRuntime.extractTensorRT()) { + // Cast is to make the compiler pick the right overload. + ((OrtProviderOptions) tensorRTOpts).applyToNative(); addTensorrtV2(OnnxRuntime.ortApiHandle, nativeHandle, tensorRTOpts.nativeHandle); } else { throw new OrtException( diff --git a/java/src/main/java/ai/onnxruntime/providers/OrtCUDAProviderOptions.java b/java/src/main/java/ai/onnxruntime/providers/OrtCUDAProviderOptions.java index b7a83708a2314..6c1e8f02e90af 100644 --- a/java/src/main/java/ai/onnxruntime/providers/OrtCUDAProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/providers/OrtCUDAProviderOptions.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -41,7 +41,6 @@ public OrtCUDAProviderOptions(int deviceId) throws OrtException { String id = "" + deviceId; this.options.put("device_id", id); - add(getApiHandle(), this.nativeHandle, "device_id", id); } @Override @@ -59,17 +58,17 @@ public OrtProvider getProvider() { private static native long create(long apiHandle) throws OrtException; /** - * Adds an option to this options instance. + * Adds the options to this options instance. * * @param apiHandle The api pointer. * @param nativeHandle The native options pointer. - * @param key The option key. - * @param value The option value. + * @param keys The option keys. + * @param values The option values. * @throws OrtException If the addition failed. */ @Override - protected native void add(long apiHandle, long nativeHandle, String key, String value) - throws OrtException; + protected native void applyToNative( + long apiHandle, long nativeHandle, String[] keys, String[] values) throws OrtException; /** * Closes this options instance. diff --git a/java/src/main/java/ai/onnxruntime/providers/OrtTensorRTProviderOptions.java b/java/src/main/java/ai/onnxruntime/providers/OrtTensorRTProviderOptions.java index 958d3a9e18f9b..0a69f0b72415b 100644 --- a/java/src/main/java/ai/onnxruntime/providers/OrtTensorRTProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/providers/OrtTensorRTProviderOptions.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -41,7 +41,6 @@ public OrtTensorRTProviderOptions(int deviceId) throws OrtException { String id = "" + deviceId; this.options.put("device_id", id); - add(getApiHandle(), this.nativeHandle, "device_id", id); } @Override @@ -59,17 +58,17 @@ public OrtProvider getProvider() { private static native long create(long apiHandle) throws OrtException; /** - * Adds an option to this options instance. + * Adds the options to this options instance. * * @param apiHandle The api pointer. * @param nativeHandle The native options pointer. - * @param key The option key. - * @param value The option value. + * @param keys The option keys. + * @param values The option values. * @throws OrtException If the addition failed. */ @Override - protected native void add(long apiHandle, long nativeHandle, String key, String value) - throws OrtException; + protected native void applyToNative( + long apiHandle, long nativeHandle, String[] keys, String[] values) throws OrtException; /** * Closes this options instance. diff --git a/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java b/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java index 961163035c9a6..8abc227d23aef 100644 --- a/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -36,7 +36,6 @@ public void add(String key, String value) throws OrtException { Objects.requireNonNull(key, "Key must not be null"); Objects.requireNonNull(value, "Value must not be null"); options.put(key, value); - add(getApiHandle(), nativeHandle, key, value); } /** @@ -49,7 +48,7 @@ public void add(String key, String value) throws OrtException { public void parseOptionsString(String serializedForm) throws OrtException { String[] options = serializedForm.split(";"); for (String o : options) { - if (!o.isEmpty() && o.contains("=")) { + if (o.contains("=")) { String[] curOption = o.split("="); if ((curOption.length == 2) && !curOption[0].isEmpty() && !curOption[1].isEmpty()) { add(curOption[0], curOption[1]); @@ -76,15 +75,31 @@ public String getOptionsString() { .collect(Collectors.joining(";", "", ";")); } + @Override + protected void applyToNative() throws OrtException { + if (!options.isEmpty()) { + String[] keys = new String[options.size()]; + String[] values = new String[options.size()]; + int i = 0; + for (Map.Entry e : options.entrySet()) { + keys[i] = e.getKey(); + values[i] = e.getValue(); + i++; + } + + applyToNative(getApiHandle(), this.nativeHandle, keys, values); + } + } + /** - * Adds an option to this options instance. + * Add all the options to this options instance. * * @param apiHandle The api pointer. * @param nativeHandle The native options pointer. - * @param key The option key. - * @param value The option value. + * @param key The option keys. + * @param value The option values. * @throws OrtException If the addition failed. */ - protected abstract void add(long apiHandle, long nativeHandle, String key, String value) - throws OrtException; + protected abstract void applyToNative( + long apiHandle, long nativeHandle, String[] key, String[] value) throws OrtException; } diff --git a/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c b/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c index 22907fc65c16c..307b77304a059 100644 --- a/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c +++ b/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -24,19 +24,44 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_cre /* * Class: ai_onnxruntime_providers_OrtCUDAProviderOptions - * Method: add - * Signature: (JJLjava/lang/String;Ljava/lang/String;)V + * Method: applyToNative + * Signature: (JJ[Ljava/lang/String;[Ljava/lang/String;)V */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_add - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring key, jstring value) { +JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_applyToNative + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jobjectArray jKeyArr, jobjectArray jValueArr) { (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; OrtCUDAProviderOptionsV2* opts = (OrtCUDAProviderOptionsV2*) optionsHandle; - const char* keyStr = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL); - const char* valueStr = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL); - checkOrtStatus(jniEnv,api,api->UpdateCUDAProviderOptions(opts, &keyStr, &valueStr, 1)); - (*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keyStr); - (*jniEnv)->ReleaseStringUTFChars(jniEnv,value,valueStr); + + size_t keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr); + const char** keys = allocarray(keyLength, sizeof(char*)); + const char** values = allocarray(keyLength, sizeof(char*)); + if ((keys == NULL) || (values == NULL)) { + if (keys != NULL) { + free(keys); + } + if (values != NULL) { + free(values); + } + throwOrtException(jniEnv, 1, "Not enough memory"); + } else { + // Copy out strings into UTF-8. + for (size_t i = 0; i < keyLength; i++) { + jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + keys[i] = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL); + jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jValueArr, i); + values[i] = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL); + } + // Write to the provider options. + checkOrtStatus(jniEnv,api,api->UpdateCUDAProviderOptions(opts, keys, values, keyLength)); + // Release allocated strings. + for (size_t i = 0; i < keyLength; i++) { + jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + (*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keys[i]); + jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + (*jniEnv)->ReleaseStringUTFChars(jniEnv,value,values[i]); + } + } } /* diff --git a/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c b/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c index 9146e7dd589aa..96a267347a301 100644 --- a/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c +++ b/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -23,19 +23,44 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions /* * Class: ai_onnxruntime_providers_OrtTensorRTProviderOptions - * Method: add - * Signature: (JJLjava/lang/String;Ljava/lang/String;)V + * Method: applyToNative + * Signature: (JJ[Ljava/lang/String;[Ljava/lang/String;)V */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_add - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring key, jstring value) { +JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_applyToNative + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jobjectArray jKeyArr, jobjectArray jValueArr) { (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; OrtTensorRTProviderOptionsV2* opts = (OrtTensorRTProviderOptionsV2*) optionsHandle; - const char* keyStr = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL); - const char* valueStr = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL); - checkOrtStatus(jniEnv,api,api->UpdateTensorRTProviderOptions(opts, &keyStr, &valueStr, 1)); - (*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keyStr); - (*jniEnv)->ReleaseStringUTFChars(jniEnv,value,valueStr); + + size_t keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr); + const char** keys = allocarray(keyLength, sizeof(char*)); + const char** values = allocarray(keyLength, sizeof(char*)); + if ((keys == NULL) || (values == NULL)) { + if (keys != NULL) { + free(keys); + } + if (values != NULL) { + free(values); + } + throwOrtException(jniEnv, 1, "Not enough memory"); + } else { + // Copy out strings into UTF-8. + for (size_t i = 0; i < keyLength; i++) { + jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + keys[i] = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL); + jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jValueArr, i); + values[i] = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL); + } + // Write to the provider options. + checkOrtStatus(jniEnv,api,api->UpdateTensorRTProviderOptions(opts, keys, values, keyLength)); + // Release allocated strings. + for (size_t i = 0; i < keyLength; i++) { + jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + (*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keys[i]); + jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + (*jniEnv)->ReleaseStringUTFChars(jniEnv,value,values[i]); + } + } } /* diff --git a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java index 0e3bc15ba9c70..ec8edec3fedac 100644 --- a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java +++ b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -41,40 +41,53 @@ public void testCUDAOptions() throws OrtException { OrtSession.SessionOptions sessionOpts = new OrtSession.SessionOptions(); sessionOpts.addCUDA(cudaOpts); runProvider(OrtProvider.CUDA, sessionOpts); + sessionOpts.close(); + cudaOpts.close(); // Test invalid device num throws - assertThrows(IllegalArgumentException.class, () -> new OrtCUDAProviderOptions(-1)); + try (OrtCUDAProviderOptions invalidIdOpts = new OrtCUDAProviderOptions(-1)) { + assertThrows(IllegalArgumentException.class, invalidIdOpts::applyToNative); + } // Test invalid key name throws - OrtCUDAProviderOptions invalidKeyOpts = new OrtCUDAProviderOptions(0); - assertThrows( - OrtException.class, () -> invalidKeyOpts.add("not_a_real_provider_option", "not a number")); + try (OrtCUDAProviderOptions invalidKeyOpts = new OrtCUDAProviderOptions(0)) { + invalidKeyOpts.add("not_a_real_provider_option", "not a number"); + assertThrows(OrtException.class, invalidKeyOpts::applyToNative); + } // Test invalid value throws - OrtCUDAProviderOptions invalidValueOpts = new OrtCUDAProviderOptions(0); - assertThrows(OrtException.class, () -> invalidValueOpts.add("gpu_mem_limit", "not a number")); + try (OrtCUDAProviderOptions invalidValueOpts = new OrtCUDAProviderOptions(0)) { + invalidValueOpts.add("gpu_mem_limit", "not a number"); + assertThrows(OrtException.class, invalidValueOpts::applyToNative); + } } @Test @EnabledIfSystemProperty(named = "USE_TENSORRT", matches = "1") public void testTensorRT() throws OrtException { // Test standard options - OrtTensorRTProviderOptions cudaOpts = new OrtTensorRTProviderOptions(0); - cudaOpts.add("trt_max_workspace_size", "" + (512 * 1024 * 1024)); + OrtTensorRTProviderOptions rtOpts = new OrtTensorRTProviderOptions(0); + rtOpts.add("trt_max_workspace_size", "" + (512 * 1024 * 1024)); OrtSession.SessionOptions sessionOpts = new OrtSession.SessionOptions(); - sessionOpts.addTensorrt(cudaOpts); + sessionOpts.addTensorrt(rtOpts); runProvider(OrtProvider.TENSOR_RT, sessionOpts); + sessionOpts.close(); + rtOpts.close(); // Test invalid device num throws - assertThrows(IllegalArgumentException.class, () -> new OrtTensorRTProviderOptions(-1)); + try (OrtTensorRTProviderOptions invalidIdOpts = new OrtTensorRTProviderOptions(-1)) { + assertThrows(IllegalArgumentException.class, invalidIdOpts::applyToNative); + } // Test invalid key name throws - OrtTensorRTProviderOptions invalidKeyOpts = new OrtTensorRTProviderOptions(0); - assertThrows( - OrtException.class, () -> invalidKeyOpts.add("not_a_real_provider_option", "not a number")); + try (OrtTensorRTProviderOptions invalidKeyOpts = new OrtTensorRTProviderOptions(0)) { + invalidKeyOpts.add("not_a_real_provider_option", "not a number"); + assertThrows(OrtException.class, invalidKeyOpts::applyToNative); + } // Test invalid value throws - OrtTensorRTProviderOptions invalidValueOpts = new OrtTensorRTProviderOptions(0); - assertThrows( - OrtException.class, () -> invalidValueOpts.add("trt_max_workspace_size", "not a number")); + try (OrtTensorRTProviderOptions invalidValueOpts = new OrtTensorRTProviderOptions(0)) { + invalidValueOpts.add("trt_max_workspace_size", "not a number"); + assertThrows(OrtException.class, invalidValueOpts::applyToNative); + } } private static void runProvider(OrtProvider provider, OrtSession.SessionOptions options) From 70cb061b00494785bc547d958863bd31c93d983b Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Thu, 2 May 2024 20:15:28 -0400 Subject: [PATCH 3/6] Updating the documentation on UpdateCUDAProviderOptions and UpdateTensorRTProviderOptions. --- include/onnxruntime/core/session/onnxruntime_c_api.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index de3013484b1ab..b4be501d3f00a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -2937,7 +2937,7 @@ struct OrtApi { * * Please refer to https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#cc * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtTensorRTProviderOptionsV2 - * and value should be its related range. + * and value should be its related range. Recreates the options and only sets the supplied values. * * For example, key="trt_max_workspace_size" and value="2147483648" * @@ -3433,7 +3433,7 @@ struct OrtApi { * * Please refer to https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtCUDAProviderOptionsV2 - * and value should be its related range. + * and value should be its related range. Recreates the options and only sets the supplied values. * * For example, key="device_id" and value="0" * From 9fadc78b59ceb601485762cfcddda31193f0201f Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 3 May 2024 10:13:10 -0400 Subject: [PATCH 4/6] Small fixes for the tests, also tidying up native memory management. --- .../ai_onnxruntime_providers_OrtCUDAProviderOptions.c | 10 ++++++---- ..._onnxruntime_providers_OrtTensorRTProviderOptions.c | 10 ++++++---- .../ai/onnxruntime/providers/ProviderOptionsTest.java | 10 +++------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c b/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c index 307b77304a059..dc632515c549d 100644 --- a/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c +++ b/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c @@ -34,14 +34,14 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_appl OrtCUDAProviderOptionsV2* opts = (OrtCUDAProviderOptionsV2*) optionsHandle; size_t keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr); - const char** keys = allocarray(keyLength, sizeof(char*)); - const char** values = allocarray(keyLength, sizeof(char*)); + const char** keys = (const char**) allocarray(keyLength, sizeof(const char*)); + const char** values = (const char**) allocarray(keyLength, sizeof(const char*)); if ((keys == NULL) || (values == NULL)) { if (keys != NULL) { - free(keys); + free((void*)keys); } if (values != NULL) { - free(values); + free((void*)values); } throwOrtException(jniEnv, 1, "Not enough memory"); } else { @@ -61,6 +61,8 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_appl jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); (*jniEnv)->ReleaseStringUTFChars(jniEnv,value,values[i]); } + free((void*)keys); + free((void*)values); } } diff --git a/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c b/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c index 96a267347a301..c56722287cad4 100644 --- a/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c +++ b/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c @@ -33,14 +33,14 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_ OrtTensorRTProviderOptionsV2* opts = (OrtTensorRTProviderOptionsV2*) optionsHandle; size_t keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr); - const char** keys = allocarray(keyLength, sizeof(char*)); - const char** values = allocarray(keyLength, sizeof(char*)); + const char** keys = (const char**) allocarray(keyLength, sizeof(const char*)); + const char** values = (const char**) allocarray(keyLength, sizeof(const char*)); if ((keys == NULL) || (values == NULL)) { if (keys != NULL) { - free(keys); + free((void*)keys); } if (values != NULL) { - free(values); + free((void*)values); } throwOrtException(jniEnv, 1, "Not enough memory"); } else { @@ -60,6 +60,8 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_ jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); (*jniEnv)->ReleaseStringUTFChars(jniEnv,value,values[i]); } + free((void*)keys); + free((void*)values); } } diff --git a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java index ec8edec3fedac..8dfea92c9ff10 100644 --- a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java +++ b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java @@ -45,9 +45,7 @@ public void testCUDAOptions() throws OrtException { cudaOpts.close(); // Test invalid device num throws - try (OrtCUDAProviderOptions invalidIdOpts = new OrtCUDAProviderOptions(-1)) { - assertThrows(IllegalArgumentException.class, invalidIdOpts::applyToNative); - } + assertThrows(IllegalArgumentException.class, () -> new OrtCUDAProviderOptions(-1)); // Test invalid key name throws try (OrtCUDAProviderOptions invalidKeyOpts = new OrtCUDAProviderOptions(0)) { @@ -74,9 +72,7 @@ public void testTensorRT() throws OrtException { rtOpts.close(); // Test invalid device num throws - try (OrtTensorRTProviderOptions invalidIdOpts = new OrtTensorRTProviderOptions(-1)) { - assertThrows(IllegalArgumentException.class, invalidIdOpts::applyToNative); - } + assertThrows(IllegalArgumentException.class, () -> new OrtTensorRTProviderOptions(-1)); // Test invalid key name throws try (OrtTensorRTProviderOptions invalidKeyOpts = new OrtTensorRTProviderOptions(0)) { @@ -109,7 +105,7 @@ private static void runProvider(OrtProvider provider, OrtSession.SessionOptions OnnxValue resultTensor = result.get(0); float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue()); assertEquals(expectedOutput.length, resultArray.length); - assertArrayEquals(expectedOutput, resultArray, 1e-5f); + assertArrayEquals(expectedOutput, resultArray, 1e-3f); } catch (OrtException e) { throw new IllegalStateException("Failed to execute a scoring operation", e); } From 7880beafaa074c3909cf5fd6152778aaf6ba666c Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 3 May 2024 10:19:17 -0400 Subject: [PATCH 5/6] Relaxing CUDA precision to make the tests pass on a H100. --- java/src/test/java/ai/onnxruntime/InferenceTest.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index ac65cbab146bf..3340a2e5e9f3a 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -678,6 +678,9 @@ private void runProvider(OrtProvider provider) throws OrtException { if (provider == OrtProvider.CORE_ML) { // CoreML gives slightly different answers on a 2020 13" M1 MBP assertArrayEquals(expectedOutput, resultArray, 1e-2f); + } else if (provider == OrtProvider.CUDA) { + // CUDA gives slightly different answers on a H100 with CUDA 12.2 + assertArrayEquals(expectedOutput, resultArray, 1e-3f); } else { assertArrayEquals(expectedOutput, resultArray, 1e-5f); } From 281e142a65c1ceb6a48c563dc6c3242e7e671ad1 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sat, 4 May 2024 21:12:33 -0400 Subject: [PATCH 6/6] Fix Windows compile error. --- .../ai_onnxruntime_providers_OrtCUDAProviderOptions.c | 6 +++--- .../ai_onnxruntime_providers_OrtTensorRTProviderOptions.c | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c b/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c index dc632515c549d..46df515c2e235 100644 --- a/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c +++ b/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c @@ -33,7 +33,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_appl const OrtApi* api = (const OrtApi*)apiHandle; OrtCUDAProviderOptionsV2* opts = (OrtCUDAProviderOptionsV2*) optionsHandle; - size_t keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr); + jsize keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr); const char** keys = (const char**) allocarray(keyLength, sizeof(const char*)); const char** values = (const char**) allocarray(keyLength, sizeof(const char*)); if ((keys == NULL) || (values == NULL)) { @@ -46,7 +46,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_appl throwOrtException(jniEnv, 1, "Not enough memory"); } else { // Copy out strings into UTF-8. - for (size_t i = 0; i < keyLength; i++) { + for (jsize i = 0; i < keyLength; i++) { jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); keys[i] = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL); jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jValueArr, i); @@ -55,7 +55,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_appl // Write to the provider options. checkOrtStatus(jniEnv,api,api->UpdateCUDAProviderOptions(opts, keys, values, keyLength)); // Release allocated strings. - for (size_t i = 0; i < keyLength; i++) { + for (jsize i = 0; i < keyLength; i++) { jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); (*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keys[i]); jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); diff --git a/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c b/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c index c56722287cad4..404a80f118306 100644 --- a/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c +++ b/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c @@ -32,7 +32,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_ const OrtApi* api = (const OrtApi*)apiHandle; OrtTensorRTProviderOptionsV2* opts = (OrtTensorRTProviderOptionsV2*) optionsHandle; - size_t keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr); + jsize keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr); const char** keys = (const char**) allocarray(keyLength, sizeof(const char*)); const char** values = (const char**) allocarray(keyLength, sizeof(const char*)); if ((keys == NULL) || (values == NULL)) { @@ -45,7 +45,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_ throwOrtException(jniEnv, 1, "Not enough memory"); } else { // Copy out strings into UTF-8. - for (size_t i = 0; i < keyLength; i++) { + for (jsize i = 0; i < keyLength; i++) { jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); keys[i] = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL); jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jValueArr, i); @@ -54,7 +54,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_ // Write to the provider options. checkOrtStatus(jniEnv,api,api->UpdateTensorRTProviderOptions(opts, keys, values, keyLength)); // Release allocated strings. - for (size_t i = 0; i < keyLength; i++) { + for (jsize i = 0; i < keyLength; i++) { jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); (*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keys[i]); jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i);