diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/R2dbcExecutor.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/R2dbcExecutor.scala index fa118ef6..e89126aa 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/R2dbcExecutor.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/R2dbcExecutor.scala @@ -53,6 +53,13 @@ import reactor.core.publisher.Mono result.getRowsUpdated.asFuture().map(_.longValue())(ExecutionContexts.parasitic) } + def updateOneReturningInTx[A](stmt: Statement, mapRow: Row => A)(implicit ec: ExecutionContext): Future[A] = + stmt.execute().asFuture().flatMap { result => + Mono + .from[A](result.map((row, _) => mapRow(row))) + .asFuture() + } + def updateBatchInTx(stmt: Statement)(implicit ec: ExecutionContext): Future[Long] = { val consumer: BiConsumer[Long, java.lang.Long] = (acc, elem) => acc + elem.longValue() Flux @@ -195,12 +202,7 @@ class R2dbcExecutor( def updateOneReturning[A]( logPrefix: String)(statementFactory: Connection => Statement, mapRow: Row => A): Future[A] = { withAutoCommitConnection(logPrefix) { connection => - val stmt = statementFactory(connection) - stmt.execute().asFuture().flatMap { result => - Mono - .from[A](result.map((row, _) => mapRow(row))) - .asFuture() - } + updateOneReturningInTx(statementFactory(connection), mapRow) } } diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2JournalDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2JournalDao.scala index e1356ece..7eb530ec 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2JournalDao.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2JournalDao.scala @@ -16,11 +16,15 @@ import io.r2dbc.spi.ConnectionFactory import io.r2dbc.spi.Statement import org.slf4j.Logger import org.slf4j.LoggerFactory - import java.time.Instant + import scala.concurrent.ExecutionContext import scala.concurrent.Future +import io.r2dbc.spi.Connection + +import akka.persistence.r2dbc.internal.R2dbcExecutor + /** * INTERNAL API */ @@ -35,7 +39,7 @@ private[r2dbc] class H2JournalDao(journalSettings: R2dbcSettings, connectionFact require(journalSettings.useAppTimestamp) require(journalSettings.dbTimestampMonotonicIncreasing) - val insertSql = sql"INSERT INTO $journalTable " + + private val insertSql = sql"INSERT INTO $journalTable " + "(slice, entity_type, persistence_id, seq_nr, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, tags, meta_ser_id, meta_ser_manifest, meta_payload, db_timestamp) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" @@ -54,66 +58,74 @@ private[r2dbc] class H2JournalDao(journalSettings: R2dbcSettings, connectionFact // it's always the same persistenceId for all events val persistenceId = events.head.persistenceId - val previousSeqNr = events.head.seqNr - 1 - - def bind(stmt: Statement, write: SerializedJournalRow): Statement = { - stmt - .bind(0, write.slice) - .bind(1, write.entityType) - .bind(2, write.persistenceId) - .bind(3, write.seqNr) - .bind(4, write.writerUuid) - .bind(5, "") // FIXME event adapter - .bind(6, write.serId) - .bind(7, write.serManifest) - .bindPayload(8, write.payload.get) - - if (write.tags.isEmpty) - stmt.bindNull(9, classOf[Array[String]]) - else - stmt.bind(9, write.tags.toArray) - - // optional metadata - write.metadata match { - case Some(m) => - stmt - .bind(10, m.serId) - .bind(11, m.serManifest) - .bind(12, m.payload) - case None => - stmt - .bindNull(10, classOf[Integer]) - .bindNull(11, classOf[String]) - .bindNull(12, classOf[Array[Byte]]) + + val totalEvents = events.size + val result = + if (totalEvents == 1) { + r2dbcExecutor.updateOne(s"insert [$persistenceId]")(connection => + bindInsertStatement(connection.createStatement(insertSql), events.head)) + } else { + r2dbcExecutor.updateInBatch(s"batch insert [$persistenceId], [$totalEvents] events")(connection => + events.foldLeft(connection.createStatement(insertSql)) { (stmt, write) => + stmt.add() + bindInsertStatement(stmt, write) + }) } - stmt.bind(13, write.dbTimestamp) + if (log.isDebugEnabled()) + result.foreach { _ => + log.debug("Wrote [{}] events for persistenceId [{}]", 1, events.head.persistenceId) + } + result.map(_ => events.head.dbTimestamp)(ExecutionContexts.parasitic) + } - stmt - } + override def writeEventInTx(event: SerializedJournalRow, connection: Connection): Future[Instant] = { + val persistenceId = event.persistenceId - val totalEvents = events.size - if (totalEvents == 1) { - val result = r2dbcExecutor.updateOne(s"insert [$persistenceId]")(connection => - bind(connection.createStatement(insertSql), events.head)) - if (log.isDebugEnabled()) - result.foreach { _ => - log.debug("Wrote [{}] events for persistenceId [{}]", 1, events.head.persistenceId) - } - result.map(_ => events.head.dbTimestamp)(ExecutionContexts.parasitic) - } else { - val result = r2dbcExecutor.updateInBatch(s"batch insert [$persistenceId], [$totalEvents] events")(connection => - events.foldLeft(connection.createStatement(insertSql)) { (stmt, write) => - stmt.add() - bind(stmt, write) - }) - if (log.isDebugEnabled()) { - result.foreach { _ => - log.debug("Wrote [{}] events for persistenceId [{}]", totalEvents, events.head.persistenceId) - } + val stmt = bindInsertStatement(connection.createStatement(insertSql), event) + val result = R2dbcExecutor.updateOneInTx(stmt) + + if (log.isDebugEnabled()) + result.foreach { _ => + log.debug("Wrote [{}] event for persistenceId [{}]", 1, persistenceId) } - result.map(_ => events.head.dbTimestamp)(ExecutionContexts.parasitic) + result.map(_ => event.dbTimestamp)(ExecutionContexts.parasitic) + } + + private def bindInsertStatement(stmt: Statement, write: SerializedJournalRow): Statement = { + stmt + .bind(0, write.slice) + .bind(1, write.entityType) + .bind(2, write.persistenceId) + .bind(3, write.seqNr) + .bind(4, write.writerUuid) + .bind(5, "") // FIXME event adapter + .bind(6, write.serId) + .bind(7, write.serManifest) + .bindPayload(8, write.payload.get) + + if (write.tags.isEmpty) + stmt.bindNull(9, classOf[Array[String]]) + else + stmt.bind(9, write.tags.toArray) + + // optional metadata + write.metadata match { + case Some(m) => + stmt + .bind(10, m.serId) + .bind(11, m.serManifest) + .bind(12, m.payload) + case None => + stmt + .bindNull(10, classOf[Integer]) + .bindNull(11, classOf[String]) + .bindNull(12, classOf[Array[Byte]]) } + + stmt.bind(13, write.dbTimestamp) + + stmt } } diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresJournalDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresJournalDao.scala index 6af05f85..6358c5bf 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresJournalDao.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresJournalDao.scala @@ -10,7 +10,6 @@ import akka.annotation.InternalApi import akka.dispatch.ExecutionContexts import akka.persistence.Persistence import akka.persistence.r2dbc.R2dbcSettings -import akka.persistence.r2dbc.internal.BySliceQuery import akka.persistence.r2dbc.internal.JournalDao import akka.persistence.r2dbc.internal.PayloadCodec import akka.persistence.r2dbc.internal.PayloadCodec.RichStatement @@ -185,10 +184,8 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti if (useTimestampFromDb) insertEventWithTransactionTimestampSql else insertEventWithParameterTimestampSql - val result = r2dbcExecutor.updateOneReturning(s"insert [$persistenceId]")( - connection => - bindInsertStatement(connection.createStatement(insertSql), event, useTimestampFromDb, previousSeqNr), - row => row.get(0, classOf[Instant])) + val stmt = bindInsertStatement(connection.createStatement(insertSql), event, useTimestampFromDb, previousSeqNr) + val result = R2dbcExecutor.updateOneReturningInTx(stmt, row => row.get(0, classOf[Instant])) if (log.isDebugEnabled()) result.foreach { _ => log.debug("Wrote [{}] event for persistenceId [{}]", 1, persistenceId) diff --git a/core/src/test/scala/akka/persistence/r2dbc/state/DurableStateUpdateWithChangeEventStoreSpec.scala b/core/src/test/scala/akka/persistence/r2dbc/state/DurableStateUpdateWithChangeEventStoreSpec.scala index 205c7f1b..c7578c3d 100644 --- a/core/src/test/scala/akka/persistence/r2dbc/state/DurableStateUpdateWithChangeEventStoreSpec.scala +++ b/core/src/test/scala/akka/persistence/r2dbc/state/DurableStateUpdateWithChangeEventStoreSpec.scala @@ -4,6 +4,8 @@ package akka.persistence.r2dbc.state +import org.scalatest.concurrent.ScalaFutures.convertScalaFuture + import akka.actor.testkit.typed.scaladsl.LogCapturing import akka.actor.testkit.typed.scaladsl.ScalaTestWithActorTestKit import akka.actor.typed.ActorSystem @@ -24,6 +26,10 @@ import akka.persistence.typed.PersistenceId import akka.stream.scaladsl.Sink import org.scalatest.wordspec.AnyWordSpecLike +import akka.Done +import akka.persistence.r2dbc.TestActors.Persister +import akka.persistence.r2dbc.TestActors.Persister.PersistWithAck + class DurableStateUpdateWithChangeEventStoreSpec extends ScalaTestWithActorTestKit(TestConfig.config) with AnyWordSpecLike @@ -70,6 +76,24 @@ class DurableStateUpdateWithChangeEventStoreSpec env3.sequenceNr shouldBe 3L } + "save additional change event in same transaction" in { + // test rollback (same tx) if the journal insert fails via simulated unique constraint violation in event_journal + val entityType = nextEntityType() + val persistenceId = PersistenceId(entityType, "my-persistenceId").id + + val probe = testKit.createTestProbe[Done]() + val persister = testKit.spawn(Persister(persistenceId)) + persister ! PersistWithAck("a", probe.ref) + probe.expectMessage(Done) + testKit.stop(persister) + + val value1 = "Genuinely Collaborative" + + store.upsertObject(persistenceId, 1L, value1, tag, s"Changed to $value1").failed.futureValue + + store.getObject(persistenceId).futureValue.value shouldBe None + } + "detect and reject concurrent inserts, and not store change event" in { val entityType = nextEntityType() val persistenceId = PersistenceId(entityType, "id-to-be-inserted-concurrently").id