diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/R2dbcTransactionManager.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/R2dbcTransactionManager.java index 325eb5524ce9..0a4c6f101ceb 100644 --- a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/R2dbcTransactionManager.java +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/R2dbcTransactionManager.java @@ -337,6 +337,11 @@ protected Mono doCleanupAfterCompletion(TransactionSynchronizationManager return Mono.defer(() -> { ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) transaction; + if (txObject.hasSavepoint()) { + // Just release the savepoint + return Mono.defer(txObject::releaseSavepoint); + } + // Remove the connection holder from the context, if exposed. if (txObject.isNewConnectionHolder()) { synchronizationManager.unbindResource(obtainConnectionFactory()); @@ -511,17 +516,25 @@ public boolean isTransactionActive() { return (this.connectionHolder != null && this.connectionHolder.isTransactionActive()); } + public boolean hasSavepoint() { + return this.savepointName != null; + } + public Mono createSavepoint() { ConnectionHolder holder = getConnectionHolder(); this.savepointName = holder.nextSavepoint(); return Mono.from(holder.getConnection().createSavepoint(this.savepointName)); } + public Mono releaseSavepoint() { + String currentSavepointName = this.savepointName; + this.savepointName = null; + return Mono.from(getConnectionHolder().getConnection().releaseSavepoint(currentSavepointName)); + } + public Mono commit() { Connection connection = getConnectionHolder().getConnection(); - return (this.savepointName != null ? - Mono.from(connection.releaseSavepoint(this.savepointName)) : - Mono.from(connection.commitTransaction())); + return (this.savepointName != null ? Mono.empty() : Mono.from(connection.commitTransaction())); } public Mono rollback() { diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/R2dbcTransactionManagerUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/R2dbcTransactionManagerUnitTests.java index 99257abecbb1..cc0e938b1ef7 100644 --- a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/R2dbcTransactionManagerUnitTests.java +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/R2dbcTransactionManagerUnitTests.java @@ -27,6 +27,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -44,6 +46,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.inOrder; import static org.mockito.BDDMockito.mock; import static org.mockito.BDDMockito.never; import static org.mockito.BDDMockito.reset; @@ -382,63 +385,130 @@ void testPropagationNeverWithExistingTransaction() { @Test void testPropagationNestedWithExistingTransaction() { - when(connectionMock.createSavepoint("SAVEPOINT_1")).thenReturn(Mono.empty()); - when(connectionMock.releaseSavepoint("SAVEPOINT_1")).thenReturn(Mono.empty()); + when(connectionMock.createSavepoint(anyString())).thenReturn(Mono.empty()); + when(connectionMock.rollbackTransactionToSavepoint(anyString())).thenReturn(Mono.empty()); + when(connectionMock.releaseSavepoint(anyString())).thenReturn(Mono.empty()); when(connectionMock.commitTransaction()).thenReturn(Mono.empty()); DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW); TransactionalOperator operator = TransactionalOperator.create(tm, definition); - operator.execute(tx1 -> { - assertThat(tx1.hasTransaction()).isTrue(); - assertThat(tx1.isNewTransaction()).isTrue(); - assertThat(tx1.isNested()).isFalse(); - definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_NESTED); - return operator.execute(tx2 -> { - assertThat(tx2.hasTransaction()).isTrue(); - assertThat(tx2.isNewTransaction()).isTrue(); - assertThat(tx2.isNested()).isTrue(); - return Mono.empty(); - }); - }).as(StepVerifier::create).verifyComplete(); - - verify(connectionMock).createSavepoint("SAVEPOINT_1"); - verify(connectionMock).releaseSavepoint("SAVEPOINT_1"); - verify(connectionMock).commitTransaction(); - verify(connectionMock).close(); - } - - @Test - void testPropagationNestedWithExistingTransactionAndRollback() { - when(connectionMock.createSavepoint("SAVEPOINT_1")).thenReturn(Mono.empty()); - when(connectionMock.rollbackTransactionToSavepoint("SAVEPOINT_1")).thenReturn(Mono.empty()); - when(connectionMock.commitTransaction()).thenReturn(Mono.empty()); - - DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); - definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW); - - TransactionalOperator operator = TransactionalOperator.create(tm, definition); - operator.execute(tx1 -> { - assertThat(tx1.hasTransaction()).isTrue(); - assertThat(tx1.isNewTransaction()).isTrue(); - assertThat(tx1.isNested()).isFalse(); + operator.execute(tx -> { + assertThat(tx.hasTransaction()).isTrue(); + assertThat(tx.isNewTransaction()).isTrue(); + assertThat(tx.isNested()).isFalse(); definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_NESTED); - return operator.execute(tx2 -> { - assertThat(tx2.hasTransaction()).isTrue(); - assertThat(tx2.isNewTransaction()).isTrue(); - assertThat(tx2.isNested()).isTrue(); - assertThat(tx2.isRollbackOnly()).isFalse(); - tx2.setRollbackOnly(); - assertThat(tx2.isRollbackOnly()).isTrue(); - return Mono.empty(); - }); + return Flux.concat( + TransactionalOperator.create(tm, definition).execute(ntx1 -> { + assertThat(ntx1.hasTransaction()).as("ntx1.hasTransaction()").isTrue(); + assertThat(ntx1.isNewTransaction()).as("ntx1.isNewTransaction()").isTrue(); + assertThat(ntx1.isNested()).as("ntx1.isNested()").isTrue(); + assertThat(ntx1.isRollbackOnly()).as("ntx1.isRollbackOnly()").isFalse(); + return Mono.empty(); + }), + TransactionalOperator.create(tm, definition).execute(ntx2 -> { + assertThat(ntx2.hasTransaction()).as("ntx2.hasTransaction()").isTrue(); + assertThat(ntx2.isNewTransaction()).as("ntx2.isNewTransaction()").isTrue(); + assertThat(ntx2.isNested()).as("ntx2.isNested()").isTrue(); + assertThat(ntx2.isRollbackOnly()).as("ntx2.isRollbackOnly()").isFalse(); + ntx2.setRollbackOnly(); + assertThat(ntx2.isRollbackOnly()).isTrue(); + return Mono.empty(); + }), + TransactionalOperator.create(tm, definition).execute(ntx3 -> { + assertThat(ntx3.hasTransaction()).as("ntx3.hasTransaction()").isTrue(); + assertThat(ntx3.isNewTransaction()).as("ntx3.isNewTransaction()").isTrue(); + assertThat(ntx3.isNested()).as("ntx3.isNested()").isTrue(); + assertThat(ntx3.isRollbackOnly()).as("ntx3.isRollbackOnly()").isFalse(); + return Mono.empty(); + }), + TransactionalOperator.create(tm, definition).execute(ntx4 -> { + assertThat(ntx4.hasTransaction()).as("ntx4.hasTransaction()").isTrue(); + assertThat(ntx4.isNewTransaction()).as("ntx4.isNewTransaction()").isTrue(); + assertThat(ntx4.isNested()).as("ntx4.isNested()").isTrue(); + assertThat(ntx4.isRollbackOnly()).as("ntx4.isRollbackOnly()").isFalse(); + ntx4.setRollbackOnly(); + assertThat(ntx4.isRollbackOnly()).isTrue(); + return Flux.concat( + TransactionalOperator.create(tm, definition).execute(ntx4n1 -> { + assertThat(ntx4n1.hasTransaction()).as("ntx4n1.hasTransaction()").isTrue(); + assertThat(ntx4n1.isNewTransaction()).as("ntx4n1.isNewTransaction()").isTrue(); + assertThat(ntx4n1.isNested()).as("ntx4n1.isNested()").isTrue(); + assertThat(ntx4n1.isRollbackOnly()).as("ntx4n1.isRollbackOnly()").isFalse(); + return Mono.empty(); + }), + TransactionalOperator.create(tm, definition).execute(ntx4n2 -> { + assertThat(ntx4n2.hasTransaction()).as("ntx4n2.hasTransaction()").isTrue(); + assertThat(ntx4n2.isNewTransaction()).as("ntx4n2.isNewTransaction()").isTrue(); + assertThat(ntx4n2.isNested()).as("ntx4n2.isNested()").isTrue(); + assertThat(ntx4n2.isRollbackOnly()).as("ntx4n2.isRollbackOnly()").isFalse(); + ntx4n2.setRollbackOnly(); + assertThat(ntx4n2.isRollbackOnly()).isTrue(); + return Mono.empty(); + }) + ); + }), + TransactionalOperator.create(tm, definition).execute(ntx5 -> { + assertThat(ntx5.hasTransaction()).as("ntx5.hasTransaction()").isTrue(); + assertThat(ntx5.isNewTransaction()).as("ntx5.isNewTransaction()").isTrue(); + assertThat(ntx5.isNested()).as("ntx5.isNested()").isTrue(); + assertThat(ntx5.isRollbackOnly()).as("ntx5.isRollbackOnly()").isFalse(); + ntx5.setRollbackOnly(); + assertThat(ntx5.isRollbackOnly()).isTrue(); + return Flux.concat( + TransactionalOperator.create(tm, definition).execute(ntx5n1 -> { + assertThat(ntx5n1.hasTransaction()).as("ntx5n1.hasTransaction()").isTrue(); + assertThat(ntx5n1.isNewTransaction()).as("ntx5n1.isNewTransaction()").isTrue(); + assertThat(ntx5n1.isNested()).as("ntx5n1.isNested()").isTrue(); + assertThat(ntx5n1.isRollbackOnly()).as("ntx5n1.isRollbackOnly()").isFalse(); + return Mono.empty(); + }), + TransactionalOperator.create(tm, definition).execute(ntx5n2 -> { + assertThat(ntx5n2.hasTransaction()).as("ntx5n2.hasTransaction()").isTrue(); + assertThat(ntx5n2.isNewTransaction()).as("ntx5n2.isNewTransaction()").isTrue(); + assertThat(ntx5n2.isNested()).as("ntx5n2.isNested()").isTrue(); + assertThat(ntx5n2.isRollbackOnly()).as("ntx5n2.isRollbackOnly()").isFalse(); + ntx5n2.setRollbackOnly(); + assertThat(ntx5n2.isRollbackOnly()).isTrue(); + return Mono.empty(); + }) + ); + }) + ); }).as(StepVerifier::create).verifyComplete(); - verify(connectionMock).createSavepoint("SAVEPOINT_1"); - verify(connectionMock).rollbackTransactionToSavepoint("SAVEPOINT_1"); - verify(connectionMock).commitTransaction(); - verify(connectionMock).close(); + InOrder inOrder = inOrder(connectionMock); + // ntx1 + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_1"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_1"); + // ntx2 + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_2"); + inOrder.verify(connectionMock).rollbackTransactionToSavepoint("SAVEPOINT_2"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_2"); + // ntx3 + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_3"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_3"); + // ntx4 + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_4"); + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_5"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_5"); + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_6"); + inOrder.verify(connectionMock).rollbackTransactionToSavepoint("SAVEPOINT_6"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_6"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_4"); + // ntx5 + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_7"); + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_8"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_8"); + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_9"); + inOrder.verify(connectionMock).rollbackTransactionToSavepoint("SAVEPOINT_9"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_9"); + inOrder.verify(connectionMock).rollbackTransactionToSavepoint("SAVEPOINT_7"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_7"); + // tx + inOrder.verify(connectionMock).commitTransaction(); + inOrder.verify(connectionMock).close(); } @Test