diff --git a/spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java b/spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java
index 62dba381649c..11743c4e80f7 100644
--- a/spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java
+++ b/spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java
@@ -31,7 +31,8 @@
* SQL {@link #scripts} and {@link #statements} to be executed against a given
* database during integration tests.
*
- *
Method-level declarations override class-level declarations.
+ *
Method-level declarations override class-level declarations by default.
+ * This behaviour can be adjusted via {@link MergeMode}
*
*
Script execution is performed by the {@link SqlScriptsTestExecutionListener},
* which is enabled by default.
@@ -146,6 +147,13 @@
*/
SqlConfig config() default @SqlConfig;
+ /**
+ * Indicates whether this annotation should be merged with upper-level annotations
+ * or override them.
+ *
Defaults to {@link MergeMode#OVERRIDE}.
+ */
+ MergeMode mergeMode() default MergeMode.OVERRIDE;
+
/**
* Enumeration of phases that dictate when SQL scripts are executed.
@@ -165,4 +173,23 @@ enum ExecutionPhase {
AFTER_TEST_METHOD
}
+ /**
+ * Enumeration of modes that dictate whether or not
+ * declared SQL {@link #scripts} and {@link #statements} are merged
+ * with the upper-level annotations.
+ */
+ enum MergeMode {
+
+ /**
+ * Indicates that locally declared SQL {@link #scripts} and {@link #statements}
+ * should override the upper-level (e.g. Class-level) annotations.
+ */
+ OVERRIDE,
+
+ /**
+ * Indicates that locally declared SQL {@link #scripts} and {@link #statements}
+ * should be merged the upper-level (e.g. Class-level) annotations.
+ */
+ MERGE
+ }
}
diff --git a/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java b/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java
index d34562c24e74..41e925799838 100644
--- a/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java
+++ b/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java
@@ -16,14 +16,17 @@
package org.springframework.test.context.jdbc;
+import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Set;
+import java.util.stream.Collectors;
import javax.sql.DataSource;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.jetbrains.annotations.NotNull;
import org.springframework.context.ApplicationContext;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.io.ByteArrayResource;
@@ -126,19 +129,35 @@ public void afterTestMethod(TestContext testContext) throws Exception {
* {@link TestContext} and {@link ExecutionPhase}.
*/
private void executeSqlScripts(TestContext testContext, ExecutionPhase executionPhase) throws Exception {
- boolean classLevel = false;
-
- Set sqlAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations(
- testContext.getTestMethod(), Sql.class, SqlGroup.class);
- if (sqlAnnotations.isEmpty()) {
- sqlAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations(
- testContext.getTestClass(), Sql.class, SqlGroup.class);
- if (!sqlAnnotations.isEmpty()) {
- classLevel = true;
- }
+ Set methodLevelSqls = getScriptsFromElement(testContext.getTestMethod());
+ List methodLevelOverrides = methodLevelSqls.stream()
+ .filter(s -> s.executionPhase() == executionPhase)
+ .filter(s -> s.mergeMode() == Sql.MergeMode.OVERRIDE)
+ .collect(Collectors.toList());
+ if (methodLevelOverrides.isEmpty()) {
+ executeScripts(getScriptsFromElement(testContext.getTestClass()), testContext, executionPhase, true);
+ executeScripts(methodLevelSqls, testContext, executionPhase, false);
+ } else {
+ executeScripts(methodLevelOverrides, testContext, executionPhase, false);
}
+ }
+
+ /**
+ * Get SQL scripts configured via {@link Sql @Sql} for the supplied
+ * {@link AnnotatedElement}.
+ */
+ private Set getScriptsFromElement(AnnotatedElement annotatedElement) throws Exception {
+ return AnnotatedElementUtils.getMergedRepeatableAnnotations(annotatedElement, Sql.class, SqlGroup.class);
+ }
- for (Sql sql : sqlAnnotations) {
+ /**
+ * Execute given {@link Sql @Sql} scripts.
+ * {@link AnnotatedElement}.
+ */
+ private void executeScripts(Iterable scripts, TestContext testContext, ExecutionPhase executionPhase,
+ boolean classLevel)
+ throws Exception {
+ for (Sql sql : scripts) {
executeSqlScripts(sql, executionPhase, testContext, classLevel);
}
}
@@ -166,14 +185,7 @@ private void executeSqlScripts(Sql sql, ExecutionPhase executionPhase, TestConte
mergedSqlConfig, executionPhase, testContext));
}
- final ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
- populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
- populator.setSeparator(mergedSqlConfig.getSeparator());
- populator.setCommentPrefix(mergedSqlConfig.getCommentPrefix());
- populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter());
- populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter());
- populator.setContinueOnError(mergedSqlConfig.getErrorMode() == ErrorMode.CONTINUE_ON_ERROR);
- populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == ErrorMode.IGNORE_FAILED_DROPS);
+ final ResourceDatabasePopulator populator = configurePopulator(mergedSqlConfig);
String[] scripts = getScripts(sql, testContext, classLevel);
scripts = TestContextResourceUtils.convertToClasspathResourcePaths(testContext.getTestClass(), scripts);
@@ -232,6 +244,19 @@ private void executeSqlScripts(Sql sql, ExecutionPhase executionPhase, TestConte
}
}
+ @NotNull
+ private ResourceDatabasePopulator configurePopulator(MergedSqlConfig mergedSqlConfig) {
+ final ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
+ populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
+ populator.setSeparator(mergedSqlConfig.getSeparator());
+ populator.setCommentPrefix(mergedSqlConfig.getCommentPrefix());
+ populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter());
+ populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter());
+ populator.setContinueOnError(mergedSqlConfig.getErrorMode() == ErrorMode.CONTINUE_ON_ERROR);
+ populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == ErrorMode.IGNORE_FAILED_DROPS);
+ return populator;
+ }
+
@Nullable
private DataSource getDataSourceFromTransactionManager(PlatformTransactionManager transactionManager) {
try {
diff --git a/spring-test/src/test/java/org/springframework/test/context/jdbc/RepeatableSqlAnnotationSqlScriptsTests.java b/spring-test/src/test/java/org/springframework/test/context/jdbc/RepeatableSqlAnnotationSqlScriptsTests.java
index 6d23fccef418..c2943bee382c 100644
--- a/spring-test/src/test/java/org/springframework/test/context/jdbc/RepeatableSqlAnnotationSqlScriptsTests.java
+++ b/spring-test/src/test/java/org/springframework/test/context/jdbc/RepeatableSqlAnnotationSqlScriptsTests.java
@@ -25,6 +25,7 @@
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;
+import org.springframework.test.jdbc.JdbcTestUtils;
import static org.junit.Assert.*;
@@ -58,6 +59,10 @@ public void test02_methodLevelScripts() {
assertNumUsers(2);
}
+ protected int countRowsInTable(String tableName) {
+ return JdbcTestUtils.countRowsInTable(this.jdbcTemplate, tableName);
+ }
+
protected void assertNumUsers(int expected) {
assertEquals("Number of rows in the 'user' table.", expected, countRowsInTable("user"));
}
diff --git a/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlMethodMergeTest.java b/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlMethodMergeTest.java
new file mode 100644
index 000000000000..3a23a9db8780
--- /dev/null
+++ b/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlMethodMergeTest.java
@@ -0,0 +1,30 @@
+package org.springframework.test.context.jdbc;
+
+import org.junit.Test;
+import org.springframework.test.annotation.DirtiesContext;
+import org.springframework.test.context.ContextConfiguration;
+import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Test to verify method level merge of @Sql annotations.
+ *
+ * @author Dmitry Semukhin
+ */
+@ContextConfiguration(classes = EmptyDatabaseConfig.class)
+@Sql(value = {"schema.sql", "data-add-catbert.sql"})
+@DirtiesContext
+public class SqlMethodMergeTest extends AbstractTransactionalJUnit4SpringContextTests {
+
+ @Test
+ @Sql(value = "data-add-dogbert.sql", mergeMode = Sql.MergeMode.MERGE)
+ public void testMerge() {
+ assertNumUsers(2);
+ }
+
+ protected void assertNumUsers(int expected) {
+ assertEquals("Number of rows in the 'user' table.", expected, countRowsInTable("user"));
+ }
+
+}
diff --git a/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlMethodOverrideTest.java b/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlMethodOverrideTest.java
new file mode 100644
index 000000000000..6f4d8023c490
--- /dev/null
+++ b/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlMethodOverrideTest.java
@@ -0,0 +1,30 @@
+package org.springframework.test.context.jdbc;
+
+import org.junit.Test;
+import org.springframework.test.annotation.DirtiesContext;
+import org.springframework.test.context.ContextConfiguration;
+import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Test to verify method level override of @Sql annotations.
+ *
+ * @author Dmitry Semukhin
+ */
+@ContextConfiguration(classes = EmptyDatabaseConfig.class)
+@Sql(value = {"schema.sql", "data-add-catbert.sql"})
+@DirtiesContext
+public class SqlMethodOverrideTest extends AbstractTransactionalJUnit4SpringContextTests {
+
+ @Test
+ @Sql(value = {"schema.sql", "data.sql", "data-add-dogbert.sql", "data-add-catbert.sql"}, mergeMode = Sql.MergeMode.OVERRIDE)
+ public void testMerge() {
+ assertNumUsers(3);
+ }
+
+ protected void assertNumUsers(int expected) {
+ assertEquals("Number of rows in the 'user' table.", expected, countRowsInTable("user"));
+ }
+
+}