Skip to content

Commit

Permalink
Allow TestCompiler to load cglib generated classes
Browse files Browse the repository at this point in the history
Update `TestCompiler` so that it can now load cglib generated classes.
This commit adds support to `DynamicJavaFileManager` so that it can
reference generated classes and adds a new lookup function to
`CompileWithTargetClassAccessClassLoader` to that it can load the
bytecode of generated classes directly.

See gh-29141

Co-authored-by: Phillip Webb <[email protected]>
  • Loading branch information
wilkinsona and philwebb committed Sep 20, 2022
1 parent 5277b1d commit 7168141
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.InputStream;
import java.net.URL;
import java.util.Enumeration;
import java.util.function.Function;

import org.springframework.lang.Nullable;

Expand All @@ -34,12 +35,19 @@ final class CompileWithTargetClassAccessClassLoader extends ClassLoader {

private final ClassLoader testClassLoader;

private Function<String, byte[]> classResourceLookup = name -> null;


public CompileWithTargetClassAccessClassLoader(ClassLoader testClassLoader) {
super(testClassLoader.getParent());
this.testClassLoader = testClassLoader;
}

// Invoked reflectively by DynamicClassLoader constructor
@SuppressWarnings("unused")
void setClassResourceLookup(Function<String, byte[]> classResourceLookup) {
this.classResourceLookup = classResourceLookup;
}

@Override
public Class<?> loadClass(String name) throws ClassNotFoundException {
Expand All @@ -51,25 +59,36 @@ public Class<?> loadClass(String name) throws ClassNotFoundException {

@Override
protected Class<?> findClass(String name) throws ClassNotFoundException {
byte[] bytes = findClassBytes(name);
return (bytes != null) ? defineClass(name, bytes, 0, bytes.length, null) : super.findClass(name);
}

@Nullable
private byte[] findClassBytes(String name) {
byte[] bytes = this.classResourceLookup.apply(name);
if (bytes != null) {
return bytes;
}
String resourceName = name.replace(".", "/") + ".class";
InputStream stream = this.testClassLoader.getResourceAsStream(resourceName);
if (stream != null) {
try (stream) {
byte[] bytes = stream.readAllBytes();
return defineClass(name, bytes, 0, bytes.length, null);
return stream.readAllBytes();
}
catch (IOException ex) {
// ignore
}
}
return super.findClass(name);
return null;
}


// Invoked reflectively by DynamicClassLoader.findDefineClassMethod(ClassLoader)
@SuppressWarnings("unused")
Class<?> defineClassWithTargetAccess(String name, byte[] b, int off, int len) {
return super.defineClass(name, b, off, len);
}


@Override
protected Enumeration<URL> findResources(String name) throws IOException {
return this.testClassLoader.getResources(name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

package org.springframework.aot.test.generate.compile;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URI;

Expand All @@ -31,23 +34,42 @@
*/
class DynamicClassFileObject extends SimpleJavaFileObject {

private volatile byte[] bytes = new byte[0];
private static final byte[] NO_BYTES = new byte[0];

private final String className;

private volatile byte[] bytes;


DynamicClassFileObject(String className) {
this(className, NO_BYTES);
}

DynamicClassFileObject(String className, byte[] bytes) {
super(URI.create("class:///" + className.replace('.', '/') + ".class"), Kind.CLASS);
this.className = className;
this.bytes = bytes;
}


@Override
public OutputStream openOutputStream() {
return new JavaClassOutputStream();
String getClassName() {
return this.className;
}

byte[] getBytes() {
return this.bytes;
}

@Override
public InputStream openInputStream() throws IOException {
return new ByteArrayInputStream(this.bytes);
}

@Override
public OutputStream openOutputStream() {
return new JavaClassOutputStream();
}


class JavaClassOutputStream extends ByteArrayOutputStream {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import java.net.URLStreamHandler;
import java.nio.charset.StandardCharsets;
import java.util.Enumeration;
import java.util.Map;
import java.util.function.Function;

import org.springframework.aot.test.generate.file.ClassFile;
import org.springframework.aot.test.generate.file.ClassFiles;
Expand All @@ -47,47 +49,65 @@ public class DynamicClassLoader extends ClassLoader {

private final ClassFiles classFiles;

private final Map<String, DynamicClassFileObject> compiledClasses;

@Nullable
private final Method defineClassMethod;


public DynamicClassLoader(ClassLoader parent, ResourceFiles resourceFiles,
ClassFiles classFiles) {
ClassFiles classFiles, Map<String, DynamicClassFileObject> compiledClasses) {

super(parent);
this.resourceFiles = resourceFiles;
this.classFiles = classFiles;
this.defineClassMethod = findDefineClassMethod(parent);
if (this.defineClassMethod != null) {
classFiles.forEach(this::defineClass);
}
}

@Nullable
private Method findDefineClassMethod(ClassLoader parent) {
this.compiledClasses = compiledClasses;
Class<? extends ClassLoader> parentClass = parent.getClass();
if (parentClass.getName().equals(CompileWithTargetClassAccessClassLoader.class.getName())) {
Method defineClassMethod = ReflectionUtils.findMethod(parentClass,
Method setClassResourceLookupMethod = ReflectionUtils.findMethod(parentClass,
"setClassResourceLookup", Function.class);
ReflectionUtils.makeAccessible(setClassResourceLookupMethod);
ReflectionUtils.invokeMethod(setClassResourceLookupMethod,
getParent(), (Function<String, byte[]>) this::findClassBytes);
this.defineClassMethod = ReflectionUtils.findMethod(parentClass,
"defineClassWithTargetAccess", String.class, byte[].class, int.class, int.class);
ReflectionUtils.makeAccessible(defineClassMethod);
return defineClassMethod;
ReflectionUtils.makeAccessible(this.defineClassMethod);
this.compiledClasses.forEach((name, file) -> defineClass(name, file.getBytes()));
}
else {
this.defineClassMethod = null;
}
return null;
}


@Override
protected Class<?> findClass(String name) throws ClassNotFoundException {
byte[] bytes = findClassBytes(name);
if (bytes != null) {
return defineClass(name, bytes);
}
return super.findClass(name);
}

@Nullable
private byte[] findClassBytes(String name) {
DynamicClassFileObject compiledClass = this.compiledClasses.get(name);
if(compiledClass != null) {
return compiledClass.getBytes();
}
return findClassFileBytes(name);
}

@Nullable
private byte[] findClassFileBytes(String name) {
ClassFile classFile = this.classFiles.get(name);
if (classFile != null) {
return defineClass(classFile);
return classFile.getContent();
}
return super.findClass(name);
return null;
}

private Class<?> defineClass(ClassFile classFile) {
String name = classFile.getName();
byte[] bytes = classFile.getContent();
private Class<?> defineClass(String name, byte[] bytes) {
if (this.defineClassMethod != null) {
return (Class<?>) ReflectionUtils.invokeMethod(this.defineClassMethod,
getParent(), name, bytes, 0, bytes.length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@

package org.springframework.aot.test.generate.compile;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
Expand All @@ -32,7 +29,6 @@
import javax.tools.JavaFileManager;
import javax.tools.JavaFileObject;
import javax.tools.JavaFileObject.Kind;
import javax.tools.SimpleJavaFileObject;

import org.springframework.aot.test.generate.file.ClassFile;
import org.springframework.aot.test.generate.file.ClassFiles;
Expand All @@ -48,19 +44,18 @@
*/
class DynamicJavaFileManager extends ForwardingJavaFileManager<JavaFileManager> {

private final ClassFiles existingClasses;

private final ClassLoader classLoader;

private final ClassFiles classFiles;

private final Map<String, DynamicClassFileObject> compiledClasses = Collections.synchronizedMap(
new LinkedHashMap<>());


DynamicJavaFileManager(JavaFileManager fileManager, ClassLoader classLoader,
ClassFiles existingClasses) {
DynamicJavaFileManager(JavaFileManager fileManager, ClassLoader classLoader, ClassFiles classFiles) {
super(fileManager);
this.classLoader = classLoader;
this.existingClasses = existingClasses;
this.classFiles = classFiles;
}


Expand All @@ -84,49 +79,27 @@ public Iterable<JavaFileObject> list(Location location, String packageName,
Set<Kind> kinds, boolean recurse) throws IOException {
List<JavaFileObject> result = new ArrayList<>();
if (kinds.contains(Kind.CLASS)) {
for (ClassFile existingClass : this.existingClasses) {
String existingPackageName = ClassUtils.getPackageName(existingClass.getName());
for (ClassFile candidate : this.classFiles) {
String existingPackageName = ClassUtils.getPackageName(candidate.getName());
if (existingPackageName.equals(packageName) || (recurse && existingPackageName.startsWith(packageName + "."))) {
result.add(new ClassFileJavaFileObject(existingClass));
result.add(new DynamicClassFileObject(candidate.getName(), candidate.getContent()));
}
}
}
Iterable<JavaFileObject> listed = super.list(location, packageName, kinds, recurse);
listed.forEach(result::add);
super.list(location, packageName, kinds, recurse).forEach(result::add);
return result;
}

@Override
public String inferBinaryName(Location location, JavaFileObject file) {
if (file instanceof ClassFileJavaFileObject classFile) {
return classFile.getClassName();
if (file instanceof DynamicClassFileObject dynamicClassFileObject) {
return dynamicClassFileObject.getClassName();
}
return super.inferBinaryName(location, file);
}

ClassFiles getClassFiles() {
return this.existingClasses.and(this.compiledClasses.entrySet().stream().map(entry ->
ClassFile.of(entry.getKey(), entry.getValue().getBytes())).toList());
}

private static final class ClassFileJavaFileObject extends SimpleJavaFileObject {

private final ClassFile classFile;

private ClassFileJavaFileObject(ClassFile classFile) {
super(URI.create("class:///" + classFile.getName().replace('.', '/') + ".class"), Kind.CLASS);
this.classFile = classFile;
}

public String getClassName() {
return this.classFile.getName();
}

@Override
public InputStream openInputStream() {
return new ByteArrayInputStream(this.classFile.getContent());
}

Map<String, DynamicClassFileObject> getCompiledClasses() {
return this.compiledClasses;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ private DynamicClassLoader compile() {
throw new CompilationException(errors.toString(), this.sourceFiles, this.resourceFiles);
}
}
return new DynamicClassLoader(classLoaderToUse, this.resourceFiles, fileManager.getClassFiles());
return new DynamicClassLoader(classLoaderToUse, this.resourceFiles, this.classFiles, fileManager.getCompiledClasses());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ void getClassFilesReturnsClassFiles() throws Exception {
Kind.CLASS, null);
this.fileManager.getJavaFileForOutput(this.location, "com.example.MyClass2",
Kind.CLASS, null);
assertThat(this.fileManager.getClassFiles().stream().map(ClassFile::getName))
.contains("com.example.MyClass1", "com.example.MyClass2");
assertThat(this.fileManager.getCompiledClasses()).containsKeys(
"com.example.MyClass1", "com.example.MyClass2");
}

@Test
Expand Down

0 comments on commit 7168141

Please sign in to comment.