From 091b81eea896482420871e5df203a193f3de3b07 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:27:29 -0700 Subject: [PATCH] Add Java API to append QNN EP. --- .../main/java/ai/onnxruntime/OrtProvider.java | 4 +- .../main/java/ai/onnxruntime/OrtSession.java | 48 ++++++++++++++----- .../java/ai/onnxruntime/InferenceTest.java | 16 +++++++ 3 files changed, 55 insertions(+), 13 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtProvider.java b/java/src/main/java/ai/onnxruntime/OrtProvider.java index ae9cb9f908629..0e2883fe23088 100644 --- a/java/src/main/java/ai/onnxruntime/OrtProvider.java +++ b/java/src/main/java/ai/onnxruntime/OrtProvider.java @@ -40,7 +40,9 @@ public enum OrtProvider { /** The XNNPACK execution provider. */ XNNPACK("XnnpackExecutionProvider"), /** The Azure remote endpoint execution provider. */ - AZURE("AzureExecutionProvider"); + AZURE("AzureExecutionProvider"), + /** The QNN execution provider. */ + QNN("QNNExecutionProvider"); private static final Map valueMap = new HashMap<>(values().length); diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 6d146d5857d3c..a441b7b9e26cc 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -1271,16 +1271,16 @@ public void addCoreML(EnumSet flags) throws OrtException { } /** - * Adds Xnnpack as an execution backend. Needs to list all options hereif a new option - * supported. current supported options: {} The maximum number of provider options is set to 128 - * (see addExecutionProvider's comment). This number is controlled by - * ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH in ai_onnxruntime_OrtSession_SessionOptions.c. If 128 is - * not enough, please increase it or implementing an incremental way to add more options. + * Adds the named execution provider (backend) as an execution backend. This generic function + * only allows a subset of execution providers. * - * @param providerOptions options pass to XNNPACK EP for initialization. + * @param providerName The name of the execution provider. + * @param providerOptions Configuration options for the execution provider. Refer to the + * specific execution provider's documentation. * @throws OrtException If there was an error in native code. */ - public void addXnnpack(Map providerOptions) throws OrtException { + private void addExecutionProvider(String providerName, Map providerOptions) + throws OrtException { checkClosed(); String[] providerOptionKey = new String[providerOptions.size()]; String[] providerOptionVal = new String[providerOptions.size()]; @@ -1291,7 +1291,35 @@ public void addXnnpack(Map providerOptions) throws OrtException i++; } addExecutionProvider( - OnnxRuntime.ortApiHandle, nativeHandle, "XNNPACK", providerOptionKey, providerOptionVal); + OnnxRuntime.ortApiHandle, + nativeHandle, + providerName, + providerOptionKey, + providerOptionVal); + } + + /** + * Adds XNNPACK as an execution backend. + * + * @param providerOptions Configuration options for the XNNPACK backend. Refer to the XNNPACK + * execution provider's documentation. + * @throws OrtException If there was an error in native code. + */ + public void addXnnpack(Map providerOptions) throws OrtException { + String xnnpackProviderName = "XNNPACK"; + addExecutionProvider(xnnpackProviderName, providerOptions); + } + + /** + * Adds QNN as an execution backend. + * + * @param providerOptions Configuration options for the QNN backend. Refer to the QNN execution + * provider's documentation. + * @throws OrtException If there was an error in native code. + */ + public void addQnn(Map providerOptions) throws OrtException { + String qnnProviderName = "QNN"; + addExecutionProvider(qnnProviderName, providerOptions); } private native void setExecutionMode(long apiHandle, long nativeHandle, int mode) @@ -1416,10 +1444,6 @@ private native void addArmNN(long apiHandle, long nativeHandle, int useArena) private native void addCoreML(long apiHandle, long nativeHandle, int coreMLFlags) throws OrtException; - /* - * The max length of providerOptionKey and providerOptionVal is 128, as specified by - * ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH (search ONNXRuntime PR #14067 for its location). - */ private native void addExecutionProvider( long apiHandle, long nativeHandle, diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 11141a3a65a3e..d6c0338ca5a9c 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -688,6 +688,19 @@ public void testDirectML() throws OrtException { runProvider(OrtProvider.DIRECT_ML); } + @Test + @EnabledIfSystemProperty(named = "USE_QNN", matches = "1") + public void testQNN() throws OrtException { + // Note: This currently only tests the API call to append the QNN EP. There's some additional + // setup required for the model to actually run with the QNN EP and not fall back to the CPU EP. + runProvider(OrtProvider.QNN); + } + + @Test + public void testDumpSystemProperties() { + System.getProperties().forEach((k, v) -> System.out.println(k + ":" + v)); + } + private void runProvider(OrtProvider provider) throws OrtException { EnumSet providers = OrtEnvironment.getAvailableProviders(); assertTrue(providers.size() > 1); @@ -1986,6 +1999,9 @@ private static SqueezeNetTuple openSessionSqueezeNet(EnumSet provid case XNNPACK: options.addXnnpack(Collections.emptyMap()); break; + case QNN: + options.addQnn(Collections.emptyMap()); + break; case VITIS_AI: case RK_NPU: case MI_GRAPH_X: