diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/R2dbcTransactionManager.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/R2dbcTransactionManager.java index 325eb5524ce9..f3edbf16b9af 100644 --- a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/R2dbcTransactionManager.java +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/R2dbcTransactionManager.java @@ -337,36 +337,36 @@ protected Mono doCleanupAfterCompletion(TransactionSynchronizationManager return Mono.defer(() -> { ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) transaction; + if (txObject.hasSavepoint()) { + // Just release the savepoint, keeping the transactional connection. + return txObject.releaseSavepoint(); + } + // Remove the connection holder from the context, if exposed. if (txObject.isNewConnectionHolder()) { synchronizationManager.unbindResource(obtainConnectionFactory()); } // Reset connection. - Connection con = txObject.getConnectionHolder().getConnection(); - - Mono afterCleanup = Mono.empty(); - - Mono releaseConnectionStep = Mono.defer(() -> { - try { - if (txObject.isNewConnectionHolder()) { - if (logger.isDebugEnabled()) { - logger.debug("Releasing R2DBC Connection [" + con + "] after transaction"); - } - Mono releaseMono = ConnectionFactoryUtils.releaseConnection(con, obtainConnectionFactory()); - if (logger.isDebugEnabled()) { - releaseMono = releaseMono.doOnError( - ex -> logger.debug(String.format("Error ignored during cleanup: %s", ex))); - } - return releaseMono.onErrorComplete(); + try { + if (txObject.isNewConnectionHolder()) { + Connection con = txObject.getConnectionHolder().getConnection(); + if (logger.isDebugEnabled()) { + logger.debug("Releasing R2DBC Connection [" + con + "] after transaction"); } + Mono releaseMono = ConnectionFactoryUtils.releaseConnection(con, obtainConnectionFactory()); + if (logger.isDebugEnabled()) { + releaseMono = releaseMono.doOnError( + ex -> logger.debug(String.format("Error ignored during cleanup: %s", ex))); + } + return releaseMono.onErrorComplete(); } - finally { - txObject.getConnectionHolder().clear(); - } - return Mono.empty(); - }); - return afterCleanup.then(releaseConnectionStep); + } + finally { + txObject.getConnectionHolder().clear(); + } + + return Mono.empty(); }); } @@ -511,23 +511,36 @@ public boolean isTransactionActive() { return (this.connectionHolder != null && this.connectionHolder.isTransactionActive()); } + public boolean hasSavepoint() { + return (this.savepointName != null); + } + public Mono createSavepoint() { ConnectionHolder holder = getConnectionHolder(); - this.savepointName = holder.nextSavepoint(); - return Mono.from(holder.getConnection().createSavepoint(this.savepointName)); + String currentSavepoint = holder.nextSavepoint(); + this.savepointName = currentSavepoint; + return Mono.from(holder.getConnection().createSavepoint(currentSavepoint)); + } + + public Mono releaseSavepoint() { + String currentSavepoint = this.savepointName; + if (currentSavepoint == null) { + return Mono.empty(); + } + this.savepointName = null; + return Mono.from(getConnectionHolder().getConnection().releaseSavepoint(currentSavepoint)); } public Mono commit() { - Connection connection = getConnectionHolder().getConnection(); - return (this.savepointName != null ? - Mono.from(connection.releaseSavepoint(this.savepointName)) : - Mono.from(connection.commitTransaction())); + return (hasSavepoint() ? Mono.empty() : + Mono.from(getConnectionHolder().getConnection().commitTransaction())); } public Mono rollback() { Connection connection = getConnectionHolder().getConnection(); - return (this.savepointName != null ? - Mono.from(connection.rollbackTransactionToSavepoint(this.savepointName)) : + String currentSavepoint = this.savepointName; + return (currentSavepoint != null ? + Mono.from(connection.rollbackTransactionToSavepoint(currentSavepoint)) : Mono.from(connection.rollbackTransaction())); } diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/R2dbcTransactionManagerUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/R2dbcTransactionManagerUnitTests.java index 99cf809646a1..05cc75cfc04a 100644 --- a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/R2dbcTransactionManagerUnitTests.java +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/R2dbcTransactionManagerUnitTests.java @@ -27,6 +27,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -44,6 +46,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.inOrder; import static org.mockito.BDDMockito.mock; import static org.mockito.BDDMockito.never; import static org.mockito.BDDMockito.reset; @@ -365,53 +368,110 @@ void testPropagationNeverWithExistingTransaction() { @Test void testPropagationNestedWithExistingTransaction() { - when(connectionMock.createSavepoint("SAVEPOINT_1")).thenReturn(Mono.empty()); - when(connectionMock.releaseSavepoint("SAVEPOINT_1")).thenReturn(Mono.empty()); + when(connectionMock.createSavepoint(anyString())).thenReturn(Mono.empty()); + when(connectionMock.rollbackTransactionToSavepoint(anyString())).thenReturn(Mono.empty()); + when(connectionMock.releaseSavepoint(anyString())).thenReturn(Mono.empty()); when(connectionMock.commitTransaction()).thenReturn(Mono.empty()); DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW); TransactionalOperator operator = TransactionalOperator.create(tm, definition); - operator.execute(tx1 -> { - assertThat(tx1.isNewTransaction()).isTrue(); - definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_NESTED); - return operator.execute(tx2 -> { - assertThat(tx2.isNewTransaction()).isTrue(); - return Mono.empty(); - }); - }).as(StepVerifier::create).verifyComplete(); - - verify(connectionMock).createSavepoint("SAVEPOINT_1"); - verify(connectionMock).releaseSavepoint("SAVEPOINT_1"); - verify(connectionMock).commitTransaction(); - verify(connectionMock).close(); - } - - @Test - void testPropagationNestedWithExistingTransactionAndRollback() { - when(connectionMock.createSavepoint("SAVEPOINT_1")).thenReturn(Mono.empty()); - when(connectionMock.rollbackTransactionToSavepoint("SAVEPOINT_1")).thenReturn(Mono.empty()); - when(connectionMock.commitTransaction()).thenReturn(Mono.empty()); - - DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); - definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW); - - TransactionalOperator operator = TransactionalOperator.create(tm, definition); - operator.execute(tx1 -> { - assertThat(tx1.isNewTransaction()).isTrue(); + operator.execute(tx -> { + assertThat(tx.isNewTransaction()).isTrue(); definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_NESTED); - return operator.execute(tx2 -> { - assertThat(tx2.isNewTransaction()).isTrue(); - tx2.setRollbackOnly(); - return Mono.empty(); - }); + return Flux.concat( + TransactionalOperator.create(tm, definition).execute(ntx1 -> { + assertThat(ntx1.isNewTransaction()).as("ntx1.isNewTransaction()").isTrue(); + assertThat(ntx1.isRollbackOnly()).as("ntx1.isRollbackOnly()").isFalse(); + return Mono.empty(); + }), + TransactionalOperator.create(tm, definition).execute(ntx2 -> { + assertThat(ntx2.isNewTransaction()).as("ntx2.isNewTransaction()").isTrue(); + assertThat(ntx2.isRollbackOnly()).as("ntx2.isRollbackOnly()").isFalse(); + ntx2.setRollbackOnly(); + assertThat(ntx2.isRollbackOnly()).isTrue(); + return Mono.empty(); + }), + TransactionalOperator.create(tm, definition).execute(ntx3 -> { + assertThat(ntx3.isNewTransaction()).as("ntx3.isNewTransaction()").isTrue(); + assertThat(ntx3.isRollbackOnly()).as("ntx3.isRollbackOnly()").isFalse(); + return Mono.empty(); + }), + TransactionalOperator.create(tm, definition).execute(ntx4 -> { + assertThat(ntx4.isNewTransaction()).as("ntx4.isNewTransaction()").isTrue(); + assertThat(ntx4.isRollbackOnly()).as("ntx4.isRollbackOnly()").isFalse(); + ntx4.setRollbackOnly(); + assertThat(ntx4.isRollbackOnly()).isTrue(); + return Flux.concat( + TransactionalOperator.create(tm, definition).execute(ntx4n1 -> { + assertThat(ntx4n1.isNewTransaction()).as("ntx4n1.isNewTransaction()").isTrue(); + assertThat(ntx4n1.isRollbackOnly()).as("ntx4n1.isRollbackOnly()").isFalse(); + return Mono.empty(); + }), + TransactionalOperator.create(tm, definition).execute(ntx4n2 -> { + assertThat(ntx4n2.isNewTransaction()).as("ntx4n2.isNewTransaction()").isTrue(); + assertThat(ntx4n2.isRollbackOnly()).as("ntx4n2.isRollbackOnly()").isFalse(); + ntx4n2.setRollbackOnly(); + assertThat(ntx4n2.isRollbackOnly()).isTrue(); + return Mono.empty(); + }) + ); + }), + TransactionalOperator.create(tm, definition).execute(ntx5 -> { + assertThat(ntx5.isNewTransaction()).as("ntx5.isNewTransaction()").isTrue(); + assertThat(ntx5.isRollbackOnly()).as("ntx5.isRollbackOnly()").isFalse(); + ntx5.setRollbackOnly(); + assertThat(ntx5.isRollbackOnly()).isTrue(); + return Flux.concat( + TransactionalOperator.create(tm, definition).execute(ntx5n1 -> { + assertThat(ntx5n1.isNewTransaction()).as("ntx5n1.isNewTransaction()").isTrue(); + assertThat(ntx5n1.isRollbackOnly()).as("ntx5n1.isRollbackOnly()").isFalse(); + return Mono.empty(); + }), + TransactionalOperator.create(tm, definition).execute(ntx5n2 -> { + assertThat(ntx5n2.isNewTransaction()).as("ntx5n2.isNewTransaction()").isTrue(); + assertThat(ntx5n2.isRollbackOnly()).as("ntx5n2.isRollbackOnly()").isFalse(); + ntx5n2.setRollbackOnly(); + assertThat(ntx5n2.isRollbackOnly()).isTrue(); + return Mono.empty(); + }) + ); + }) + ); }).as(StepVerifier::create).verifyComplete(); - verify(connectionMock).createSavepoint("SAVEPOINT_1"); - verify(connectionMock).rollbackTransactionToSavepoint("SAVEPOINT_1"); - verify(connectionMock).commitTransaction(); - verify(connectionMock).close(); + InOrder inOrder = inOrder(connectionMock); + // ntx1 + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_1"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_1"); + // ntx2 + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_2"); + inOrder.verify(connectionMock).rollbackTransactionToSavepoint("SAVEPOINT_2"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_2"); + // ntx3 + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_3"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_3"); + // ntx4 + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_4"); + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_5"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_5"); + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_6"); + inOrder.verify(connectionMock).rollbackTransactionToSavepoint("SAVEPOINT_6"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_6"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_4"); + // ntx5 + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_7"); + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_8"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_8"); + inOrder.verify(connectionMock).createSavepoint("SAVEPOINT_9"); + inOrder.verify(connectionMock).rollbackTransactionToSavepoint("SAVEPOINT_9"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_9"); + inOrder.verify(connectionMock).rollbackTransactionToSavepoint("SAVEPOINT_7"); + inOrder.verify(connectionMock).releaseSavepoint("SAVEPOINT_7"); + // tx + inOrder.verify(connectionMock).commitTransaction(); + inOrder.verify(connectionMock).close(); } @Test @@ -452,7 +512,9 @@ void testPropagationSupportsAndNestedWithRollback() { TransactionalOperator inner = TransactionalOperator.create(tm, innerDef); return inner.execute(tx2 -> { assertThat(tx2.isNewTransaction()).isTrue(); + assertThat(tx2.isRollbackOnly()).isFalse(); tx2.setRollbackOnly(); + assertThat(tx2.isRollbackOnly()).isTrue(); return Mono.empty(); }); }).as(StepVerifier::create).verifyComplete(); @@ -499,7 +561,9 @@ void testPropagationSupportsAndRequiresNewWithRollback() { TransactionalOperator inner = TransactionalOperator.create(tm, innerDef); return inner.execute(tx2 -> { assertThat(tx2.isNewTransaction()).isTrue(); + assertThat(tx2.isRollbackOnly()).isFalse(); tx2.setRollbackOnly(); + assertThat(tx2.isRollbackOnly()).isTrue(); return Mono.empty(); }); }).as(StepVerifier::create).verifyComplete();