Skip to content

Commit

Permalink
Mitigate concurrent classes definition in RunnerClassLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
mariofusco committed Sep 4, 2024
1 parent 3c731e2 commit 3cb0952
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,25 @@ public interface ClassLoadingResource {
default void resetInternalCaches() {
//no-op
}

/**
* Notifies this ClassLoadingResource that the definition of a class is about to begin.
*
* @param className The name of the class to be defined.
* @return true if the ClassLoader should actually attempt the definition of this class, false if the definition of the same
* class has been already requested by a different thread and then the current thread should wait for that
* definition to be completed and load the class without redefining it.
*/
default boolean definingClass(String className) {
return true;
}

/**
* Notifies this ClassLoadingResource that the definition of a class is terminated.
*
* @param className The name of the class to be defined.
*/
default void classDefined(String className) {
//no-op
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.quarkus.bootstrap.runner;

import static io.quarkus.bootstrap.runner.VirtualThreadSupport.isVirtualThread;

import java.io.IOException;
import java.io.InputStream;
import java.net.MalformedURLException;
Expand Down Expand Up @@ -30,6 +32,9 @@ public class JarResource implements ClassLoadingResource {
final Path jarPath;
final AtomicReference<CompletableFuture<JarFileReference>> jarFileReference = new AtomicReference<>();

// Single-entry cache of the name of the class currently loaded
private final AtomicReference<String> loadingClass = new AtomicReference<>();

public JarResource(ManifestInfo manifestInfo, Path jarPath) {
this.manifestInfo = manifestInfo;
this.jarPath = jarPath;
Expand All @@ -56,6 +61,51 @@ public byte[] getResourceData(String resource) {
return JarFileReference.withJarFile(this, resource, JarResourceDataProvider.INSTANCE);
}

@Override
public boolean definingClass(String className) {
if (isVirtualThread()) {
// Use full non-blocking algorithm for virtual threads
return ClassLoadingResource.super.definingClass(className);
}

if (loadingClass.compareAndSet(null, className)) {
// First thread trying to load this class, return true to signal that it has to be defined
return true;
}

if (className.equals(loadingClass.get())) {
try {
synchronized (this) {
// Another thread already started the definition of this class, wait for its completion
this.wait();
}
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
return false;
}

// The single value cache is already occupied by another class, use the full non-blocking algorithm
return true;
}

@Override
public void classDefined(String className) {
if (isVirtualThread()) {
// Use full non-blocking algorithm for virtual threads
ClassLoadingResource.super.classDefined(className);
return;
}

// The definition of the class has been completed, so make the single value cache available again ...
if (loadingClass.compareAndSet(className, null)) {
synchronized (this) {
// ... and notify other threads eventually waiting that they can now load the defined class
this.notifyAll();
}
}
}

private static class JarResourceDataProvider implements JarFileReference.JarFileConsumer<byte[]> {
private static final JarResourceDataProvider INSTANCE = new JarResourceDataProvider();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import static io.quarkus.commons.classloading.ClassLoaderHelper.isInJdkPackage;

import java.net.URL;
import java.security.ProtectionDomain;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
Expand Down Expand Up @@ -92,48 +93,47 @@ public Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundExce
if (loaded != null) {
return loaded;
}
final ClassLoadingResource[] resources;
if (packageName == null) {
resources = resourceDirectoryMap.get("");
} else {
String dirName = packageName.replace('.', '/');
resources = resourceDirectoryMap.get(dirName);
}

final ClassLoadingResource[] resources = findLoadingResources(packageName);
if (resources != null) {
String classResource = fromClassNameToResourceName(name);
for (ClassLoadingResource resource : resources) {
accessingResource(resource);
byte[] data = resource.getResourceData(classResource);
if (data == null) {
continue;
if (data != null) {
return findOrDefineClass(resource, packageName, name, data);
}
definePackage(packageName, resources);
return defineClass(name, data, resource);
}
}
return getParent().loadClass(name);
}

private void definePackage(String pkgName, ClassLoadingResource[] resources) {
private ClassLoadingResource[] findLoadingResources(String packageName) {
if (packageName == null) {
return resourceDirectoryMap.get("");
}
String dirName = packageName.replace('.', '/');
return resourceDirectoryMap.get(dirName);
}

private void definePackage(String pkgName, ClassLoadingResource classPathElement) {
if ((pkgName != null) && getDefinedPackage(pkgName) == null) {
for (ClassLoadingResource classPathElement : resources) {
ManifestInfo mf = classPathElement.getManifestInfo();
if (mf != null) {
try {
definePackage(pkgName, mf.getSpecTitle(),
mf.getSpecVersion(),
mf.getSpecVendor(),
mf.getImplTitle(),
mf.getImplVersion(),
mf.getImplVendor(), null);
} catch (IllegalArgumentException e) {
var loaded = getDefinedPackage(pkgName);
if (loaded == null) {
throw e;
}
ManifestInfo mf = classPathElement.getManifestInfo();
if (mf != null) {
try {
definePackage(pkgName, mf.getSpecTitle(),
mf.getSpecVersion(),
mf.getSpecVendor(),
mf.getImplTitle(),
mf.getImplVersion(),
mf.getImplVendor(), null);
} catch (IllegalArgumentException e) {
var loaded = getDefinedPackage(pkgName);
if (loaded == null) {
throw e;
}
return;
}
return;
}
try {
definePackage(pkgName, null, null, null, null, null, null, null);
Expand All @@ -146,12 +146,21 @@ private void definePackage(String pkgName, ClassLoadingResource[] resources) {
}
}

private Class<?> defineClass(String name, byte[] data, ClassLoadingResource resource) {
Class<?> loaded;
private Class<?> findOrDefineClass(ClassLoadingResource resource, String packageName, String name, byte[] data) {
if (resource.definingClass(name)) {
definePackage(packageName, resource);
Class<?> definedClass = defineClass(name, data, resource.getProtectionDomain());
resource.classDefined(name);
return definedClass;
}
return findLoadedClass(name);
}

private Class<?> defineClass(String name, byte[] data, ProtectionDomain protectionDomain) {
try {
return defineClass(name, data, 0, data.length, resource.getProtectionDomain());
return defineClass(name, data, 0, data.length, protectionDomain);
} catch (LinkageError e) {
loaded = findLoadedClass(name);
Class<?> loaded = findLoadedClass(name);
if (loaded != null) {
return loaded;
}
Expand Down

0 comments on commit 3cb0952

Please sign in to comment.