Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AOT for GraalVM 22 #283

Merged
merged 1 commit into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions aot-core/src/main/java/io/micronaut/aot/core/AOTContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;

/**
Expand Down Expand Up @@ -115,6 +116,12 @@ public interface AOTContext {
*/
void registerClassNeededAtCompileTime(@NonNull Class<?> clazz);

/**
* Registers a type as a requiring initialization at build time.
* @param className the type
*/
void registerBuildTimeInit(@NonNull String className);

/**
* Generates a java file spec.
* @param typeSpec the type spec of the main class
Expand Down Expand Up @@ -165,4 +172,16 @@ public interface AOTContext {
*/
@NonNull
Runtime getRuntime();

/**
* Returns the set of classes which require build time initialization
* @return the set of classes needing build time init
*/
Set<String> getBuildTimeInitClasses();

/**
* Performs actions which have to be done as late as possible during
* source generation.
*/
void finish();
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ public void generate(@NonNull AOTContext context) {
this.context = context;
JavaFile javaFile = generate();
context.registerGeneratedSourceFile(javaFile);
context.registerBuildTimeInit(javaFile.packageName + "." + javaFile.typeSpec.name);
context.registerBuildTimeInit(javaFile.packageName + "." + javaFile.typeSpec.name + "$1");
}

protected final AOTContext getContext() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public void generate(@NonNull AOTContext context) {
optimizedEntryPoint.addStaticBlock(staticInitializer.build());
context.registerGeneratedSourceFile(context.javaFile(optimizedEntryPoint.build()));
context.registerServiceImplementation(ApplicationContextConfigurer.class, CUSTOMIZER_CLASS_NAME);
context.finish();
}

private void addDiagnostics(AOTContext context, TypeSpec.Builder optimizedEntryPoint) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;

/**
Expand Down Expand Up @@ -110,6 +111,11 @@ public void registerClassNeededAtCompileTime(@NonNull Class<?> clazz) {
delegate.registerClassNeededAtCompileTime(clazz);
}

@Override
public void registerBuildTimeInit(String className) {
delegate.registerBuildTimeInit(className);
}

@Override
@NonNull
public JavaFile javaFile(TypeSpec typeSpec) {
Expand Down Expand Up @@ -144,4 +150,14 @@ public Map<String, List<String>> getDiagnostics() {
public Runtime getRuntime() {
return delegate.getRuntime();
}

@Override
public Set<String> getBuildTimeInitClasses() {
return delegate.getBuildTimeInitClasses();
}

@Override
public void finish() {
delegate.finish();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ public final class DefaultSourceGenerationContext implements AOTContext {
private final List<JavaFile> generatedJavaFiles = new ArrayList<>();
private final List<MethodSpec> initializers = new ArrayList<>();
private final Path generatedResourcesDirectory;
private final Set<String> buildTimeInitClasses = new HashSet<>();
private final List<Runnable> deferredOperations = new ArrayList<>();

public DefaultSourceGenerationContext(String packageName,
ApplicationContextAnalyzer analyzer,
Expand Down Expand Up @@ -163,6 +165,7 @@ public <T> void registerStaticOptimization(String className, Class<T> optimizati
.addSuperinterface(ParameterizedTypeName.get(StaticOptimizations.Loader.class, optimizationKind))
.addMethod(method)
.build();
registerBuildTimeInit(optimizationKind.getName());
registerGeneratedSourceFile(javaFile(generatedType));
registerServiceImplementation(StaticOptimizations.Loader.class, className);
}
Expand Down Expand Up @@ -190,14 +193,16 @@ public List<MethodSpec> getGeneratedStaticInitializers() {
@Override
public void registerGeneratedResource(@NonNull String path, Consumer<? super File> consumer) {
LOGGER.debug("Registering generated resource file: {}", path);
Path relative = generatedResourcesDirectory.resolve(path);
File resourceFile = relative.toFile();
File parent = resourceFile.getParentFile();
if (parent.exists() || parent.mkdirs()) {
consumer.accept(resourceFile);
} else {
throw new RuntimeException("Unable to create parent file " + parent + " for resource " + path);
}
deferredOperations.add(() -> {
Path relative = generatedResourcesDirectory.resolve(path);
File resourceFile = relative.toFile();
File parent = resourceFile.getParentFile();
if (parent.exists() || parent.mkdirs()) {
consumer.accept(resourceFile);
} else {
throw new RuntimeException("Unable to create parent file " + parent + " for resource " + path);
}
});
}

@NonNull
Expand All @@ -217,6 +222,11 @@ public List<File> getExtraClasspath() {
.collect(Collectors.toList());
}

@Override
public void registerBuildTimeInit(String className) {
buildTimeInitClasses.add(className);
}

/**
* Returns the list of resources to be excluded from
* the binary.
Expand Down Expand Up @@ -258,4 +268,14 @@ public <T> Optional<T> get(@NonNull Class<T> type) {
public Map<String, List<String>> getDiagnostics() {
return diagnostics;
}

@Override
public Set<String> getBuildTimeInitClasses() {
return Collections.unmodifiableSet(buildTimeInitClasses);
}

@Override
public void finish() {
deferredOperations.forEach(Runnable::run);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ abstract class AbstractSourceGeneratorSpec extends Specification {

abstract AOTCodeGenerator newGenerator()

void generate() {
final void generate() {
def sourceGenerator = newGenerator()
sourceGenerator.generate(context)
context.finish()
def sources = context.getGeneratedJavaFiles().collectEntries([:]) {
def writer = new StringWriter()
it.writeTo(writer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public abstract class AbstractStaticServiceLoaderSourceGenerator extends Abstrac
private Map<String, AbstractCodeGenerator> substitutions;
private Set<String> forceInclude;
private final Substitutes substitutes = new Substitutes();
private final Map<String, TypeSpec> staticServiceClasses = new HashMap<>();
private final Map<String, GeneratedType> staticServiceClasses = new HashMap<>();
private final Set<BeanConfiguration> disabledConfigurations = Collections.synchronizedSet(new HashSet<>());
private final Map<String, List<Class<?>>> serviceClasses = new HashMap<>();
private final Set<Class<?>> disabledServices = new HashSet<>();
Expand Down Expand Up @@ -144,7 +144,10 @@ public void generate(@NonNull AOTContext context) {
LOGGER.debug("Generated static {} service loader substitutions", substitutes.values().size());
staticServiceClasses.values()
.stream()
.map(context::javaFile)
.map(generatedType -> {
context.registerBuildTimeInit(generatedType.className());
return context.javaFile(generatedType.typeSpec());
})
.forEach(context::registerGeneratedSourceFile);
context.registerStaticOptimization("StaticServicesLoader", SoftServiceLoader.Optimizations.class, this::buildOptimization);
}
Expand All @@ -166,7 +169,7 @@ private void generateServiceLoader() {
serviceName,
serviceType,
factory);
staticServiceClasses.put(serviceName, factory.build());
staticServiceClasses.put(serviceName, new GeneratedType(context.getPackageName() + "." + factoryNameFor(serviceName), factory.build()));
}
}

Expand Down Expand Up @@ -234,23 +237,27 @@ protected abstract void generateFindAllMethod(Stream<Class<?>> serviceClasses,
TypeSpec.Builder factory);

private TypeSpec.Builder prepareServiceLoaderType(String serviceName, Class<?> serviceType) {
String name = simpleNameOf(serviceName) + "Factory";
String name = factoryNameFor(serviceName);
TypeSpec.Builder factory = TypeSpec.classBuilder(name)
.addModifiers(PUBLIC)
.addAnnotation(Generated.class)
.addSuperinterface(ParameterizedTypeName.get(SoftServiceLoader.StaticServiceLoader.class, serviceType));
return factory;
}

private static String factoryNameFor(String serviceName) {
return simpleNameOf(serviceName) + "Factory";
}

private void buildOptimization(CodeBlock.Builder body) {
ParameterizedTypeName serviceLoaderType = ParameterizedTypeName.get(
ClassName.get(SoftServiceLoader.StaticServiceLoader.class), WildcardTypeName.subtypeOf(Object.class));
body.addStatement("$T staticServices = new $T()",
ParameterizedTypeName.get(ClassName.get(Map.class), ClassName.get(String.class), serviceLoaderType),
ParameterizedTypeName.get(ClassName.get(HashMap.class), ClassName.get(String.class), serviceLoaderType));

for (Map.Entry<String, TypeSpec> entry : staticServiceClasses.entrySet()) {
body.addStatement("staticServices.put($S, new $T())", entry.getKey(), ClassName.bestGuess(entry.getValue().name));
for (Map.Entry<String, GeneratedType> entry : staticServiceClasses.entrySet()) {
body.addStatement("staticServices.put($S, new $T())", entry.getKey(), ClassName.bestGuess(entry.getValue().typeSpec().name));
}
body.addStatement("return new $T(staticServices)", SoftServiceLoader.Optimizations.class);
}
Expand Down Expand Up @@ -324,5 +331,11 @@ private boolean skipService(Class<?> clazz, Throwable e) {
}
}

private record GeneratedType(
String className,
TypeSpec typeSpec
) {

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,40 +28,55 @@
import java.io.IOException;
import java.io.PrintWriter;
import java.util.List;
import java.util.stream.Collectors;

/**
* Generates the GraalVM configuration file which is going to configure
* the native image code generation, typically asking to initialize
* the optimized entry point at build time.
*/
@AOTModule(
id = GraalVMOptimizationFeatureSourceGenerator.ID,
description = GraalVMOptimizationFeatureSourceGenerator.DESCRIPTION,
options = {
@Option(
key = "service.types",
description = "The list of service types to be scanned (comma separated)",
sampleValue = "io.micronaut.Service1,io.micronaut.Service2"
)
},
enabledOn = Runtime.NATIVE
id = GraalVMOptimizationFeatureSourceGenerator.ID,
description = GraalVMOptimizationFeatureSourceGenerator.DESCRIPTION,
options = {
@Option(
key = "service.types",
description = "The list of service types to be scanned (comma separated)",
sampleValue = "io.micronaut.Service1,io.micronaut.Service2"
)
},
enabledOn = Runtime.NATIVE
)
public class GraalVMOptimizationFeatureSourceGenerator extends AbstractCodeGenerator {
public static final String ID = "graalvm.config";
public static final String DESCRIPTION = "Generates GraalVM configuration files required to load the AOT optimizations";
public static final String DESCRIPTION =
"Generates GraalVM configuration files required to load the AOT optimizations";
private static final String NEXT_LINE = " \\";

private static final Option OPTION = MetadataUtils.findOption(GraalVMOptimizationFeatureSourceGenerator.class, "service.types");
private static final Option OPTION =
MetadataUtils.findOption(GraalVMOptimizationFeatureSourceGenerator.class, "service.types");

@Override
public void generate(@NonNull AOTContext context) {
List<String> serviceTypes = context.getConfiguration().stringList(OPTION.key());
String path = "META-INF/native-image/" + context.getPackageName() + "/native-image.properties";
String path =
"META-INF/native-image/" + context.getPackageName() + "/native-image.properties";
context.registerGeneratedResource(path, propertiesFile -> {
try (PrintWriter wrt = new PrintWriter(new FileWriter(propertiesFile))) {
wrt.print("Args=");
wrt.println("--initialize-at-build-time=" + context.getPackageName() + "." + ApplicationContextConfigurerGenerator.CUSTOMIZER_CLASS_NAME + NEXT_LINE);
if (context.getConfiguration().isFeatureEnabled(NativeStaticServiceLoaderSourceGenerator.ID)) {
wrt.println("--initialize-at-build-time=io.micronaut.context.ApplicationContextConfigurer$1" + NEXT_LINE);
wrt.println(" --initialize-at-build-time=" + context.getPackageName() + "." +
ApplicationContextConfigurerGenerator.CUSTOMIZER_CLASS_NAME +
NEXT_LINE);
var buildTimeInit = context.getBuildTimeInitClasses()
.stream()
.map(clazz -> " --initialize-at-build-time=" + clazz)
.collect(Collectors.joining(NEXT_LINE + "\n"));
if (!buildTimeInit.isEmpty()) {
wrt.println(buildTimeInit);
}
if (context.getConfiguration()
.isFeatureEnabled(NativeStaticServiceLoaderSourceGenerator.ID)) {
for (int i = 0; i < serviceTypes.size(); i++) {
String serviceType = serviceTypes.get(i);
wrt.print(" -H:ServiceLoaderFeatureExcludeServices=" + serviceType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class GraalVMOptimizationFeatureSourceGeneratorTest extends AbstractSourceGenera
assertThatGeneratedSources {
doesNotCreateInitializer()
generatesMetaInfResource("native-image/$packageName/native-image.properties", """
Args=--initialize-at-build-time=io.micronaut.test.AOTApplicationContextConfigurer \\
Args=--initialize-at-build-time=io.micronaut.context.ApplicationContextConfigurer\$1 \\
--initialize-at-build-time=io.micronaut.test.AOTApplicationContextConfigurer \\
""")
}
}
Expand All @@ -32,7 +33,8 @@ Args=--initialize-at-build-time=io.micronaut.test.AOTApplicationContextConfigure
assertThatGeneratedSources {
doesNotCreateInitializer()
generatesMetaInfResource("native-image/$packageName/native-image.properties", """
Args=--initialize-at-build-time=io.micronaut.test.AOTApplicationContextConfigurer \\
Args=--initialize-at-build-time=io.micronaut.context.ApplicationContextConfigurer\$1 \\
--initialize-at-build-time=io.micronaut.test.AOTApplicationContextConfigurer \\
-H:ServiceLoaderFeatureExcludeServices=A \\
-H:ServiceLoaderFeatureExcludeServices=B \\
-H:ServiceLoaderFeatureExcludeServices=C
Expand Down
Loading