diff --git a/spring-tx/src/main/java/org/springframework/transaction/annotation/RestrictedTransactionalEventListenerFactory.java b/spring-tx/src/main/java/org/springframework/transaction/annotation/RestrictedTransactionalEventListenerFactory.java index 73e4ff931591..e1473eb411e1 100644 --- a/spring-tx/src/main/java/org/springframework/transaction/annotation/RestrictedTransactionalEventListenerFactory.java +++ b/spring-tx/src/main/java/org/springframework/transaction/annotation/RestrictedTransactionalEventListenerFactory.java @@ -38,6 +38,11 @@ public class RestrictedTransactionalEventListenerFactory extends TransactionalEv @Override public ApplicationListener createApplicationListener(String beanName, Class type, Method method) { Transactional txAnn = AnnotatedElementUtils.findMergedAnnotation(method, Transactional.class); + + if (txAnn == null) { + txAnn = AnnotatedElementUtils.findMergedAnnotation(type, Transactional.class); + } + if (txAnn != null) { Propagation propagation = txAnn.propagation(); if (propagation != Propagation.REQUIRES_NEW && propagation != Propagation.NOT_SUPPORTED) { diff --git a/spring-tx/src/test/java/org/springframework/transaction/event/TransactionalApplicationListenerMethodAdapterTests.java b/spring-tx/src/test/java/org/springframework/transaction/event/TransactionalApplicationListenerMethodAdapterTests.java index 686e5f9d44f5..0143cbd868fa 100644 --- a/spring-tx/src/test/java/org/springframework/transaction/event/TransactionalApplicationListenerMethodAdapterTests.java +++ b/spring-tx/src/test/java/org/springframework/transaction/event/TransactionalApplicationListenerMethodAdapterTests.java @@ -28,6 +28,7 @@ import org.springframework.transaction.annotation.Propagation; import org.springframework.transaction.annotation.RestrictedTransactionalEventListenerFactory; import org.springframework.transaction.annotation.Transactional; +import org.springframework.transaction.event.TransactionalApplicationListenerMethodAdapterTests.SampleEvents.SampleEventsWithTransactionalAnnotation; import org.springframework.transaction.support.TransactionSynchronization; import org.springframework.transaction.support.TransactionSynchronizationManager; import org.springframework.util.ClassUtils; @@ -157,6 +158,34 @@ void withAsyncTransactionalAnnotation() { assertThatNoException().isThrownBy(() -> factory.createApplicationListener("test", SampleEvents.class, m)); } + @Test + void withTransactionalAnnotationOnEnclosingClass() { + RestrictedTransactionalEventListenerFactory factory = new RestrictedTransactionalEventListenerFactory(); + Method m = ReflectionUtils.findMethod(SampleEvents.SampleEventsWithTransactionalAnnotation.class, "defaultPhase", String.class); + assertThatIllegalStateException().isThrownBy(() -> factory.createApplicationListener("test", SampleEvents.SampleEventsWithTransactionalAnnotation.class, m)); + } + + @Test + void withTransactionalRequiresNewAnnotationAndTransactionalAnnotationOnEnclosingClass() { + RestrictedTransactionalEventListenerFactory factory = new RestrictedTransactionalEventListenerFactory(); + Method m = ReflectionUtils.findMethod(SampleEvents.SampleEventsWithTransactionalAnnotation.class, "withTransactionalRequiresNewAnnotation", String.class); + assertThatNoException().isThrownBy(() -> factory.createApplicationListener("test", SampleEvents.SampleEventsWithTransactionalAnnotation.class, m)); + } + + @Test + void withTransactionalNotSupportedAnnotationAndTransactionalAnnotationOnEnclosingClass() { + RestrictedTransactionalEventListenerFactory factory = new RestrictedTransactionalEventListenerFactory(); + Method m = ReflectionUtils.findMethod(SampleEvents.SampleEventsWithTransactionalAnnotation.class, "withTransactionalNotSupportedAnnotation", String.class); + assertThatNoException().isThrownBy(() -> factory.createApplicationListener("test", SampleEvents.SampleEventsWithTransactionalAnnotation.class, m)); + } + + @Test + void withAsyncTransactionalAnnotationAndTransactionalAnnotationOnEnclosingClass() { + RestrictedTransactionalEventListenerFactory factory = new RestrictedTransactionalEventListenerFactory(); + Method m = ReflectionUtils.findMethod(SampleEvents.SampleEventsWithTransactionalAnnotation.class, "withAsyncTransactionalAnnotation", String.class); + assertThatNoException().isThrownBy(() -> factory.createApplicationListener("test", SampleEvents.SampleEventsWithTransactionalAnnotation.class, m)); + } + private static void assertPhase(Method method, TransactionPhase expected) { assertThat(method).as("Method must not be null").isNotNull(); @@ -248,6 +277,29 @@ public void withTransactionalNotSupportedAnnotation(String data) { @Async @Transactional(propagation = Propagation.REQUIRES_NEW) public void withAsyncTransactionalAnnotation(String data) { } + + @Transactional + static class SampleEventsWithTransactionalAnnotation { + + @TransactionalEventListener + public void defaultPhase(String data) { + } + + @TransactionalEventListener + @Transactional(propagation = Propagation.REQUIRES_NEW) + public void withTransactionalRequiresNewAnnotation(String data) { + } + + @TransactionalEventListener + @Transactional(propagation = Propagation.NOT_SUPPORTED) + public void withTransactionalNotSupportedAnnotation(String data) { + } + + @TransactionalEventListener + @Async @Transactional(propagation = Propagation.REQUIRES_NEW) + public void withAsyncTransactionalAnnotation(String data) { + } + } } }