From ae706f3954b1823626bee1373f306b20fb9239d3 Mon Sep 17 00:00:00 2001 From: Stephane Nicoll Date: Wed, 24 Aug 2022 07:53:38 +0200 Subject: [PATCH] Allow MethodReference to define a more flexible signature This commit moves MethodReference to an interface with a default implementation that relies on a MethodSpec. Such an arrangement avoid the need of specifying attributes of the method such as whether it is static or not. The resolution of the invocation block now takes an ArgumentCodeGenerator rather than the raw arguments. Doing so gives the opportunity to create more flexible signatures. See gh-29005 --- ...roxyBeanRegistrationAotProcessorTests.java | 6 +- .../aot/BeanRegistrationsAotContribution.java | 5 +- .../DefaultBeanRegistrationCodeFragments.java | 3 +- .../aot/InstanceSupplierCodeGenerator.java | 4 +- ...nBeanRegistrationAotContributionTests.java | 6 +- .../BeanDefinitionMethodGeneratorTests.java | 5 +- ...BeanRegistrationsAotContributionTests.java | 6 +- .../MockBeanFactoryInitializationCode.java | 4 + ...ionContextInitializationCodeGenerator.java | 8 +- ...lassPostProcessorAotContributionTests.java | 6 +- .../aot/generate/DefaultMethodReference.java | 134 +++++++++ .../aot/generate/GeneratedMethod.java | 6 +- .../aot/generate/MethodReference.java | 271 ++++++------------ .../generate/DefaultMethodReferenceTests.java | 199 +++++++++++++ .../aot/generate/GeneratedMethodTests.java | 8 +- .../aot/generate/MethodReferenceTests.java | 226 --------------- 16 files changed, 469 insertions(+), 428 deletions(-) create mode 100644 spring-core/src/main/java/org/springframework/aot/generate/DefaultMethodReference.java create mode 100644 spring-core/src/test/java/org/springframework/aot/generate/DefaultMethodReferenceTests.java delete mode 100644 spring-core/src/test/java/org/springframework/aot/generate/MethodReferenceTests.java diff --git a/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java b/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java index e29a31924929..1ce8b7986809 100644 --- a/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java +++ b/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java @@ -26,6 +26,7 @@ import org.springframework.aop.framework.AopInfrastructureBean; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.compile.Compiled; import org.springframework.aot.test.generate.compile.TestCompiler; @@ -139,11 +140,14 @@ private void compile(BiConsumer result) { MethodReference methodReference = this.beanFactoryInitializationCode .getInitializers().get(0); this.beanFactoryInitializationCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = methodReference.toInvokeCodeBlock( + ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"), + this.beanFactoryInitializationCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class)); type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) .addParameter(DefaultListableBeanFactory.class, "beanFactory") - .addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .addStatement(methodInvocation) .build()); }); this.generationContext.writeGeneratedContent(); diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java index fc8ca237b8b2..a80db112d350 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java @@ -25,6 +25,7 @@ import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; @@ -81,9 +82,11 @@ private void generateRegisterMethod(MethodSpec.Builder method, MethodReference beanDefinitionMethod = beanDefinitionMethodGenerator .generateBeanDefinitionMethod(generationContext, beanRegistrationsCode); + CodeBlock methodInvocation = beanDefinitionMethod.toInvokeCodeBlock( + ArgumentCodeGenerator.none(), beanRegistrationsCode.getClassName()); code.addStatement("$L.registerBeanDefinition($S, $L)", BEAN_FACTORY_PARAMETER_NAME, beanName, - beanDefinitionMethod.toInvokeCodeBlock()); + methodInvocation); }); method.addCode(code.build()); } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index bf5719eccdb3..fa04f8621151 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -24,6 +24,7 @@ import org.springframework.aot.generate.AccessVisibility; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinitionHolder; @@ -156,7 +157,7 @@ protected CodeBlock generateValueCode(GenerationContext generationContext, MethodReference generatedMethod = methodGenerator .generateBeanDefinitionMethod(generationContext, this.beanRegistrationsCode); - return generatedMethod.toInvokeCodeBlock(); + return generatedMethod.toInvokeCodeBlock(ArgumentCodeGenerator.none()); } return null; } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java index 1b597a36d379..e6cb5df84f13 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java @@ -28,6 +28,7 @@ import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.hint.ExecutableMode; import org.springframework.beans.factory.support.InstanceSupplier; import org.springframework.beans.factory.support.RegisteredBean; @@ -297,7 +298,8 @@ private CodeBlock generateNewInstanceCodeForMethod(boolean dependsOnBean, } private CodeBlock generateReturnStatement(GeneratedMethod generatedMethod) { - return generatedMethod.toMethodReference().toInvokeCodeBlock(); + return generatedMethod.toMethodReference().toInvokeCodeBlock( + ArgumentCodeGenerator.none(), this.className); } private CodeBlock generateWithGeneratorCode(boolean hasArguments, CodeBlock newInstance) { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java index 2f4fe187b131..219093424e75 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java @@ -24,6 +24,7 @@ import org.junit.jupiter.api.Test; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.compile.CompileWithTargetClassAccess; @@ -161,13 +162,16 @@ private void compile(RegisteredBean registeredBean, Class target = registeredBean.getBeanClass(); MethodReference methodReference = this.beanRegistrationCode.getInstancePostProcessors().get(0); this.beanRegistrationCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = methodReference.toInvokeCodeBlock( + ArgumentCodeGenerator.of(RegisteredBean.class, "registeredBean").and(target, "instance"), + this.beanRegistrationCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(BiFunction.class, RegisteredBean.class, target, target)); type.addMethod(MethodSpec.methodBuilder("apply") .addModifiers(Modifier.PUBLIC) .addParameter(RegisteredBean.class, "registeredBean") .addParameter(target, "instance").returns(target) - .addStatement("return $L", methodReference.toInvokeCodeBlock(CodeBlock.of("registeredBean"), CodeBlock.of("instance"))) + .addStatement("return $L", methodInvocation) .build()); }); diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index 3cc278470a14..9020bf8626a7 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -30,6 +30,7 @@ import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.compile.CompileWithTargetClassAccess; import org.springframework.aot.test.generate.compile.Compiled; @@ -414,12 +415,14 @@ private RegisteredBean registerBean(RootBeanDefinition beanDefinition) { private void compile(MethodReference method, BiConsumer result) { this.beanRegistrationsCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = method.toInvokeCodeBlock(ArgumentCodeGenerator.none(), + this.beanRegistrationsCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Supplier.class, BeanDefinition.class)); type.addMethod(MethodSpec.methodBuilder("get") .addModifiers(Modifier.PUBLIC) .returns(BeanDefinition.class) - .addCode("return $L;", method.toInvokeCodeBlock()).build()); + .addCode("return $L;", methodInvocation).build()); }); this.generationContext.writeGeneratedContent(); TestCompiler.forSystem().withFiles(this.generationContext.getGeneratedFiles()).compile(compiled -> diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java index 1eee3a83434c..bd2bba145e99 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java @@ -31,6 +31,7 @@ import org.springframework.aot.generate.ClassNameGenerator; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.TestTarget; import org.springframework.aot.test.generate.compile.Compiled; @@ -155,11 +156,14 @@ private void compile( MethodReference methodReference = this.beanFactoryInitializationCode .getInitializers().get(0); this.beanFactoryInitializationCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = methodReference.toInvokeCodeBlock( + ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"), + this.beanFactoryInitializationCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class)); type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) .addParameter(DefaultListableBeanFactory.class, "beanFactory") - .addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .addStatement(methodInvocation) .build()); }); this.generationContext.writeGeneratedContent(); diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java index 01a78dda3b47..c6986c7c4b04 100644 --- a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java @@ -25,6 +25,7 @@ import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; +import org.springframework.javapoet.ClassName; /** * Mock {@link BeanFactoryInitializationCode} implementation. @@ -46,6 +47,9 @@ public MockBeanFactoryInitializationCode(GenerationContext generationContext) { .addForFeature("TestCode", this.typeBuilder); } + public ClassName getClassName() { + return this.generatedClass.getName(); + } public DeferredTypeBuilder getTypeBuilder() { return this.typeBuilder; diff --git a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java index 29f502c7353d..b2bf870e2f56 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java +++ b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java @@ -25,6 +25,7 @@ import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.context.ApplicationContextInitializer; @@ -88,12 +89,17 @@ private CodeBlock generateInitializeCode() { BEAN_FACTORY_VARIABLE, ContextAnnotationAutowireCandidateResolver.class); code.addStatement("$L.setDependencyComparator($T.INSTANCE)", BEAN_FACTORY_VARIABLE, AnnotationAwareOrderComparator.class); + ArgumentCodeGenerator argCodeGenerator = createInitializerMethodsArgumentCodeGenerator(); for (MethodReference initializer : this.initializers) { - code.addStatement(initializer.toInvokeCodeBlock(CodeBlock.of(BEAN_FACTORY_VARIABLE))); + code.addStatement(initializer.toInvokeCodeBlock(argCodeGenerator, this.generatedClass.getName())); } return code.build(); } + private ArgumentCodeGenerator createInitializerMethodsArgumentCodeGenerator() { + return ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, BEAN_FACTORY_VARIABLE); + } + GeneratedClass getGeneratedClass() { return this.generatedClass; } diff --git a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java index 24bef4dcf158..29961c610704 100644 --- a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java +++ b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java @@ -27,6 +27,7 @@ import org.junit.jupiter.api.Test; import org.springframework.aot.generate.MethodReference; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.hint.ResourcePatternHint; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.compile.Compiled; @@ -162,11 +163,14 @@ private void assertPostProcessorEntry(BeanPostProcessor postProcessor, Class private void compile(BiConsumer, Compiled> result) { MethodReference methodReference = this.beanFactoryInitializationCode.getInitializers().get(0); this.beanFactoryInitializationCode.getTypeBuilder().set(type -> { + CodeBlock methodInvocation = methodReference.toInvokeCodeBlock( + ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"), + this.beanFactoryInitializationCode.getClassName()); type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class)); type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) .addParameter(DefaultListableBeanFactory.class, "beanFactory") - .addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory"))) + .addStatement(methodInvocation) .build()); }); this.generationContext.writeGeneratedContent(); diff --git a/spring-core/src/main/java/org/springframework/aot/generate/DefaultMethodReference.java b/spring-core/src/main/java/org/springframework/aot/generate/DefaultMethodReference.java new file mode 100644 index 000000000000..b3a3ab117d8e --- /dev/null +++ b/spring-core/src/main/java/org/springframework/aot/generate/DefaultMethodReference.java @@ -0,0 +1,134 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.aot.generate; + +import java.util.ArrayList; +import java.util.List; + +import javax.lang.model.element.Modifier; + +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.TypeName; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Default {@link MethodReference} implementation based on a {@link MethodSpec}. + * + * @author Stephane Nicoll + * @author Phillip Webb + * @since 6.0 + */ +public class DefaultMethodReference implements MethodReference { + + private final MethodSpec method; + + @Nullable + private final ClassName declaringClass; + + public DefaultMethodReference(MethodSpec method, @Nullable ClassName declaringClass) { + this.method = method; + this.declaringClass = declaringClass; + } + + @Override + public CodeBlock toCodeBlock() { + String methodName = this.method.name; + if (isStatic()) { + Assert.notNull(this.declaringClass, "static method reference must define a declaring class"); + return CodeBlock.of("$T::$L", this.declaringClass, methodName); + } + else { + return CodeBlock.of("this::$L", methodName); + } + } + + public CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator, + @Nullable ClassName targetClassName) { + String methodName = this.method.name; + CodeBlock.Builder code = CodeBlock.builder(); + if (isStatic()) { + Assert.notNull(this.declaringClass, "static method reference must define a declaring class"); + if (isSameDeclaringClass(targetClassName)) { + code.add("$L", methodName); + } + else { + code.add("$T.$L", this.declaringClass, methodName); + } + } + else { + if (!isSameDeclaringClass(targetClassName)) { + code.add(instantiateDeclaringClass(this.declaringClass)); + } + code.add("$L", methodName); + } + code.add("("); + addArguments(code, argumentCodeGenerator); + code.add(")"); + return code.build(); + } + + /** + * Add the code for the method arguments using the specified + * {@link ArgumentCodeGenerator} if necessary. + * @param code the code builder to use to add method arguments + * @param argumentCodeGenerator the code generator to use + */ + protected void addArguments(CodeBlock.Builder code, ArgumentCodeGenerator argumentCodeGenerator) { + List arguments = new ArrayList<>(); + TypeName[] argumentTypes = this.method.parameters.stream() + .map(parameter -> parameter.type).toArray(TypeName[]::new); + for (int i = 0; i < argumentTypes.length; i++) { + TypeName argumentType = argumentTypes[i]; + CodeBlock argumentCode = argumentCodeGenerator.generateCode(argumentType); + if (argumentCode == null) { + throw new IllegalArgumentException("Could not generate code for " + this + + ": parameter " + i + " of type " + argumentType + " is not supported"); + } + arguments.add(argumentCode); + } + code.add(CodeBlock.join(arguments, ", ")); + } + + protected CodeBlock instantiateDeclaringClass(ClassName declaringClass) { + return CodeBlock.of("new $T().", declaringClass); + } + + private boolean isStatic() { + return this.method.modifiers.contains(Modifier.STATIC); + } + + private boolean isSameDeclaringClass(ClassName declaringClass) { + return this.declaringClass == null || this.declaringClass.equals(declaringClass); + } + + @Override + public String toString() { + String methodName = this.method.name; + if (isStatic()) { + return this.declaringClass + "::" + methodName; + } + else { + return ((this.declaringClass != null) + ? "<" + this.declaringClass + ">" : "") + + "::" + methodName; + } + } + +} diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java index 7247c212d0c7..b09d36f61f28 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java @@ -18,8 +18,6 @@ import java.util.function.Consumer; -import javax.lang.model.element.Modifier; - import org.springframework.javapoet.ClassName; import org.springframework.javapoet.MethodSpec; import org.springframework.util.Assert; @@ -73,9 +71,7 @@ public String getName() { * @return a method reference */ public MethodReference toMethodReference() { - return (this.methodSpec.modifiers.contains(Modifier.STATIC) - ? MethodReference.ofStatic(this.className, this.name) - : MethodReference.of(this.className, this.name)); + return new DefaultMethodReference(this.methodSpec, this.className); } /** diff --git a/spring-core/src/main/java/org/springframework/aot/generate/MethodReference.java b/spring-core/src/main/java/org/springframework/aot/generate/MethodReference.java index 80359dd314b2..f6dda9710077 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/MethodReference.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/MethodReference.java @@ -16,223 +16,124 @@ package org.springframework.aot.generate; +import java.util.function.Function; + import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.TypeName; import org.springframework.lang.Nullable; -import org.springframework.util.Assert; /** - * A reference to a static or instance method. + * A reference to a method with convenient code generation for + * referencing, or invoking it. * + * @author Stephane Nicoll * @author Phillip Webb * @since 6.0 */ -public final class MethodReference { - - private final Kind kind; - - @Nullable - private final ClassName declaringClass; - - private final String methodName; - - - private MethodReference(Kind kind, @Nullable ClassName declaringClass, - String methodName) { - this.kind = kind; - this.declaringClass = declaringClass; - this.methodName = methodName; - } - - - /** - * Create a new method reference that refers to the given instance method. - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference of(String methodName) { - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.INSTANCE, null, methodName); - } - - /** - * Create a new method reference that refers to the given instance method. - * @param declaringClass the declaring class - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference of(Class declaringClass, String methodName) { - Assert.notNull(declaringClass, "'declaringClass' must not be null"); - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.INSTANCE, ClassName.get(declaringClass), - methodName); - } - - /** - * Create a new method reference that refers to the given instance method. - * @param declaringClass the declaring class - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference of(ClassName declaringClass, String methodName) { - Assert.notNull(declaringClass, "'declaringClass' must not be null"); - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.INSTANCE, declaringClass, methodName); - } - - /** - * Create a new method reference that refers to the given static method. - * @param declaringClass the declaring class - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference ofStatic(Class declaringClass, String methodName) { - Assert.notNull(declaringClass, "'declaringClass' must not be null"); - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.STATIC, ClassName.get(declaringClass), - methodName); - } - - /** - * Create a new method reference that refers to the given static method. - * @param declaringClass the declaring class - * @param methodName the method name - * @return a new {@link MethodReference} instance - */ - public static MethodReference ofStatic(ClassName declaringClass, String methodName) { - Assert.notNull(declaringClass, "'declaringClass' must not be null"); - Assert.hasLength(methodName, "'methodName' must not be empty"); - return new MethodReference(Kind.STATIC, declaringClass, methodName); - } - - - /** - * Return the referenced declaring class. - * @return the declaring class - */ - @Nullable - public ClassName getDeclaringClass() { - return this.declaringClass; - } - - /** - * Return the referenced method name. - * @return the method name - */ - public String getMethodName() { - return this.methodName; - } +public interface MethodReference { /** * Return this method reference as a {@link CodeBlock}. If the reference is * to an instance method then {@code this::} will be returned. * @return a code block for the method reference. - * @see #toCodeBlock(String) */ - public CodeBlock toCodeBlock() { - return toCodeBlock(null); - } + CodeBlock toCodeBlock(); /** - * Return this method reference as a {@link CodeBlock}. If the reference is - * to an instance method and {@code instanceVariable} is {@code null} then - * {@code this::} will be returned. No {@code instanceVariable} - * can be specified for static method references. - * @param instanceVariable the instance variable or {@code null} - * @return a code block for the method reference. - * @see #toCodeBlock(String) + * Return this method reference as a {@link CodeBlock} using the specified + * {@link ArgumentCodeGenerator}. + * @param argumentCodeGenerator the argument code generator to use + * @return a code block to invoke the method */ - public CodeBlock toCodeBlock(@Nullable String instanceVariable) { - return switch (this.kind) { - case INSTANCE -> toCodeBlockForInstance(instanceVariable); - case STATIC -> toCodeBlockForStatic(instanceVariable); - }; - } - - private CodeBlock toCodeBlockForInstance(@Nullable String instanceVariable) { - instanceVariable = (instanceVariable != null) ? instanceVariable : "this"; - return CodeBlock.of("$L::$L", instanceVariable, this.methodName); - } - - private CodeBlock toCodeBlockForStatic(@Nullable String instanceVariable) { - Assert.isTrue(instanceVariable == null, - "'instanceVariable' must be null for static method references"); - return CodeBlock.of("$T::$L", this.declaringClass, this.methodName); + default CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator) { + return toInvokeCodeBlock(argumentCodeGenerator, null); } /** - * Return this method reference as an invocation {@link CodeBlock}. - * @param arguments the method arguments - * @return a code back to invoke the method + * Return this method reference as a {@link CodeBlock} using the specified + * {@link ArgumentCodeGenerator}. The {@code targetClassName} defines the + * context in which the method invocation is added. + *

If the caller has an instance of the type in which this method is + * defined, it can hint that by specifying the type as a target class. + * @param argumentCodeGenerator the argument code generator to use + * @param targetClassName the target class name + * @return a code block to invoke the method */ - public CodeBlock toInvokeCodeBlock(CodeBlock... arguments) { - return toInvokeCodeBlock(null, arguments); - } + CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator, @Nullable ClassName targetClassName); + /** - * Return this method reference as an invocation {@link CodeBlock}. - * @param instanceVariable the instance variable or {@code null} - * @param arguments the method arguments - * @return a code back to invoke the method + * Strategy for generating code for arguments based on their type. */ - public CodeBlock toInvokeCodeBlock(@Nullable String instanceVariable, - CodeBlock... arguments) { - - return switch (this.kind) { - case INSTANCE -> toInvokeCodeBlockForInstance(instanceVariable, arguments); - case STATIC -> toInvokeCodeBlockForStatic(instanceVariable, arguments); - }; - } - - private CodeBlock toInvokeCodeBlockForInstance(@Nullable String instanceVariable, - CodeBlock[] arguments) { - - CodeBlock.Builder code = CodeBlock.builder(); - if (instanceVariable != null) { - code.add("$L.", instanceVariable); - } - else if (this.declaringClass != null) { - code.add("new $T().", this.declaringClass); + interface ArgumentCodeGenerator { + + /** + * Generate the code for the given argument type. If this type is + * not supported, return {@code null}. + * @param argumentType the argument type + * @return the code for this argument, or {@code null} + */ + @Nullable + CodeBlock generateCode(TypeName argumentType); + + /** + * Factory method that returns an {@link ArgumentCodeGenerator} that + * always returns {@code null}. + * @return a new {@link ArgumentCodeGenerator} instance + */ + static ArgumentCodeGenerator none() { + return from(type -> null); } - code.add("$L", this.methodName); - addArguments(code, arguments); - return code.build(); - } - private CodeBlock toInvokeCodeBlockForStatic(@Nullable String instanceVariable, - CodeBlock[] arguments) { - - Assert.isTrue(instanceVariable == null, - "'instanceVariable' must be null for static method references"); - CodeBlock.Builder code = CodeBlock.builder(); - code.add("$T.$L", this.declaringClass, this.methodName); - addArguments(code, arguments); - return code.build(); - } + /** + * Factory method that can be used to create an {@link ArgumentCodeGenerator} + * that support only the given argument type. + * @param argumentType the argument type + * @param argumentCode the code for an argument of that type + * @return a new {@link ArgumentCodeGenerator} instance + */ + static ArgumentCodeGenerator of(Class argumentType, String argumentCode) { + return from(candidateType -> (candidateType.equals(ClassName.get(argumentType)) + ? CodeBlock.of(argumentCode) : null)); + } - private void addArguments(CodeBlock.Builder code, CodeBlock[] arguments) { - code.add("("); - for (int i = 0; i < arguments.length; i++) { - if (i != 0) { - code.add(", "); - } - code.add(arguments[i]); + /** + * Factory method that creates a new {@link ArgumentCodeGenerator} from + * a lambda friendly function. The given function is provided with the + * argument type and must provide the code to use or {@code null} if + * the type is not supported. + * @param function the resolver function + * @return a new {@link ArgumentCodeGenerator} instance backed by the function + */ + static ArgumentCodeGenerator from(Function function) { + return function::apply; } - code.add(")"); - } - @Override - public String toString() { - return switch (this.kind) { - case INSTANCE -> ((this.declaringClass != null) ? "<" + this.declaringClass + ">" - : "") + "::" + this.methodName; - case STATIC -> this.declaringClass + "::" + this.methodName; - }; - } + /** + * Create a new composed {@link ArgumentCodeGenerator} by combining this + * generator with supporting the given argument type. + * @param argumentType the argument type + * @param argumentCode the code for an argument of that type + * @return a new composite {@link ArgumentCodeGenerator} instance + */ + default ArgumentCodeGenerator and(Class argumentType, String argumentCode) { + return and(ArgumentCodeGenerator.of(argumentType, argumentCode)); + } + /** + * Create a new composed {@link ArgumentCodeGenerator} by combining this + * generator with the given generator. + * @param argumentCodeGenerator the argument generator to add + * @return a new composite {@link ArgumentCodeGenerator} instance + */ + default ArgumentCodeGenerator and(ArgumentCodeGenerator argumentCodeGenerator) { + return from(type -> { + CodeBlock code = generateCode(type); + return (code != null ? code : argumentCodeGenerator.generateCode(type)); + }); + } - private enum Kind { - INSTANCE, STATIC } } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/DefaultMethodReferenceTests.java b/spring-core/src/test/java/org/springframework/aot/generate/DefaultMethodReferenceTests.java new file mode 100644 index 000000000000..b9643151ca9c --- /dev/null +++ b/spring-core/src/test/java/org/springframework/aot/generate/DefaultMethodReferenceTests.java @@ -0,0 +1,199 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.aot.generate; + +import javax.lang.model.element.Modifier; + +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.MethodSpec.Builder; +import org.springframework.javapoet.TypeName; +import org.springframework.lang.Nullable; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link DefaultMethodReference}. + * + * @author Phillip Webb + * @author Stephane Nicoll + */ +class DefaultMethodReferenceTests { + + private static final String EXPECTED_STATIC = "org.springframework.aot.generate.DefaultMethodReferenceTests::someMethod"; + + private static final String EXPECTED_ANONYMOUS_INSTANCE = "::someMethod"; + + private static final String EXPECTED_DECLARED_INSTANCE = "::someMethod"; + + private static final ClassName TEST_CLASS_NAME = ClassName.get("com.example", "Test"); + + private static final ClassName INITIALIZER_CLASS_NAME = ClassName.get("com.example", "Initializer"); + + @Test + void createWithStringCreatesMethodReference() { + MethodSpec method = createTestMethod("someMethod", new TypeName[0]); + MethodReference reference = new DefaultMethodReference(method, null); + assertThat(reference).hasToString(EXPECTED_ANONYMOUS_INSTANCE); + } + + @Test + void createWithClassNameAndStringCreateMethodReference() { + ClassName declaringClass = ClassName.get(DefaultMethodReferenceTests.class); + MethodReference reference = createMethodReference("someMethod", new TypeName[0], declaringClass); + assertThat(reference).hasToString(EXPECTED_DECLARED_INSTANCE); + } + + @Test + void createWithStaticAndClassAndStringCreatesMethodReference() { + ClassName declaringClass = ClassName.get(DefaultMethodReferenceTests.class); + MethodReference reference = createStaticMethodReference("someMethod", declaringClass); + assertThat(reference).hasToString(EXPECTED_STATIC); + } + + @Test + void toCodeBlock() { + assertThat(createLocalMethodReference("methodName").toCodeBlock()) + .isEqualTo(CodeBlock.of("this::methodName")); + } + + @Test + void toCodeBlockWithStaticMethod() { + assertThat(createStaticMethodReference("methodName", TEST_CLASS_NAME).toCodeBlock()) + .isEqualTo(CodeBlock.of("com.example.Test::methodName")); + } + + @Test + void toCodeBlockWithStaticMethodRequiresDeclaringClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0], Modifier.STATIC); + MethodReference methodReference = new DefaultMethodReference(method, null); + assertThatIllegalArgumentException().isThrownBy(methodReference::toCodeBlock) + .withMessage("static method reference must define a declaring class"); + } + + @Test + void toInvokeCodeBlockWithNullDeclaringClassAndTargetClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, null); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME)) + .isEqualTo(CodeBlock.of("methodName()")); + } + + @Test + void toInvokeCodeBlockWithNullDeclaringClassAndNullTargetClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, null); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none())) + .isEqualTo(CodeBlock.of("methodName()")); + } + + @Test + void toInvokeCodeBlockWithDeclaringClassAndNullTargetClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, TEST_CLASS_NAME); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none())) + .isEqualTo(CodeBlock.of("new com.example.Test().methodName()")); + } + + @Test + void toInvokeCodeBlockWithMatchingTargetClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, TEST_CLASS_NAME); + CodeBlock invocation = methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME); + // Assume com.example.Test is in a `test` variable. + assertThat(CodeBlock.of("$L.$L", "test", invocation)).isEqualTo(CodeBlock.of("test.methodName()")); + } + + @Test + void toInvokeCodeBlockWithNonMatchingDeclaringClass() { + MethodSpec method = createTestMethod("methodName", new TypeName[0]); + MethodReference methodReference = new DefaultMethodReference(method, TEST_CLASS_NAME); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), INITIALIZER_CLASS_NAME)) + .isEqualTo(CodeBlock.of("new com.example.Test().methodName()")); + } + + @Test + void toInvokeCodeBlockWithMatchingArg() { + MethodReference methodReference = createLocalMethodReference("methodName", ClassName.get(String.class)); + ArgumentCodeGenerator argCodeGenerator = ArgumentCodeGenerator.of(String.class, "stringArg"); + assertThat(methodReference.toInvokeCodeBlock(argCodeGenerator)) + .isEqualTo(CodeBlock.of("methodName(stringArg)")); + } + + @Test + void toInvokeCodeBlockWithMatchingArgs() { + MethodReference methodReference = createLocalMethodReference("methodName", + ClassName.get(Integer.class), ClassName.get(String.class)); + ArgumentCodeGenerator argCodeGenerator = ArgumentCodeGenerator.of(String.class, "stringArg") + .and(Integer.class, "integerArg"); + assertThat(methodReference.toInvokeCodeBlock(argCodeGenerator)) + .isEqualTo(CodeBlock.of("methodName(integerArg, stringArg)")); + } + + @Test + void toInvokeCodeBlockWithNonMatchingArg() { + MethodReference methodReference = createLocalMethodReference("methodName", + ClassName.get(Integer.class), ClassName.get(String.class)); + ArgumentCodeGenerator argCodeGenerator = ArgumentCodeGenerator.of(Integer.class, "integerArg"); + assertThatIllegalArgumentException().isThrownBy(() -> methodReference.toInvokeCodeBlock(argCodeGenerator)) + .withMessageContaining("parameter 1 of type java.lang.String is not supported"); + } + + @Test + void toInvokeCodeBlockWithStaticMethodAndMatchingDeclaringClass() { + MethodReference methodReference = createStaticMethodReference("methodName", TEST_CLASS_NAME); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME)) + .isEqualTo(CodeBlock.of("methodName()")); + } + + @Test + void toInvokeCodeBlockWithStaticMethodAndSeparateDeclaringClass() { + MethodReference methodReference = createStaticMethodReference("methodName", TEST_CLASS_NAME); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), INITIALIZER_CLASS_NAME)) + .isEqualTo(CodeBlock.of("com.example.Test.methodName()")); + } + + + private MethodReference createLocalMethodReference(String name, TypeName... argumentTypes) { + return createMethodReference(name, argumentTypes, null); + } + + private MethodReference createMethodReference(String name, TypeName[] argumentTypes, @Nullable ClassName declaringClass) { + MethodSpec method = createTestMethod(name, argumentTypes); + return new DefaultMethodReference(method, declaringClass); + } + + private MethodReference createStaticMethodReference(String name, ClassName declaringClass, TypeName... argumentTypes) { + MethodSpec method = createTestMethod(name, argumentTypes, Modifier.STATIC); + return new DefaultMethodReference(method, declaringClass); + } + + private MethodSpec createTestMethod(String name, TypeName[] argumentTypes, Modifier... modifiers) { + Builder method = MethodSpec.methodBuilder(name); + for (int i = 0; i < argumentTypes.length; i++) { + method.addParameter(argumentTypes[i], "args" + i); + } + method.addModifiers(modifiers); + return method.build(); + } + +} diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java index 34ac962746aa..6e865bd4eb8c 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedMethodTests.java @@ -22,6 +22,7 @@ import org.junit.jupiter.api.Test; +import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; @@ -67,8 +68,8 @@ void toMethodReferenceWithInstanceMethod() { GeneratedMethod generatedMethod = create(emptyMethod); MethodReference methodReference = generatedMethod.toMethodReference(); assertThat(methodReference).isNotNull(); - assertThat(methodReference.toInvokeCodeBlock("test")) - .isEqualTo(CodeBlock.of("test.spring()")); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), TEST_CLASS_NAME)) + .isEqualTo(CodeBlock.of("spring()")); } @Test @@ -76,7 +77,8 @@ void toMethodReferenceWithStaticMethod() { GeneratedMethod generatedMethod = create(method -> method.addModifiers(Modifier.STATIC)); MethodReference methodReference = generatedMethod.toMethodReference(); assertThat(methodReference).isNotNull(); - assertThat(methodReference.toInvokeCodeBlock()) + ClassName anotherDeclaringClass = ClassName.get("com.example", "Another"); + assertThat(methodReference.toInvokeCodeBlock(ArgumentCodeGenerator.none(), anotherDeclaringClass)) .isEqualTo(CodeBlock.of("com.example.Test.spring()")); } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/MethodReferenceTests.java b/spring-core/src/test/java/org/springframework/aot/generate/MethodReferenceTests.java deleted file mode 100644 index de5c79667b42..000000000000 --- a/spring-core/src/test/java/org/springframework/aot/generate/MethodReferenceTests.java +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Copyright 2002-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.aot.generate; - -import org.junit.jupiter.api.Test; - -import org.springframework.javapoet.ClassName; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; - -/** - * Tests for {@link MethodReference}. - * - * @author Phillip Webb - */ -class MethodReferenceTests { - - private static final String EXPECTED_STATIC = "org.springframework.aot.generate.MethodReferenceTests::someMethod"; - - private static final String EXPECTED_ANONYMOUS_INSTANCE = "::someMethod"; - - private static final String EXPECTED_DECLARED_INSTANCE = "::someMethod"; - - - @Test - void ofWithStringWhenMethodNameIsNullThrowsException() { - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofWithStringCreatesMethodReference() { - String methodName = "someMethod"; - MethodReference reference = MethodReference.of(methodName); - assertThat(reference).hasToString(EXPECTED_ANONYMOUS_INSTANCE); - } - - @Test - void ofWithClassAndStringWhenDeclaringClassIsNullThrowsException() { - Class declaringClass = null; - String methodName = "someMethod"; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(declaringClass, methodName)) - .withMessage("'declaringClass' must not be null"); - } - - @Test - void ofWithClassAndStringWhenMethodNameIsNullThrowsException() { - Class declaringClass = MethodReferenceTests.class; - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(declaringClass, methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofWithClassAndStringCreatesMethodReference() { - Class declaringClass = MethodReferenceTests.class; - String methodName = "someMethod"; - MethodReference reference = MethodReference.of(declaringClass, methodName); - assertThat(reference).hasToString(EXPECTED_DECLARED_INSTANCE); - } - - @Test - void ofWithClassNameAndStringWhenDeclaringClassIsNullThrowsException() { - ClassName declaringClass = null; - String methodName = "someMethod"; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(declaringClass, methodName)) - .withMessage("'declaringClass' must not be null"); - } - - @Test - void ofWithClassNameAndStringWhenMethodNameIsNullThrowsException() { - ClassName declaringClass = ClassName.get(MethodReferenceTests.class); - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.of(declaringClass, methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofWithClassNameAndStringCreateMethodReference() { - ClassName declaringClass = ClassName.get(MethodReferenceTests.class); - String methodName = "someMethod"; - MethodReference reference = MethodReference.of(declaringClass, methodName); - assertThat(reference).hasToString(EXPECTED_DECLARED_INSTANCE); - } - - @Test - void ofStaticWithClassAndStringWhenDeclaringClassIsNullThrowsException() { - Class declaringClass = null; - String methodName = "someMethod"; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName)) - .withMessage("'declaringClass' must not be null"); - } - - @Test - void ofStaticWithClassAndStringWhenMethodNameIsEmptyThrowsException() { - Class declaringClass = MethodReferenceTests.class; - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofStaticWithClassAndStringCreatesMethodReference() { - Class declaringClass = MethodReferenceTests.class; - String methodName = "someMethod"; - MethodReference reference = MethodReference.ofStatic(declaringClass, methodName); - assertThat(reference).hasToString(EXPECTED_STATIC); - } - - @Test - void ofStaticWithClassNameAndGeneratedMethodNameWhenDeclaringClassIsNullThrowsException() { - ClassName declaringClass = null; - String methodName = "someMethod"; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName)) - .withMessage("'declaringClass' must not be null"); - } - - @Test - void ofStaticWithClassNameAndGeneratedMethodNameWhenMethodNameIsEmptyThrowsException() { - ClassName declaringClass = ClassName.get(MethodReferenceTests.class); - String methodName = null; - assertThatIllegalArgumentException() - .isThrownBy(() -> MethodReference.ofStatic(declaringClass, methodName)) - .withMessage("'methodName' must not be empty"); - } - - @Test - void ofStaticWithClassNameAndGeneratedMethodNameCreatesMethodReference() { - ClassName declaringClass = ClassName.get(MethodReferenceTests.class); - String methodName = "someMethod"; - MethodReference reference = MethodReference.ofStatic(declaringClass, methodName); - assertThat(reference).hasToString(EXPECTED_STATIC); - } - - @Test - void toCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNull() { - MethodReference reference = MethodReference.of("someMethod"); - assertThat(reference.toCodeBlock(null)).hasToString("this::someMethod"); - } - - @Test - void toCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNotNull() { - MethodReference reference = MethodReference.of("someMethod"); - assertThat(reference.toCodeBlock("myInstance")) - .hasToString("myInstance::someMethod"); - } - - @Test - void toCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNull() { - MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class, - "someMethod"); - assertThat(reference.toCodeBlock(null)).hasToString(EXPECTED_STATIC); - } - - @Test - void toCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNotNullThrowsException() { - MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class, - "someMethod"); - assertThatIllegalArgumentException() - .isThrownBy(() -> reference.toCodeBlock("myInstance")).withMessage( - "'instanceVariable' must be null for static method references"); - } - - @Test - void toInvokeCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNull() { - MethodReference reference = MethodReference.of("someMethod"); - assertThat(reference.toInvokeCodeBlock()).hasToString("someMethod()"); - } - - @Test - void toInvokeCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNullAndHasDecalredClass() { - MethodReference reference = MethodReference.of(MethodReferenceTests.class, - "someMethod"); - assertThat(reference.toInvokeCodeBlock()).hasToString( - "new org.springframework.aot.generate.MethodReferenceTests().someMethod()"); - } - - @Test - void toInvokeCodeBlockWhenInstanceMethodReferenceAndInstanceVariableIsNotNull() { - MethodReference reference = MethodReference.of("someMethod"); - assertThat(reference.toInvokeCodeBlock("myInstance")) - .hasToString("myInstance.someMethod()"); - } - - @Test - void toInvokeCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNull() { - MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class, - "someMethod"); - assertThat(reference.toInvokeCodeBlock()).hasToString( - "org.springframework.aot.generate.MethodReferenceTests.someMethod()"); - } - - @Test - void toInvokeCodeBlockWhenStaticMethodReferenceAndInstanceVariableIsNotNullThrowsException() { - MethodReference reference = MethodReference.ofStatic(MethodReferenceTests.class, - "someMethod"); - assertThatIllegalArgumentException() - .isThrownBy(() -> reference.toInvokeCodeBlock("myInstance")).withMessage( - "'instanceVariable' must be null for static method references"); - } - -}