From d77b715d38c856202d1d20d5b4ef42e05b8faea5 Mon Sep 17 00:00:00 2001 From: asympro Date: Fri, 25 May 2018 16:33:41 +0300 Subject: [PATCH] Merge class-level and method-level @Sql declarations See gh-1835 --- .../test/context/jdbc/Sql.java | 29 ++++++++- .../jdbc/SqlScriptsTestExecutionListener.java | 63 +++++++++++++------ ...epeatableSqlAnnotationSqlScriptsTests.java | 5 ++ .../test/context/jdbc/SqlMethodMergeTest.java | 30 +++++++++ .../context/jdbc/SqlMethodOverrideTest.java | 30 +++++++++ 5 files changed, 137 insertions(+), 20 deletions(-) create mode 100644 spring-test/src/test/java/org/springframework/test/context/jdbc/SqlMethodMergeTest.java create mode 100644 spring-test/src/test/java/org/springframework/test/context/jdbc/SqlMethodOverrideTest.java diff --git a/spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java b/spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java index 97ba3a428df8..27e7e2924802 100644 --- a/spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java +++ b/spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java @@ -31,7 +31,8 @@ * SQL {@link #scripts} and {@link #statements} to be executed against a given * database during integration tests. * - *

Method-level declarations override class-level declarations. + *

Method-level declarations override class-level declarations by default. + * This behaviour can be adjusted via {@link MergeMode} * *

Script execution is performed by the {@link SqlScriptsTestExecutionListener}, * which is enabled by default. @@ -146,6 +147,13 @@ */ SqlConfig config() default @SqlConfig; + /** + * Indicates whether this annotation should be merged with upper-level annotations + * or override them. + *

Defaults to {@link MergeMode#OVERRIDE}. + */ + MergeMode mergeMode() default MergeMode.OVERRIDE; + /** * Enumeration of phases that dictate when SQL scripts are executed. @@ -165,4 +173,23 @@ enum ExecutionPhase { AFTER_TEST_METHOD } + /** + * Enumeration of modes that dictate whether or not + * declared SQL {@link #scripts} and {@link #statements} are merged + * with the upper-level annotations. + */ + enum MergeMode { + + /** + * Indicates that locally declared SQL {@link #scripts} and {@link #statements} + * should override the upper-level (e.g. Class-level) annotations. + */ + OVERRIDE, + + /** + * Indicates that locally declared SQL {@link #scripts} and {@link #statements} + * should be merged the upper-level (e.g. Class-level) annotations. + */ + MERGE + } } diff --git a/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java b/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java index 299df20fd66b..0ad591976188 100644 --- a/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java +++ b/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java @@ -16,14 +16,17 @@ package org.springframework.test.context.jdbc; +import java.lang.reflect.AnnotatedElement; import java.lang.reflect.Method; import java.util.List; import java.util.Set; +import java.util.stream.Collectors; import javax.sql.DataSource; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.jetbrains.annotations.NotNull; import org.springframework.context.ApplicationContext; import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.core.io.ByteArrayResource; @@ -126,19 +129,35 @@ public void afterTestMethod(TestContext testContext) throws Exception { * {@link TestContext} and {@link ExecutionPhase}. */ private void executeSqlScripts(TestContext testContext, ExecutionPhase executionPhase) throws Exception { - boolean classLevel = false; - - Set sqlAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations( - testContext.getTestMethod(), Sql.class, SqlGroup.class); - if (sqlAnnotations.isEmpty()) { - sqlAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations( - testContext.getTestClass(), Sql.class, SqlGroup.class); - if (!sqlAnnotations.isEmpty()) { - classLevel = true; - } + Set methodLevelSqls = getScriptsFromElement(testContext.getTestMethod()); + List methodLevelOverrides = methodLevelSqls.stream() + .filter(s -> s.executionPhase() == executionPhase) + .filter(s -> s.mergeMode() == Sql.MergeMode.OVERRIDE) + .collect(Collectors.toList()); + if (methodLevelOverrides.isEmpty()) { + executeScripts(getScriptsFromElement(testContext.getTestClass()), testContext, executionPhase, true); + executeScripts(methodLevelSqls, testContext, executionPhase, false); + } else { + executeScripts(methodLevelOverrides, testContext, executionPhase, false); } + } + + /** + * Get SQL scripts configured via {@link Sql @Sql} for the supplied + * {@link AnnotatedElement}. + */ + private Set getScriptsFromElement(AnnotatedElement annotatedElement) throws Exception { + return AnnotatedElementUtils.getMergedRepeatableAnnotations(annotatedElement, Sql.class, SqlGroup.class); + } - for (Sql sql : sqlAnnotations) { + /** + * Execute given {@link Sql @Sql} scripts. + * {@link AnnotatedElement}. + */ + private void executeScripts(Iterable scripts, TestContext testContext, ExecutionPhase executionPhase, + boolean classLevel) + throws Exception { + for (Sql sql : scripts) { executeSqlScripts(sql, executionPhase, testContext, classLevel); } } @@ -166,14 +185,7 @@ private void executeSqlScripts(Sql sql, ExecutionPhase executionPhase, TestConte mergedSqlConfig, executionPhase, testContext)); } - final ResourceDatabasePopulator populator = new ResourceDatabasePopulator(); - populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding()); - populator.setSeparator(mergedSqlConfig.getSeparator()); - populator.setCommentPrefix(mergedSqlConfig.getCommentPrefix()); - populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter()); - populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter()); - populator.setContinueOnError(mergedSqlConfig.getErrorMode() == ErrorMode.CONTINUE_ON_ERROR); - populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == ErrorMode.IGNORE_FAILED_DROPS); + final ResourceDatabasePopulator populator = configurePopulator(mergedSqlConfig); String[] scripts = getScripts(sql, testContext, classLevel); scripts = TestContextResourceUtils.convertToClasspathResourcePaths(testContext.getTestClass(), scripts); @@ -232,6 +244,19 @@ private void executeSqlScripts(Sql sql, ExecutionPhase executionPhase, TestConte } } + @NotNull + private ResourceDatabasePopulator configurePopulator(MergedSqlConfig mergedSqlConfig) { + final ResourceDatabasePopulator populator = new ResourceDatabasePopulator(); + populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding()); + populator.setSeparator(mergedSqlConfig.getSeparator()); + populator.setCommentPrefix(mergedSqlConfig.getCommentPrefix()); + populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter()); + populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter()); + populator.setContinueOnError(mergedSqlConfig.getErrorMode() == ErrorMode.CONTINUE_ON_ERROR); + populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == ErrorMode.IGNORE_FAILED_DROPS); + return populator; + } + @Nullable private DataSource getDataSourceFromTransactionManager(PlatformTransactionManager transactionManager) { try { diff --git a/spring-test/src/test/java/org/springframework/test/context/jdbc/RepeatableSqlAnnotationSqlScriptsTests.java b/spring-test/src/test/java/org/springframework/test/context/jdbc/RepeatableSqlAnnotationSqlScriptsTests.java index c1b827a1b104..cbc6944a3156 100644 --- a/spring-test/src/test/java/org/springframework/test/context/jdbc/RepeatableSqlAnnotationSqlScriptsTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/jdbc/RepeatableSqlAnnotationSqlScriptsTests.java @@ -25,6 +25,7 @@ import org.springframework.test.annotation.DirtiesContext; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests; +import org.springframework.test.jdbc.JdbcTestUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -58,6 +59,10 @@ public void test02_methodLevelScripts() { assertNumUsers(2); } + protected int countRowsInTable(String tableName) { + return JdbcTestUtils.countRowsInTable(this.jdbcTemplate, tableName); + } + protected void assertNumUsers(int expected) { assertThat(countRowsInTable("user")).as("Number of rows in the 'user' table.").isEqualTo(expected); } diff --git a/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlMethodMergeTest.java b/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlMethodMergeTest.java new file mode 100644 index 000000000000..3a23a9db8780 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlMethodMergeTest.java @@ -0,0 +1,30 @@ +package org.springframework.test.context.jdbc; + +import org.junit.Test; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests; + +import static org.junit.Assert.assertEquals; + +/** + * Test to verify method level merge of @Sql annotations. + * + * @author Dmitry Semukhin + */ +@ContextConfiguration(classes = EmptyDatabaseConfig.class) +@Sql(value = {"schema.sql", "data-add-catbert.sql"}) +@DirtiesContext +public class SqlMethodMergeTest extends AbstractTransactionalJUnit4SpringContextTests { + + @Test + @Sql(value = "data-add-dogbert.sql", mergeMode = Sql.MergeMode.MERGE) + public void testMerge() { + assertNumUsers(2); + } + + protected void assertNumUsers(int expected) { + assertEquals("Number of rows in the 'user' table.", expected, countRowsInTable("user")); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlMethodOverrideTest.java b/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlMethodOverrideTest.java new file mode 100644 index 000000000000..6f4d8023c490 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlMethodOverrideTest.java @@ -0,0 +1,30 @@ +package org.springframework.test.context.jdbc; + +import org.junit.Test; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests; + +import static org.junit.Assert.assertEquals; + +/** + * Test to verify method level override of @Sql annotations. + * + * @author Dmitry Semukhin + */ +@ContextConfiguration(classes = EmptyDatabaseConfig.class) +@Sql(value = {"schema.sql", "data-add-catbert.sql"}) +@DirtiesContext +public class SqlMethodOverrideTest extends AbstractTransactionalJUnit4SpringContextTests { + + @Test + @Sql(value = {"schema.sql", "data.sql", "data-add-dogbert.sql", "data-add-catbert.sql"}, mergeMode = Sql.MergeMode.OVERRIDE) + public void testMerge() { + assertNumUsers(3); + } + + protected void assertNumUsers(int expected) { + assertEquals("Number of rows in the 'user' table.", expected, countRowsInTable("user")); + } + +}