Skip to content

Commit

Permalink
Allow MethodReference to define a more flexible signature
Browse files Browse the repository at this point in the history
This commit moves MethodReference to an interface with a default
implementation that relies on a MethodSpec. Such an arrangement avoid
the need of specifying attributes of the method such as whether it is
static or not.

The resolution of the invocation block now takes an
ArgumentCodeGenerator rather than the raw arguments. Doing so gives
the opportunity to create more flexible signatures.

See gh-29005
  • Loading branch information
snicoll committed Sep 12, 2022
1 parent 8a4a89b commit ae706f3
Show file tree
Hide file tree
Showing 16 changed files with 469 additions and 428 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import org.springframework.aop.framework.AopInfrastructureBean;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.aot.test.generate.compile.Compiled;
import org.springframework.aot.test.generate.compile.TestCompiler;
Expand Down Expand Up @@ -139,11 +140,14 @@ private void compile(BiConsumer<DefaultListableBeanFactory, Compiled> result) {
MethodReference methodReference = this.beanFactoryInitializationCode
.getInitializers().get(0);
this.beanFactoryInitializationCode.getTypeBuilder().set(type -> {
CodeBlock methodInvocation = methodReference.toInvokeCodeBlock(
ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"),
this.beanFactoryInitializationCode.getClassName());
type.addModifiers(Modifier.PUBLIC);
type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class));
type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC)
.addParameter(DefaultListableBeanFactory.class, "beanFactory")
.addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory")))
.addStatement(methodInvocation)
.build());
});
this.generationContext.writeGeneratedContent();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.springframework.aot.generate.GeneratedMethods;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
Expand Down Expand Up @@ -81,9 +82,11 @@ private void generateRegisterMethod(MethodSpec.Builder method,
MethodReference beanDefinitionMethod = beanDefinitionMethodGenerator
.generateBeanDefinitionMethod(generationContext,
beanRegistrationsCode);
CodeBlock methodInvocation = beanDefinitionMethod.toInvokeCodeBlock(
ArgumentCodeGenerator.none(), beanRegistrationsCode.getClassName());
code.addStatement("$L.registerBeanDefinition($S, $L)",
BEAN_FACTORY_PARAMETER_NAME, beanName,
beanDefinitionMethod.toInvokeCodeBlock());
methodInvocation);
});
method.addCode(code.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.springframework.aot.generate.AccessVisibility;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanDefinitionHolder;
Expand Down Expand Up @@ -156,7 +157,7 @@ protected CodeBlock generateValueCode(GenerationContext generationContext,
MethodReference generatedMethod = methodGenerator
.generateBeanDefinitionMethod(generationContext,
this.beanRegistrationsCode);
return generatedMethod.toInvokeCodeBlock();
return generatedMethod.toInvokeCodeBlock(ArgumentCodeGenerator.none());
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GeneratedMethods;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.aot.hint.ExecutableMode;
import org.springframework.beans.factory.support.InstanceSupplier;
import org.springframework.beans.factory.support.RegisteredBean;
Expand Down Expand Up @@ -297,7 +298,8 @@ private CodeBlock generateNewInstanceCodeForMethod(boolean dependsOnBean,
}

private CodeBlock generateReturnStatement(GeneratedMethod generatedMethod) {
return generatedMethod.toMethodReference().toInvokeCodeBlock();
return generatedMethod.toMethodReference().toInvokeCodeBlock(
ArgumentCodeGenerator.none(), this.className);
}

private CodeBlock generateWithGeneratorCode(boolean hasArguments, CodeBlock newInstance) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.junit.jupiter.api.Test;

import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.aot.hint.predicate.RuntimeHintsPredicates;
import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.aot.test.generate.compile.CompileWithTargetClassAccess;
Expand Down Expand Up @@ -161,13 +162,16 @@ private void compile(RegisteredBean registeredBean,
Class<?> target = registeredBean.getBeanClass();
MethodReference methodReference = this.beanRegistrationCode.getInstancePostProcessors().get(0);
this.beanRegistrationCode.getTypeBuilder().set(type -> {
CodeBlock methodInvocation = methodReference.toInvokeCodeBlock(
ArgumentCodeGenerator.of(RegisteredBean.class, "registeredBean").and(target, "instance"),
this.beanRegistrationCode.getClassName());
type.addModifiers(Modifier.PUBLIC);
type.addSuperinterface(ParameterizedTypeName.get(BiFunction.class, RegisteredBean.class, target, target));
type.addMethod(MethodSpec.methodBuilder("apply")
.addModifiers(Modifier.PUBLIC)
.addParameter(RegisteredBean.class, "registeredBean")
.addParameter(target, "instance").returns(target)
.addStatement("return $L", methodReference.toInvokeCodeBlock(CodeBlock.of("registeredBean"), CodeBlock.of("instance")))
.addStatement("return $L", methodInvocation)
.build());

});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.aot.test.generate.compile.CompileWithTargetClassAccess;
import org.springframework.aot.test.generate.compile.Compiled;
Expand Down Expand Up @@ -414,12 +415,14 @@ private RegisteredBean registerBean(RootBeanDefinition beanDefinition) {
private void compile(MethodReference method,
BiConsumer<RootBeanDefinition, Compiled> result) {
this.beanRegistrationsCode.getTypeBuilder().set(type -> {
CodeBlock methodInvocation = method.toInvokeCodeBlock(ArgumentCodeGenerator.none(),
this.beanRegistrationsCode.getClassName());
type.addModifiers(Modifier.PUBLIC);
type.addSuperinterface(ParameterizedTypeName.get(Supplier.class, BeanDefinition.class));
type.addMethod(MethodSpec.methodBuilder("get")
.addModifiers(Modifier.PUBLIC)
.returns(BeanDefinition.class)
.addCode("return $L;", method.toInvokeCodeBlock()).build());
.addCode("return $L;", methodInvocation).build());
});
this.generationContext.writeGeneratedContent();
TestCompiler.forSystem().withFiles(this.generationContext.getGeneratedFiles()).compile(compiled ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.springframework.aot.generate.ClassNameGenerator;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.aot.test.generate.TestTarget;
import org.springframework.aot.test.generate.compile.Compiled;
Expand Down Expand Up @@ -155,11 +156,14 @@ private void compile(
MethodReference methodReference = this.beanFactoryInitializationCode
.getInitializers().get(0);
this.beanFactoryInitializationCode.getTypeBuilder().set(type -> {
CodeBlock methodInvocation = methodReference.toInvokeCodeBlock(
ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"),
this.beanFactoryInitializationCode.getClassName());
type.addModifiers(Modifier.PUBLIC);
type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class));
type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC)
.addParameter(DefaultListableBeanFactory.class, "beanFactory")
.addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory")))
.addStatement(methodInvocation)
.build());
});
this.generationContext.writeGeneratedContent();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
import org.springframework.javapoet.ClassName;

/**
* Mock {@link BeanFactoryInitializationCode} implementation.
Expand All @@ -46,6 +47,9 @@ public MockBeanFactoryInitializationCode(GenerationContext generationContext) {
.addForFeature("TestCode", this.typeBuilder);
}

public ClassName getClassName() {
return this.generatedClass.getName();
}

public DeferredTypeBuilder getTypeBuilder() {
return this.typeBuilder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.springframework.aot.generate.GeneratedMethods;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.context.ApplicationContextInitializer;
Expand Down Expand Up @@ -88,12 +89,17 @@ private CodeBlock generateInitializeCode() {
BEAN_FACTORY_VARIABLE, ContextAnnotationAutowireCandidateResolver.class);
code.addStatement("$L.setDependencyComparator($T.INSTANCE)",
BEAN_FACTORY_VARIABLE, AnnotationAwareOrderComparator.class);
ArgumentCodeGenerator argCodeGenerator = createInitializerMethodsArgumentCodeGenerator();
for (MethodReference initializer : this.initializers) {
code.addStatement(initializer.toInvokeCodeBlock(CodeBlock.of(BEAN_FACTORY_VARIABLE)));
code.addStatement(initializer.toInvokeCodeBlock(argCodeGenerator, this.generatedClass.getName()));
}
return code.build();
}

private ArgumentCodeGenerator createInitializerMethodsArgumentCodeGenerator() {
return ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, BEAN_FACTORY_VARIABLE);
}

GeneratedClass getGeneratedClass() {
return this.generatedClass;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.junit.jupiter.api.Test;

import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.aot.hint.ResourcePatternHint;
import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.aot.test.generate.compile.Compiled;
Expand Down Expand Up @@ -162,11 +163,14 @@ private void assertPostProcessorEntry(BeanPostProcessor postProcessor, Class<?>
private void compile(BiConsumer<Consumer<DefaultListableBeanFactory>, Compiled> result) {
MethodReference methodReference = this.beanFactoryInitializationCode.getInitializers().get(0);
this.beanFactoryInitializationCode.getTypeBuilder().set(type -> {
CodeBlock methodInvocation = methodReference.toInvokeCodeBlock(
ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"),
this.beanFactoryInitializationCode.getClassName());
type.addModifiers(Modifier.PUBLIC);
type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class));
type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC)
.addParameter(DefaultListableBeanFactory.class, "beanFactory")
.addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory")))
.addStatement(methodInvocation)
.build());
});
this.generationContext.writeGeneratedContent();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.aot.generate;

import java.util.ArrayList;
import java.util.List;

import javax.lang.model.element.Modifier;

import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.TypeName;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
* Default {@link MethodReference} implementation based on a {@link MethodSpec}.
*
* @author Stephane Nicoll
* @author Phillip Webb
* @since 6.0
*/
public class DefaultMethodReference implements MethodReference {

private final MethodSpec method;

@Nullable
private final ClassName declaringClass;

public DefaultMethodReference(MethodSpec method, @Nullable ClassName declaringClass) {
this.method = method;
this.declaringClass = declaringClass;
}

@Override
public CodeBlock toCodeBlock() {
String methodName = this.method.name;
if (isStatic()) {
Assert.notNull(this.declaringClass, "static method reference must define a declaring class");
return CodeBlock.of("$T::$L", this.declaringClass, methodName);
}
else {
return CodeBlock.of("this::$L", methodName);
}
}

public CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator,
@Nullable ClassName targetClassName) {
String methodName = this.method.name;
CodeBlock.Builder code = CodeBlock.builder();
if (isStatic()) {
Assert.notNull(this.declaringClass, "static method reference must define a declaring class");
if (isSameDeclaringClass(targetClassName)) {
code.add("$L", methodName);
}
else {
code.add("$T.$L", this.declaringClass, methodName);
}
}
else {
if (!isSameDeclaringClass(targetClassName)) {
code.add(instantiateDeclaringClass(this.declaringClass));
}
code.add("$L", methodName);
}
code.add("(");
addArguments(code, argumentCodeGenerator);
code.add(")");
return code.build();
}

/**
* Add the code for the method arguments using the specified
* {@link ArgumentCodeGenerator} if necessary.
* @param code the code builder to use to add method arguments
* @param argumentCodeGenerator the code generator to use
*/
protected void addArguments(CodeBlock.Builder code, ArgumentCodeGenerator argumentCodeGenerator) {
List<CodeBlock> arguments = new ArrayList<>();
TypeName[] argumentTypes = this.method.parameters.stream()
.map(parameter -> parameter.type).toArray(TypeName[]::new);
for (int i = 0; i < argumentTypes.length; i++) {
TypeName argumentType = argumentTypes[i];
CodeBlock argumentCode = argumentCodeGenerator.generateCode(argumentType);
if (argumentCode == null) {
throw new IllegalArgumentException("Could not generate code for " + this
+ ": parameter " + i + " of type " + argumentType + " is not supported");
}
arguments.add(argumentCode);
}
code.add(CodeBlock.join(arguments, ", "));
}

protected CodeBlock instantiateDeclaringClass(ClassName declaringClass) {
return CodeBlock.of("new $T().", declaringClass);
}

private boolean isStatic() {
return this.method.modifiers.contains(Modifier.STATIC);
}

private boolean isSameDeclaringClass(ClassName declaringClass) {
return this.declaringClass == null || this.declaringClass.equals(declaringClass);
}

@Override
public String toString() {
String methodName = this.method.name;
if (isStatic()) {
return this.declaringClass + "::" + methodName;
}
else {
return ((this.declaringClass != null)
? "<" + this.declaringClass + ">" : "<instance>")
+ "::" + methodName;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

import java.util.function.Consumer;

import javax.lang.model.element.Modifier;

import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.MethodSpec;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -73,9 +71,7 @@ public String getName() {
* @return a method reference
*/
public MethodReference toMethodReference() {
return (this.methodSpec.modifiers.contains(Modifier.STATIC)
? MethodReference.ofStatic(this.className, this.name)
: MethodReference.of(this.className, this.name));
return new DefaultMethodReference(this.methodSpec, this.className);
}

/**
Expand Down
Loading

0 comments on commit ae706f3

Please sign in to comment.