Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor JNI native dependency loading to allow returning of library path #15566

Merged
merged 2 commits into from
Apr 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 73 additions & 11 deletions java/src/main/java/ai/rapids/cudf/NativeDepsLoader.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -77,7 +77,7 @@ public class NativeDepsLoader {
public static synchronized void loadNativeDeps() {
if (!loaded) {
try {
loadNativeDeps(loadOrder);
loadNativeDeps(loadOrder, preserveDepsAfterLoad);
loaded = true;
} catch (Throwable t) {
log.error("Could not load cudf jni library...", t);
Expand Down Expand Up @@ -122,21 +122,65 @@ public static synchronized void loadNativeDeps() {
* @throws IOException on any error trying to load the libraries.
*/
public static void loadNativeDeps(String[] loadOrder) throws IOException {
loadNativeDeps(loadOrder, preserveDepsAfterLoad);
}

/**
* Allows other libraries to reuse the same native deps loading logic. Libraries will be searched
* for under ${os.arch}/${os.name}/ in the class path using the class loader for this class.
* <br/>
* Because this just loads the libraries and loading the libraries themselves needs to be a
* singleton operation it is recommended that any library using this provide their own wrapper
* function similar to
* <pre>
* private static boolean loaded = false;
* static synchronized void loadNativeDeps() {
* if (!loaded) {
* try {
* // If you also depend on the cudf liobrary being loaded, be sure it is loaded
* // first
* ai.rapids.cudf.NativeDepsLoader.loadNativeDeps();
* ai.rapids.cudf.NativeDepsLoader.loadNativeDeps(new String[]{...});
* loaded = true;
* } catch (Throwable t) {
* log.error("Could not load ...", t);
* }
* }
* }
* </pre>
* This function should be called from the static initialization block of any class that uses
* JNI. For example
* <pre>
* public class UsesJNI {
* static {
* MyNativeDepsLoader.loadNativeDeps();
* }
* }
* </pre>
* @param loadOrder the base name of the libraries. For example libfoo.so would be passed in as
* "foo". The libraries are loaded in the order provided.
* @param preserveDeps if false the dependencies will be deleted immediately after loading
* rather than on exit.
* @throws IOException on any error trying to load the libraries.
*/
public static void loadNativeDeps(String[] loadOrder, boolean preserveDeps) throws IOException {
String os = System.getProperty("os.name");
String arch = System.getProperty("os.arch");

for (String toLoad : loadOrder) {
loadDep(os, arch, toLoad);
loadDep(os, arch, toLoad, preserveDeps);
}
}

/**
* Load native dependencies in stages, where the dependency libraries in each stage
* are loaded only after all libraries in earlier stages have completed loading.
* @param loadOrder array of stages with an array of dependency library names in each stage
* @param preserveDeps if false the dependencies will be deleted immediately after loading
* rather than on exit.
* @throws IOException on any error trying to load the libraries
*/
private static void loadNativeDeps(String[][] loadOrder) throws IOException {
private static void loadNativeDeps(String[][] loadOrder, boolean preserveDeps) throws IOException {
String os = System.getProperty("os.name");
String arch = System.getProperty("os.arch");

Expand All @@ -161,7 +205,7 @@ private static void loadNativeDeps(String[][] loadOrder) throws IOException {
// Submit all dependencies in the stage to be loaded in parallel
loadCompletionFutures.clear();
for (Future<File> fileFuture : stageFileFutures) {
loadCompletionFutures.add(executor.submit(() -> loadDep(fileFuture)));
loadCompletionFutures.add(executor.submit(() -> loadDep(fileFuture, preserveDeps)));
}

// Wait for all dependencies in this stage to have been loaded
Expand All @@ -177,28 +221,46 @@ private static void loadNativeDeps(String[][] loadOrder) throws IOException {
executor.shutdownNow();
}

private static void loadDep(String os, String arch, String baseName) throws IOException {
/**
* Allows other libraries to reuse the same native deps loading logic. Library will be searched
* for under ${os.arch}/${os.name}/ in the class path using the class loader for this class.
* @param depName the base name of the library. For example libfoo.so would be passed in as
* "foo". The libraries are loaded in the order provided.
* @param preserveDep if false the dependencies will be deleted immediately after loading
* rather than on exit.
* @return path where the dependency was loaded
* @throws IOException on any error trying to load the libraries.
*/
public static File loadNativeDep(String depName, boolean preserveDep) throws IOException {
String os = System.getProperty("os.name");
String arch = System.getProperty("os.arch");
return loadDep(os, arch, depName, preserveDep);
}

private static File loadDep(String os, String arch, String baseName, boolean preserveDep)
throws IOException {
File path = createFile(os, arch, baseName);
loadDep(path);
loadDep(path, preserveDep);
return path;
}

/** Load a library at the specified path */
private static void loadDep(File path) {
private static void loadDep(File path, boolean preserveDep) {
System.load(path.getAbsolutePath());
if (!preserveDepsAfterLoad) {
if (!preserveDep) {
path.delete();
}
}

/** Load a library, waiting for the specified future to produce the path before loading */
private static void loadDep(Future<File> fileFuture) {
private static void loadDep(Future<File> fileFuture, boolean preserveDep) {
File path;
try {
path = fileFuture.get();
} catch (ExecutionException | InterruptedException e) {
throw new RuntimeException("Error loading dependencies", e);
}
loadDep(path);
loadDep(path, preserveDep);
}

/** Extract the contents of a library resource into a temporary file */
Expand Down
Loading