Skip to content

Commit

Permalink
Add Java API to append QNN EP.
Browse files Browse the repository at this point in the history
  • Loading branch information
edgchen1 committed Sep 24, 2024
1 parent 8d2d407 commit 091b81e
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 13 deletions.
4 changes: 3 additions & 1 deletion java/src/main/java/ai/onnxruntime/OrtProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ public enum OrtProvider {
/** The XNNPACK execution provider. */
XNNPACK("XnnpackExecutionProvider"),
/** The Azure remote endpoint execution provider. */
AZURE("AzureExecutionProvider");
AZURE("AzureExecutionProvider"),
/** The QNN execution provider. */
QNN("QNNExecutionProvider");

private static final Map<String, OrtProvider> valueMap = new HashMap<>(values().length);

Expand Down
48 changes: 36 additions & 12 deletions java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -1271,16 +1271,16 @@ public void addCoreML(EnumSet<CoreMLFlags> flags) throws OrtException {
}

/**
* Adds Xnnpack as an execution backend. Needs to list all options hereif a new option
* supported. current supported options: {} The maximum number of provider options is set to 128
* (see addExecutionProvider's comment). This number is controlled by
* ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH in ai_onnxruntime_OrtSession_SessionOptions.c. If 128 is
* not enough, please increase it or implementing an incremental way to add more options.
* Adds the named execution provider (backend) as an execution backend. This generic function
* only allows a subset of execution providers.
*
* @param providerOptions options pass to XNNPACK EP for initialization.
* @param providerName The name of the execution provider.
* @param providerOptions Configuration options for the execution provider. Refer to the
* specific execution provider's documentation.
* @throws OrtException If there was an error in native code.
*/
public void addXnnpack(Map<String, String> providerOptions) throws OrtException {
private void addExecutionProvider(String providerName, Map<String, String> providerOptions)
throws OrtException {
checkClosed();
String[] providerOptionKey = new String[providerOptions.size()];
String[] providerOptionVal = new String[providerOptions.size()];
Expand All @@ -1291,7 +1291,35 @@ public void addXnnpack(Map<String, String> providerOptions) throws OrtException
i++;
}
addExecutionProvider(
OnnxRuntime.ortApiHandle, nativeHandle, "XNNPACK", providerOptionKey, providerOptionVal);
OnnxRuntime.ortApiHandle,
nativeHandle,
providerName,
providerOptionKey,
providerOptionVal);
}

/**
* Adds XNNPACK as an execution backend.
*
* @param providerOptions Configuration options for the XNNPACK backend. Refer to the XNNPACK
* execution provider's documentation.
* @throws OrtException If there was an error in native code.
*/
public void addXnnpack(Map<String, String> providerOptions) throws OrtException {
String xnnpackProviderName = "XNNPACK";
addExecutionProvider(xnnpackProviderName, providerOptions);
}

/**
* Adds QNN as an execution backend.
*
* @param providerOptions Configuration options for the QNN backend. Refer to the QNN execution
* provider's documentation.
* @throws OrtException If there was an error in native code.
*/
public void addQnn(Map<String, String> providerOptions) throws OrtException {
String qnnProviderName = "QNN";
addExecutionProvider(qnnProviderName, providerOptions);
}

private native void setExecutionMode(long apiHandle, long nativeHandle, int mode)
Expand Down Expand Up @@ -1416,10 +1444,6 @@ private native void addArmNN(long apiHandle, long nativeHandle, int useArena)
private native void addCoreML(long apiHandle, long nativeHandle, int coreMLFlags)
throws OrtException;

/*
* The max length of providerOptionKey and providerOptionVal is 128, as specified by
* ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH (search ONNXRuntime PR #14067 for its location).
*/
private native void addExecutionProvider(
long apiHandle,
long nativeHandle,
Expand Down
16 changes: 16 additions & 0 deletions java/src/test/java/ai/onnxruntime/InferenceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,19 @@ public void testDirectML() throws OrtException {
runProvider(OrtProvider.DIRECT_ML);
}

@Test
@EnabledIfSystemProperty(named = "USE_QNN", matches = "1")
public void testQNN() throws OrtException {
// Note: This currently only tests the API call to append the QNN EP. There's some additional
// setup required for the model to actually run with the QNN EP and not fall back to the CPU EP.
runProvider(OrtProvider.QNN);
}

@Test
public void testDumpSystemProperties() {
System.getProperties().forEach((k, v) -> System.out.println(k + ":" + v));
}

private void runProvider(OrtProvider provider) throws OrtException {
EnumSet<OrtProvider> providers = OrtEnvironment.getAvailableProviders();
assertTrue(providers.size() > 1);
Expand Down Expand Up @@ -1986,6 +1999,9 @@ private static SqueezeNetTuple openSessionSqueezeNet(EnumSet<OrtProvider> provid
case XNNPACK:
options.addXnnpack(Collections.emptyMap());
break;
case QNN:
options.addQnn(Collections.emptyMap());
break;
case VITIS_AI:
case RK_NPU:
case MI_GRAPH_X:
Expand Down

0 comments on commit 091b81e

Please sign in to comment.