Skip to content

Commit

Permalink
Merge class-level and method-level @SQL declarations
Browse files Browse the repository at this point in the history
  • Loading branch information
asympro authored and sbrannen committed Jul 21, 2019
1 parent b0939a8 commit d77b715
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
* SQL {@link #scripts} and {@link #statements} to be executed against a given
* database during integration tests.
*
* <p>Method-level declarations override class-level declarations.
* <p>Method-level declarations override class-level declarations by default.
* This behaviour can be adjusted via {@link MergeMode}
*
* <p>Script execution is performed by the {@link SqlScriptsTestExecutionListener},
* which is enabled by default.
Expand Down Expand Up @@ -146,6 +147,13 @@
*/
SqlConfig config() default @SqlConfig;

/**
* Indicates whether this annotation should be merged with upper-level annotations
* or override them.
* <p>Defaults to {@link MergeMode#OVERRIDE}.
*/
MergeMode mergeMode() default MergeMode.OVERRIDE;


/**
* Enumeration of <em>phases</em> that dictate when SQL scripts are executed.
Expand All @@ -165,4 +173,23 @@ enum ExecutionPhase {
AFTER_TEST_METHOD
}

/**
* Enumeration of <em>modes</em> that dictate whether or not
* declared SQL {@link #scripts} and {@link #statements} are merged
* with the upper-level annotations.
*/
enum MergeMode {

/**
* Indicates that locally declared SQL {@link #scripts} and {@link #statements}
* should override the upper-level (e.g. Class-level) annotations.
*/
OVERRIDE,

/**
* Indicates that locally declared SQL {@link #scripts} and {@link #statements}
* should be merged the upper-level (e.g. Class-level) annotations.
*/
MERGE
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@

package org.springframework.test.context.jdbc;

import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import javax.sql.DataSource;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.jetbrains.annotations.NotNull;
import org.springframework.context.ApplicationContext;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.io.ByteArrayResource;
Expand Down Expand Up @@ -126,19 +129,35 @@ public void afterTestMethod(TestContext testContext) throws Exception {
* {@link TestContext} and {@link ExecutionPhase}.
*/
private void executeSqlScripts(TestContext testContext, ExecutionPhase executionPhase) throws Exception {
boolean classLevel = false;

Set<Sql> sqlAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations(
testContext.getTestMethod(), Sql.class, SqlGroup.class);
if (sqlAnnotations.isEmpty()) {
sqlAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations(
testContext.getTestClass(), Sql.class, SqlGroup.class);
if (!sqlAnnotations.isEmpty()) {
classLevel = true;
}
Set<Sql> methodLevelSqls = getScriptsFromElement(testContext.getTestMethod());
List<Sql> methodLevelOverrides = methodLevelSqls.stream()
.filter(s -> s.executionPhase() == executionPhase)
.filter(s -> s.mergeMode() == Sql.MergeMode.OVERRIDE)
.collect(Collectors.toList());
if (methodLevelOverrides.isEmpty()) {
executeScripts(getScriptsFromElement(testContext.getTestClass()), testContext, executionPhase, true);
executeScripts(methodLevelSqls, testContext, executionPhase, false);
} else {
executeScripts(methodLevelOverrides, testContext, executionPhase, false);
}
}

/**
* Get SQL scripts configured via {@link Sql @Sql} for the supplied
* {@link AnnotatedElement}.
*/
private Set<Sql> getScriptsFromElement(AnnotatedElement annotatedElement) throws Exception {
return AnnotatedElementUtils.getMergedRepeatableAnnotations(annotatedElement, Sql.class, SqlGroup.class);
}

for (Sql sql : sqlAnnotations) {
/**
* Execute given {@link Sql @Sql} scripts.
* {@link AnnotatedElement}.
*/
private void executeScripts(Iterable<Sql> scripts, TestContext testContext, ExecutionPhase executionPhase,
boolean classLevel)
throws Exception {
for (Sql sql : scripts) {
executeSqlScripts(sql, executionPhase, testContext, classLevel);
}
}
Expand Down Expand Up @@ -166,14 +185,7 @@ private void executeSqlScripts(Sql sql, ExecutionPhase executionPhase, TestConte
mergedSqlConfig, executionPhase, testContext));
}

final ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
populator.setSeparator(mergedSqlConfig.getSeparator());
populator.setCommentPrefix(mergedSqlConfig.getCommentPrefix());
populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter());
populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter());
populator.setContinueOnError(mergedSqlConfig.getErrorMode() == ErrorMode.CONTINUE_ON_ERROR);
populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == ErrorMode.IGNORE_FAILED_DROPS);
final ResourceDatabasePopulator populator = configurePopulator(mergedSqlConfig);

String[] scripts = getScripts(sql, testContext, classLevel);
scripts = TestContextResourceUtils.convertToClasspathResourcePaths(testContext.getTestClass(), scripts);
Expand Down Expand Up @@ -232,6 +244,19 @@ private void executeSqlScripts(Sql sql, ExecutionPhase executionPhase, TestConte
}
}

@NotNull
private ResourceDatabasePopulator configurePopulator(MergedSqlConfig mergedSqlConfig) {
final ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
populator.setSeparator(mergedSqlConfig.getSeparator());
populator.setCommentPrefix(mergedSqlConfig.getCommentPrefix());
populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter());
populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter());
populator.setContinueOnError(mergedSqlConfig.getErrorMode() == ErrorMode.CONTINUE_ON_ERROR);
populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == ErrorMode.IGNORE_FAILED_DROPS);
return populator;
}

@Nullable
private DataSource getDataSourceFromTransactionManager(PlatformTransactionManager transactionManager) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;
import org.springframework.test.jdbc.JdbcTestUtils;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down Expand Up @@ -58,6 +59,10 @@ public void test02_methodLevelScripts() {
assertNumUsers(2);
}

protected int countRowsInTable(String tableName) {
return JdbcTestUtils.countRowsInTable(this.jdbcTemplate, tableName);
}

protected void assertNumUsers(int expected) {
assertThat(countRowsInTable("user")).as("Number of rows in the 'user' table.").isEqualTo(expected);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.springframework.test.context.jdbc;

import org.junit.Test;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;

import static org.junit.Assert.assertEquals;

/**
* Test to verify method level merge of @Sql annotations.
*
* @author Dmitry Semukhin
*/
@ContextConfiguration(classes = EmptyDatabaseConfig.class)
@Sql(value = {"schema.sql", "data-add-catbert.sql"})
@DirtiesContext
public class SqlMethodMergeTest extends AbstractTransactionalJUnit4SpringContextTests {

@Test
@Sql(value = "data-add-dogbert.sql", mergeMode = Sql.MergeMode.MERGE)
public void testMerge() {
assertNumUsers(2);
}

protected void assertNumUsers(int expected) {
assertEquals("Number of rows in the 'user' table.", expected, countRowsInTable("user"));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.springframework.test.context.jdbc;

import org.junit.Test;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;

import static org.junit.Assert.assertEquals;

/**
* Test to verify method level override of @Sql annotations.
*
* @author Dmitry Semukhin
*/
@ContextConfiguration(classes = EmptyDatabaseConfig.class)
@Sql(value = {"schema.sql", "data-add-catbert.sql"})
@DirtiesContext
public class SqlMethodOverrideTest extends AbstractTransactionalJUnit4SpringContextTests {

@Test
@Sql(value = {"schema.sql", "data.sql", "data-add-dogbert.sql", "data-add-catbert.sql"}, mergeMode = Sql.MergeMode.OVERRIDE)
public void testMerge() {
assertNumUsers(3);
}

protected void assertNumUsers(int expected) {
assertEquals("Number of rows in the 'user' table.", expected, countRowsInTable("user"));
}

}

0 comments on commit d77b715

Please sign in to comment.