diff --git a/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java b/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java index a1c0ad3ac0b..df3fd368348 100644 --- a/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java +++ b/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java @@ -14,6 +14,7 @@ import ai.djl.repository.zoo.ModelLoader; import ai.djl.repository.zoo.ModelZoo; +import ai.djl.util.ClassLoaderUtils; import java.io.IOException; import java.net.MalformedURLException; import java.net.URI; @@ -137,7 +138,7 @@ public Repository newInstance(String name, URI uri) { if (p.startsWith("/")) { p = p.substring(1); } - URL u = Thread.currentThread().getContextClassLoader().getResource(p); + URL u = ClassLoaderUtils.getContextClassLoader().getResource(p); if (u == null) { throw new IllegalArgumentException("Resource not found: " + uri); } diff --git a/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java b/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java index 39f16d0cdb5..af6d460766b 100644 --- a/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java +++ b/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java @@ -228,7 +228,7 @@ protected TranslatorFactory getTranslatorFactory( String factoryClass = (String) arguments.get("translatorFactory"); if (factoryClass != null) { - ClassLoader cl = Thread.currentThread().getContextClassLoader(); + ClassLoader cl = ClassLoaderUtils.getContextClassLoader(); factory = ClassLoaderUtils.initClass(cl, factoryClass); } return factory; diff --git a/api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java b/api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java index 92f84161f70..11353777b3e 100644 --- a/api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java +++ b/api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java @@ -20,6 +20,7 @@ import ai.djl.modality.cv.translator.ImageServingTranslator; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; +import ai.djl.util.ClassLoaderUtils; import ai.djl.util.Pair; import java.io.File; import java.io.IOException; @@ -107,7 +108,7 @@ private ServingTranslator findTranslator(Path path, String className) { urls.add(p.toUri().toURL()); } - ClassLoader parentCl = Thread.currentThread().getContextClassLoader(); + ClassLoader parentCl = ClassLoaderUtils.getContextClassLoader(); ClassLoader cl = new URLClassLoader(urls.toArray(new URL[0]), parentCl); if (className != null && !className.isEmpty()) { return initTranslator(cl, className); diff --git a/api/src/main/java/ai/djl/util/ClassLoaderUtils.java b/api/src/main/java/ai/djl/util/ClassLoaderUtils.java index e45aaa26611..4c51e262912 100644 --- a/api/src/main/java/ai/djl/util/ClassLoaderUtils.java +++ b/api/src/main/java/ai/djl/util/ClassLoaderUtils.java @@ -68,7 +68,7 @@ public static T findImplementation(Path path, String className) { urls[index++] = p.toUri().toURL(); } - final ClassLoader contextCl = Thread.currentThread().getContextClassLoader(); + final ClassLoader contextCl = getContextClassLoader(); ClassLoader cl = AccessController.doPrivileged( (PrivilegedAction) @@ -154,4 +154,17 @@ public static T initClass(ClassLoader cl, String className) { } return null; } + + /** + * Returns the context class loader if available. + * + * @return the context class loader if available + */ + public static ClassLoader getContextClassLoader() { + ClassLoader cl = Thread.currentThread().getContextClassLoader(); + if (cl == null) { + return ClassLoaderUtils.class.getClassLoader(); // NOPMD + } + return cl; + } } diff --git a/api/src/main/java/ai/djl/util/Platform.java b/api/src/main/java/ai/djl/util/Platform.java index 363a0ab4252..130021a69ba 100644 --- a/api/src/main/java/ai/djl/util/Platform.java +++ b/api/src/main/java/ai/djl/util/Platform.java @@ -59,7 +59,7 @@ public static Platform detectPlatform(String engine) { String nativeProp = "native/lib/" + engine + ".properties"; Enumeration urls; try { - urls = Thread.currentThread().getContextClassLoader().getResources(nativeProp); + urls = ClassLoaderUtils.getContextClassLoader().getResources(nativeProp); } catch (IOException e) { throw new AssertionError("Failed to list property files.", e); } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageNet.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageNet.java index ef828033ad4..2179fee0650 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageNet.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageNet.java @@ -15,6 +15,7 @@ import ai.djl.modality.cv.transform.ToTensor; import ai.djl.training.dataset.Dataset; import ai.djl.translate.Pipeline; +import ai.djl.util.ClassLoaderUtils; import ai.djl.util.JsonUtils; import ai.djl.util.Progress; import java.io.IOException; @@ -101,7 +102,7 @@ public void prepare(Progress progress) throws IOException { } private void loadSynset() { - ClassLoader cl = Thread.currentThread().getContextClassLoader(); + ClassLoader cl = ClassLoaderUtils.getContextClassLoader(); try (InputStream classStream = cl.getResourceAsStream("imagenet/classes.json")) { if (classStream == null) { throw new AssertionError("Missing imagenet/classes.json in jar resource"); diff --git a/testing/src/main/java/ai/djl/testing/CoverageUtils.java b/testing/src/main/java/ai/djl/testing/CoverageUtils.java index 77da7c3784c..1ae9c62560d 100644 --- a/testing/src/main/java/ai/djl/testing/CoverageUtils.java +++ b/testing/src/main/java/ai/djl/testing/CoverageUtils.java @@ -12,6 +12,7 @@ */ package ai.djl.testing; +import ai.djl.util.ClassLoaderUtils; import java.io.File; import java.io.IOException; import java.lang.reflect.Constructor; @@ -98,13 +99,13 @@ public static void testGetterSetters(Class baseClass) private static List> getClasses(Class clazz) throws IOException, ReflectiveOperationException, URISyntaxException { - ClassLoader appClassLoader = Thread.currentThread().getContextClassLoader(); + ClassLoader appClassLoader = ClassLoaderUtils.getContextClassLoader(); Field field = appClassLoader.getClass().getDeclaredField("ucp"); field.setAccessible(true); Object ucp = field.get(appClassLoader); Method method = ucp.getClass().getDeclaredMethod("getURLs"); URL[] urls = (URL[]) method.invoke(ucp); - ClassLoader cl = new TestClassLoader(urls, Thread.currentThread().getContextClassLoader()); + ClassLoader cl = new TestClassLoader(urls, ClassLoaderUtils.getContextClassLoader()); URL url = clazz.getProtectionDomain().getCodeSource().getLocation(); String path = url.getPath();