diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/CglibSubclassingInstantiationStrategy.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/CglibSubclassingInstantiationStrategy.java index 0a0b81211200..39736de39e2d 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/CglibSubclassingInstantiationStrategy.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/CglibSubclassingInstantiationStrategy.java @@ -18,7 +18,6 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Method; -import java.util.Objects; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -289,10 +288,10 @@ public Object intercept(Object obj, Method method, Object[] args, MethodProxy mp @Nullable private T processReturnType(Method method, @Nullable T returnValue) { Class returnType = method.getReturnType(); - if (returnType != void.class && returnType.isPrimitive()) { - return Objects.requireNonNull(returnValue, () -> "Null return value from replacer does not match primitive return type for: " + method); + if (returnValue == null && returnType != void.class && returnType.isPrimitive()) { + throw new IllegalStateException( + "Null return value from MethodReplacer does not match primitive return type for: " + method); } - return returnValue; } } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/support/CglibSubclassingInstantiationStrategyTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/support/CglibSubclassingInstantiationStrategyTests.java index 330a507134bb..4c709272428b 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/support/CglibSubclassingInstantiationStrategyTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/support/CglibSubclassingInstantiationStrategyTests.java @@ -1,90 +1,129 @@ -package org.springframework.beans.factory.support; +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +package org.springframework.beans.factory.support; import java.lang.reflect.Method; import java.util.Map; +import java.util.regex.Pattern; import java.util.stream.Stream; -import org.assertj.core.api.ThrowableAssert; +import org.assertj.core.api.ThrowableAssert.ThrowingCallable; import org.junit.jupiter.api.Test; + import org.springframework.lang.Nullable; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +/** + * Tests for {@link CglibSubclassingInstantiationStrategy}. + * + * @author Mikaƫl Francoeur + * @author Sam Brannen + * @since 6.2 + */ class CglibSubclassingInstantiationStrategyTests { private final CglibSubclassingInstantiationStrategy strategy = new CglibSubclassingInstantiationStrategy(); - @Nullable - public static Object valueToReturnFromReplacer; @Test - void methodOverride() { + void replaceOverrideMethodInterceptorRejectsNullReturnValueForPrimitives() { + MyReplacer replacer = new MyReplacer(); StaticListableBeanFactory beanFactory = new StaticListableBeanFactory(Map.of( "myBean", new MyBean(), - "replacer", new MyReplacer() + "replacer", replacer )); - RootBeanDefinition bd = new RootBeanDefinition(MyBean.class); MethodOverrides methodOverrides = new MethodOverrides(); - Stream.of("getBoolean", "getShort", "getInt", "getLong", "getFloat", "getDouble", "getByte") - .forEach(methodToOverride -> addOverride(methodOverrides, methodToOverride)); + Stream.of("getBoolean", "getChar", "getByte", "getShort", "getInt", "getLong", "getFloat", "getDouble") + .map(methodToOverride -> new ReplaceOverride(methodToOverride, "replacer")) + .forEach(methodOverrides::addOverride); + + RootBeanDefinition bd = new RootBeanDefinition(MyBean.class); bd.setMethodOverrides(methodOverrides); MyBean bean = (MyBean) strategy.instantiate(bd, "myBean", beanFactory); - valueToReturnFromReplacer = null; + replacer.reset(); assertCorrectExceptionThrownBy(bean::getBoolean); - valueToReturnFromReplacer = true; + replacer.returnValue = true; assertThat(bean.getBoolean()).isTrue(); - valueToReturnFromReplacer = null; + replacer.reset(); + assertCorrectExceptionThrownBy(bean::getChar); + replacer.returnValue = 'x'; + assertThat(bean.getChar()).isEqualTo('x'); + + replacer.reset(); + assertCorrectExceptionThrownBy(bean::getByte); + replacer.returnValue = 123; + assertThat(bean.getByte()).isEqualTo((byte) 123); + + replacer.reset(); assertCorrectExceptionThrownBy(bean::getShort); - valueToReturnFromReplacer = 123; + replacer.returnValue = 123; assertThat(bean.getShort()).isEqualTo((short) 123); - valueToReturnFromReplacer = null; + replacer.reset(); assertCorrectExceptionThrownBy(bean::getInt); - valueToReturnFromReplacer = 123; + replacer.returnValue = 123; assertThat(bean.getInt()).isEqualTo(123); - valueToReturnFromReplacer = null; + replacer.reset(); assertCorrectExceptionThrownBy(bean::getLong); - valueToReturnFromReplacer = 123; + replacer.returnValue = 123; assertThat(bean.getLong()).isEqualTo(123L); - valueToReturnFromReplacer = null; + replacer.reset(); assertCorrectExceptionThrownBy(bean::getFloat); - valueToReturnFromReplacer = 123; + replacer.returnValue = 123; assertThat(bean.getFloat()).isEqualTo(123f); - valueToReturnFromReplacer = null; + replacer.reset(); assertCorrectExceptionThrownBy(bean::getDouble); - valueToReturnFromReplacer = 123; + replacer.returnValue = 123; assertThat(bean.getDouble()).isEqualTo(123d); - - valueToReturnFromReplacer = null; - assertCorrectExceptionThrownBy(bean::getByte); - valueToReturnFromReplacer = 123; - assertThat(bean.getByte()).isEqualTo((byte) 123); } - private void assertCorrectExceptionThrownBy(ThrowableAssert.ThrowingCallable runnable) { - assertThatThrownBy(runnable) - .isInstanceOf(NullPointerException.class) - .hasMessageMatching("Null return value from replacer does not match primitive return type for: " - + "\\w+ org\\.springframework\\.beans\\.factory\\.support\\.CglibSubclassingInstantiationStrategyTests\\$MyBean\\.\\w+\\(\\)"); - } - private void addOverride(MethodOverrides methodOverrides, String methodToOverride) { - methodOverrides.addOverride(new ReplaceOverride(methodToOverride, "replacer")); + private static void assertCorrectExceptionThrownBy(ThrowingCallable runnable) { + assertThatIllegalStateException() + .isThrownBy(runnable) + .withMessageMatching( + "Null return value from MethodReplacer does not match primitive return type for: " + + "\\w+ %s\\.\\w+\\(\\)".formatted(Pattern.quote(MyBean.class.getName()))); } + static class MyBean { + boolean getBoolean() { return true; } + char getChar() { + return 'x'; + } + + byte getByte() { + return 123; + } + short getShort() { return 123; } @@ -104,17 +143,21 @@ float getFloat() { double getDouble() { return 123; } - - byte getByte() { - return 123; - } } static class MyReplacer implements MethodReplacer { + @Nullable + Object returnValue; + + void reset() { + this.returnValue = null; + } + @Override public Object reimplement(Object obj, Method method, Object[] args) { - return CglibSubclassingInstantiationStrategyTests.valueToReturnFromReplacer; + return this.returnValue; } } + }