Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] Fix NPE bug in getContextClassLoader(). #1445

Merged
merged 1 commit into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.repository.zoo.ModelLoader;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.util.ClassLoaderUtils;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URI;
Expand Down Expand Up @@ -137,7 +138,7 @@ public Repository newInstance(String name, URI uri) {
if (p.startsWith("/")) {
p = p.substring(1);
}
URL u = Thread.currentThread().getContextClassLoader().getResource(p);
URL u = ClassLoaderUtils.getContextClassLoader().getResource(p);
if (u == null) {
throw new IllegalArgumentException("Resource not found: " + uri);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ protected TranslatorFactory getTranslatorFactory(

String factoryClass = (String) arguments.get("translatorFactory");
if (factoryClass != null) {
ClassLoader cl = Thread.currentThread().getContextClassLoader();
ClassLoader cl = ClassLoaderUtils.getContextClassLoader();
factory = ClassLoaderUtils.initClass(cl, factoryClass);
}
return factory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import ai.djl.modality.cv.translator.ImageServingTranslator;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Pair;
import java.io.File;
import java.io.IOException;
Expand Down Expand Up @@ -107,7 +108,7 @@ private ServingTranslator findTranslator(Path path, String className) {
urls.add(p.toUri().toURL());
}

ClassLoader parentCl = Thread.currentThread().getContextClassLoader();
ClassLoader parentCl = ClassLoaderUtils.getContextClassLoader();
ClassLoader cl = new URLClassLoader(urls.toArray(new URL[0]), parentCl);
if (className != null && !className.isEmpty()) {
return initTranslator(cl, className);
Expand Down
15 changes: 14 additions & 1 deletion api/src/main/java/ai/djl/util/ClassLoaderUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public static <T> T findImplementation(Path path, String className) {
urls[index++] = p.toUri().toURL();
}

final ClassLoader contextCl = Thread.currentThread().getContextClassLoader();
final ClassLoader contextCl = getContextClassLoader();
ClassLoader cl =
AccessController.doPrivileged(
(PrivilegedAction<ClassLoader>)
Expand Down Expand Up @@ -154,4 +154,17 @@ public static <T> T initClass(ClassLoader cl, String className) {
}
return null;
}

/**
* Returns the context class loader if available.
*
* @return the context class loader if available
*/
public static ClassLoader getContextClassLoader() {
ClassLoader cl = Thread.currentThread().getContextClassLoader();
if (cl == null) {
return ClassLoaderUtils.class.getClassLoader(); // NOPMD
}
return cl;
}
}
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/util/Platform.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public static Platform detectPlatform(String engine) {
String nativeProp = "native/lib/" + engine + ".properties";
Enumeration<URL> urls;
try {
urls = Thread.currentThread().getContextClassLoader().getResources(nativeProp);
urls = ClassLoaderUtils.getContextClassLoader().getResources(nativeProp);
} catch (IOException e) {
throw new AssertionError("Failed to list property files.", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.training.dataset.Dataset;
import ai.djl.translate.Pipeline;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.JsonUtils;
import ai.djl.util.Progress;
import java.io.IOException;
Expand Down Expand Up @@ -101,7 +102,7 @@ public void prepare(Progress progress) throws IOException {
}

private void loadSynset() {
ClassLoader cl = Thread.currentThread().getContextClassLoader();
ClassLoader cl = ClassLoaderUtils.getContextClassLoader();
try (InputStream classStream = cl.getResourceAsStream("imagenet/classes.json")) {
if (classStream == null) {
throw new AssertionError("Missing imagenet/classes.json in jar resource");
Expand Down
5 changes: 3 additions & 2 deletions testing/src/main/java/ai/djl/testing/CoverageUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*/
package ai.djl.testing;

import ai.djl.util.ClassLoaderUtils;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Constructor;
Expand Down Expand Up @@ -98,13 +99,13 @@ public static void testGetterSetters(Class<?> baseClass)

private static List<Class<?>> getClasses(Class<?> clazz)
throws IOException, ReflectiveOperationException, URISyntaxException {
ClassLoader appClassLoader = Thread.currentThread().getContextClassLoader();
ClassLoader appClassLoader = ClassLoaderUtils.getContextClassLoader();
Field field = appClassLoader.getClass().getDeclaredField("ucp");
field.setAccessible(true);
Object ucp = field.get(appClassLoader);
Method method = ucp.getClass().getDeclaredMethod("getURLs");
URL[] urls = (URL[]) method.invoke(ucp);
ClassLoader cl = new TestClassLoader(urls, Thread.currentThread().getContextClassLoader());
ClassLoader cl = new TestClassLoader(urls, ClassLoaderUtils.getContextClassLoader());

URL url = clazz.getProtectionDomain().getCodeSource().getLocation();
String path = url.getPath();
Expand Down