Skip to content

Commit

Permalink
Register runtime hints for @Sql scripts
Browse files Browse the repository at this point in the history
SqlScriptsTestExecutionListener now implements AotTestExecutionListener
in order to register run-time hints for SQL scripts used by test
classes and test methods annotated with @Sql if the configured or
detected SQL scripts are classpath resources.

Closes gh-29027
  • Loading branch information
sbrannen committed Sep 5, 2022
1 parent e85e769 commit e57b5f1
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
package org.springframework.test.context.jdbc;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;

import javax.sql.DataSource;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.aot.hint.RuntimeHints;
import org.springframework.context.ApplicationContext;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.io.ByteArrayResource;
Expand All @@ -35,6 +38,7 @@
import org.springframework.lang.Nullable;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.TestContextAnnotationUtils;
import org.springframework.test.context.aot.AotTestExecutionListener;
import org.springframework.test.context.jdbc.Sql.ExecutionPhase;
import org.springframework.test.context.jdbc.SqlConfig.ErrorMode;
import org.springframework.test.context.jdbc.SqlConfig.TransactionMode;
Expand All @@ -52,9 +56,11 @@
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.ResourceUtils;
import org.springframework.util.ReflectionUtils.MethodFilter;
import org.springframework.util.StringUtils;

import static org.springframework.util.ResourceUtils.CLASSPATH_URL_PREFIX;

/**
* {@code TestExecutionListener} that provides support for executing SQL
* {@link Sql#scripts scripts} and inlined {@link Sql#statements statements}
Expand Down Expand Up @@ -90,18 +96,22 @@
* @since 4.1
* @see Sql
* @see SqlConfig
* @see SqlMergeMode
* @see SqlGroup
* @see org.springframework.test.context.transaction.TestContextTransactionUtils
* @see org.springframework.test.context.transaction.TransactionalTestExecutionListener
* @see org.springframework.jdbc.datasource.init.ResourceDatabasePopulator
* @see org.springframework.jdbc.datasource.init.ScriptUtils
*/
public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListener {
public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListener implements AotTestExecutionListener {

private static final String SLASH = "/";

private static final Log logger = LogFactory.getLog(SqlScriptsTestExecutionListener.class);

private static final MethodFilter sqlMethodFilter = ReflectionUtils.USER_DECLARED_METHODS
.and(method -> AnnotatedElementUtils.hasAnnotation(method, Sql.class));


/**
* Returns {@code 5000}.
Expand Down Expand Up @@ -129,6 +139,21 @@ public void afterTestMethod(TestContext testContext) {
executeSqlScripts(testContext, ExecutionPhase.AFTER_TEST_METHOD);
}

/**
* Process the supplied test class and its methods and register run-time
* hints for any SQL scripts configured or detected as classpath resources
* via {@link Sql @Sql}.
* @since 6.0
*/
@Override
public void processAheadOfTime(Class<?> testClass, RuntimeHints runtimeHints, ClassLoader classLoader) {
getSqlAnnotationsFor(testClass).forEach(sql ->
registerClasspathResources(runtimeHints, getScripts(sql, testClass, null, true)));
getSqlMethods(testClass).forEach(testMethod ->
getSqlAnnotationsFor(testMethod).forEach(sql ->
registerClasspathResources(runtimeHints, getScripts(sql, testClass, testMethod, false))));
}

/**
* Execute SQL scripts configured via {@link Sql @Sql} for the supplied
* {@link TestContext} and {@link ExecutionPhase}.
Expand Down Expand Up @@ -226,8 +251,7 @@ private void executeSqlScripts(
mergedSqlConfig, executionPhase, testContext));
}

String[] scripts = getScripts(sql, testContext, classLevel);
scripts = TestContextResourceUtils.convertToClasspathResourcePaths(testContext.getTestClass(), scripts);
String[] scripts = getScripts(sql, testContext.getTestClass(), testContext.getTestMethod(), classLevel);
List<Resource> scriptResources = TestContextResourceUtils.convertToResourceList(
testContext.getApplicationContext(), scripts);
for (String stmt : sql.statements()) {
Expand Down Expand Up @@ -321,31 +345,29 @@ private DataSource getDataSourceFromTransactionManager(PlatformTransactionManage
return null;
}

private String[] getScripts(Sql sql, TestContext testContext, boolean classLevel) {
private String[] getScripts(Sql sql, Class<?> testClass, Method testMethod, boolean classLevel) {
String[] scripts = sql.scripts();
if (ObjectUtils.isEmpty(scripts) && ObjectUtils.isEmpty(sql.statements())) {
scripts = new String[] {detectDefaultScript(testContext, classLevel)};
scripts = new String[] {detectDefaultScript(testClass, testMethod, classLevel)};
}
return scripts;
return TestContextResourceUtils.convertToClasspathResourcePaths(testClass, scripts);
}

/**
* Detect a default SQL script by implementing the algorithm defined in
* {@link Sql#scripts}.
*/
private String detectDefaultScript(TestContext testContext, boolean classLevel) {
Class<?> clazz = testContext.getTestClass();
Method method = testContext.getTestMethod();
private String detectDefaultScript(Class<?> testClass, Method testMethod, boolean classLevel) {
String elementType = (classLevel ? "class" : "method");
String elementName = (classLevel ? clazz.getName() : method.toString());
String elementName = (classLevel ? testClass.getName() : testMethod.toString());

String resourcePath = ClassUtils.convertClassNameToResourcePath(clazz.getName());
String resourcePath = ClassUtils.convertClassNameToResourcePath(testClass.getName());
if (!classLevel) {
resourcePath += "." + method.getName();
resourcePath += "." + testMethod.getName();
}
resourcePath += ".sql";

String prefixedResourcePath = ResourceUtils.CLASSPATH_URL_PREFIX + SLASH + resourcePath;
String prefixedResourcePath = CLASSPATH_URL_PREFIX + SLASH + resourcePath;
ClassPathResource classPathResource = new ClassPathResource(resourcePath);

if (classPathResource.exists()) {
Expand All @@ -364,4 +386,23 @@ private String detectDefaultScript(TestContext testContext, boolean classLevel)
}
}

private Stream<Method> getSqlMethods(Class<?> testClass) {
return Arrays.stream(ReflectionUtils.getUniqueDeclaredMethods(testClass, sqlMethodFilter));
}

private void registerClasspathResources(RuntimeHints runtimeHints, String... locations) {
Arrays.stream(locations)
.filter(location -> location.startsWith(CLASSPATH_URL_PREFIX))
.map(this::cleanClasspathResource)
.forEach(runtimeHints.resources()::registerPattern);
}

private String cleanClasspathResource(String location) {
location = location.substring(CLASSPATH_URL_PREFIX.length());
if (!location.startsWith(SLASH)) {
location = SLASH + location;
}
return location;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ private static void assertRuntimeHints(RuntimeHints runtimeHints) {
// @WebAppConfiguration(value = ...)
assertThat(resource().forResource("/META-INF/web-resources/resources/Spring.js")).accepts(runtimeHints);
assertThat(resource().forResource("/META-INF/web-resources/WEB-INF/views/home.jsp")).accepts(runtimeHints);

// @Sql(scripts = ...)
assertThat(resource().forResource("/org/springframework/test/context/jdbc/schema.sql"))
.accepts(runtimeHints);
assertThat(resource().forResource("/org/springframework/test/context/aot/samples/jdbc/SqlScriptsSpringJupiterTests.test.sql"))
.accepts(runtimeHints);
}

private static void assertReflectionRegistered(RuntimeHints runtimeHints, String type, MemberCategory memberCategory) {
Expand Down

0 comments on commit e57b5f1

Please sign in to comment.