Skip to content

Commit

Permalink
override in H2JournalDao, and actually same tx
Browse files Browse the repository at this point in the history
  • Loading branch information
patriknw committed Dec 19, 2023
1 parent d1c927b commit fbe67e0
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ import reactor.core.publisher.Mono
result.getRowsUpdated.asFuture().map(_.longValue())(ExecutionContexts.parasitic)
}

def updateOneReturningInTx[A](stmt: Statement, mapRow: Row => A)(implicit ec: ExecutionContext): Future[A] =
stmt.execute().asFuture().flatMap { result =>
Mono
.from[A](result.map((row, _) => mapRow(row)))
.asFuture()
}

def updateBatchInTx(stmt: Statement)(implicit ec: ExecutionContext): Future[Long] = {
val consumer: BiConsumer[Long, java.lang.Long] = (acc, elem) => acc + elem.longValue()
Flux
Expand Down Expand Up @@ -195,12 +202,7 @@ class R2dbcExecutor(
def updateOneReturning[A](
logPrefix: String)(statementFactory: Connection => Statement, mapRow: Row => A): Future[A] = {
withAutoCommitConnection(logPrefix) { connection =>
val stmt = statementFactory(connection)
stmt.execute().asFuture().flatMap { result =>
Mono
.from[A](result.map((row, _) => mapRow(row)))
.asFuture()
}
updateOneReturningInTx(statementFactory(connection), mapRow)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ import io.r2dbc.spi.ConnectionFactory
import io.r2dbc.spi.Statement
import org.slf4j.Logger
import org.slf4j.LoggerFactory

import java.time.Instant

import scala.concurrent.ExecutionContext
import scala.concurrent.Future

import io.r2dbc.spi.Connection

import akka.persistence.r2dbc.internal.R2dbcExecutor

/**
* INTERNAL API
*/
Expand All @@ -35,7 +39,7 @@ private[r2dbc] class H2JournalDao(journalSettings: R2dbcSettings, connectionFact
require(journalSettings.useAppTimestamp)
require(journalSettings.dbTimestampMonotonicIncreasing)

val insertSql = sql"INSERT INTO $journalTable " +
private val insertSql = sql"INSERT INTO $journalTable " +
"(slice, entity_type, persistence_id, seq_nr, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, tags, meta_ser_id, meta_ser_manifest, meta_payload, db_timestamp) " +
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"

Expand All @@ -54,66 +58,74 @@ private[r2dbc] class H2JournalDao(journalSettings: R2dbcSettings, connectionFact

// it's always the same persistenceId for all events
val persistenceId = events.head.persistenceId
val previousSeqNr = events.head.seqNr - 1

def bind(stmt: Statement, write: SerializedJournalRow): Statement = {
stmt
.bind(0, write.slice)
.bind(1, write.entityType)
.bind(2, write.persistenceId)
.bind(3, write.seqNr)
.bind(4, write.writerUuid)
.bind(5, "") // FIXME event adapter
.bind(6, write.serId)
.bind(7, write.serManifest)
.bindPayload(8, write.payload.get)

if (write.tags.isEmpty)
stmt.bindNull(9, classOf[Array[String]])
else
stmt.bind(9, write.tags.toArray)

// optional metadata
write.metadata match {
case Some(m) =>
stmt
.bind(10, m.serId)
.bind(11, m.serManifest)
.bind(12, m.payload)
case None =>
stmt
.bindNull(10, classOf[Integer])
.bindNull(11, classOf[String])
.bindNull(12, classOf[Array[Byte]])

val totalEvents = events.size
val result =
if (totalEvents == 1) {
r2dbcExecutor.updateOne(s"insert [$persistenceId]")(connection =>
bindInsertStatement(connection.createStatement(insertSql), events.head))
} else {
r2dbcExecutor.updateInBatch(s"batch insert [$persistenceId], [$totalEvents] events")(connection =>
events.foldLeft(connection.createStatement(insertSql)) { (stmt, write) =>
stmt.add()
bindInsertStatement(stmt, write)
})
}

stmt.bind(13, write.dbTimestamp)
if (log.isDebugEnabled())
result.foreach { _ =>
log.debug("Wrote [{}] events for persistenceId [{}]", 1, events.head.persistenceId)
}
result.map(_ => events.head.dbTimestamp)(ExecutionContexts.parasitic)
}

stmt
}
override def writeEventInTx(event: SerializedJournalRow, connection: Connection): Future[Instant] = {
val persistenceId = event.persistenceId

val totalEvents = events.size
if (totalEvents == 1) {
val result = r2dbcExecutor.updateOne(s"insert [$persistenceId]")(connection =>
bind(connection.createStatement(insertSql), events.head))
if (log.isDebugEnabled())
result.foreach { _ =>
log.debug("Wrote [{}] events for persistenceId [{}]", 1, events.head.persistenceId)
}
result.map(_ => events.head.dbTimestamp)(ExecutionContexts.parasitic)
} else {
val result = r2dbcExecutor.updateInBatch(s"batch insert [$persistenceId], [$totalEvents] events")(connection =>
events.foldLeft(connection.createStatement(insertSql)) { (stmt, write) =>
stmt.add()
bind(stmt, write)
})
if (log.isDebugEnabled()) {
result.foreach { _ =>
log.debug("Wrote [{}] events for persistenceId [{}]", totalEvents, events.head.persistenceId)
}
val stmt = bindInsertStatement(connection.createStatement(insertSql), event)
val result = R2dbcExecutor.updateOneInTx(stmt)

if (log.isDebugEnabled())
result.foreach { _ =>
log.debug("Wrote [{}] event for persistenceId [{}]", 1, persistenceId)
}
result.map(_ => events.head.dbTimestamp)(ExecutionContexts.parasitic)
result.map(_ => event.dbTimestamp)(ExecutionContexts.parasitic)
}

private def bindInsertStatement(stmt: Statement, write: SerializedJournalRow): Statement = {
stmt
.bind(0, write.slice)
.bind(1, write.entityType)
.bind(2, write.persistenceId)
.bind(3, write.seqNr)
.bind(4, write.writerUuid)
.bind(5, "") // FIXME event adapter
.bind(6, write.serId)
.bind(7, write.serManifest)
.bindPayload(8, write.payload.get)

if (write.tags.isEmpty)
stmt.bindNull(9, classOf[Array[String]])
else
stmt.bind(9, write.tags.toArray)

// optional metadata
write.metadata match {
case Some(m) =>
stmt
.bind(10, m.serId)
.bind(11, m.serManifest)
.bind(12, m.payload)
case None =>
stmt
.bindNull(10, classOf[Integer])
.bindNull(11, classOf[String])
.bindNull(12, classOf[Array[Byte]])
}

stmt.bind(13, write.dbTimestamp)

stmt
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import akka.annotation.InternalApi
import akka.dispatch.ExecutionContexts
import akka.persistence.Persistence
import akka.persistence.r2dbc.R2dbcSettings
import akka.persistence.r2dbc.internal.BySliceQuery
import akka.persistence.r2dbc.internal.JournalDao
import akka.persistence.r2dbc.internal.PayloadCodec
import akka.persistence.r2dbc.internal.PayloadCodec.RichStatement
Expand Down Expand Up @@ -185,10 +184,8 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti
if (useTimestampFromDb) insertEventWithTransactionTimestampSql
else insertEventWithParameterTimestampSql

val result = r2dbcExecutor.updateOneReturning(s"insert [$persistenceId]")(
connection =>
bindInsertStatement(connection.createStatement(insertSql), event, useTimestampFromDb, previousSeqNr),
row => row.get(0, classOf[Instant]))
val stmt = bindInsertStatement(connection.createStatement(insertSql), event, useTimestampFromDb, previousSeqNr)
val result = R2dbcExecutor.updateOneReturningInTx(stmt, row => row.get(0, classOf[Instant]))
if (log.isDebugEnabled())
result.foreach { _ =>
log.debug("Wrote [{}] event for persistenceId [{}]", 1, persistenceId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

package akka.persistence.r2dbc.state

import org.scalatest.concurrent.ScalaFutures.convertScalaFuture

import akka.actor.testkit.typed.scaladsl.LogCapturing
import akka.actor.testkit.typed.scaladsl.ScalaTestWithActorTestKit
import akka.actor.typed.ActorSystem
Expand All @@ -24,6 +26,10 @@ import akka.persistence.typed.PersistenceId
import akka.stream.scaladsl.Sink
import org.scalatest.wordspec.AnyWordSpecLike

import akka.Done
import akka.persistence.r2dbc.TestActors.Persister
import akka.persistence.r2dbc.TestActors.Persister.PersistWithAck

class DurableStateUpdateWithChangeEventStoreSpec
extends ScalaTestWithActorTestKit(TestConfig.config)
with AnyWordSpecLike
Expand Down Expand Up @@ -70,6 +76,24 @@ class DurableStateUpdateWithChangeEventStoreSpec
env3.sequenceNr shouldBe 3L
}

"save additional change event in same transaction" in {
// test rollback (same tx) if the journal insert fails via simulated unique constraint violation in event_journal
val entityType = nextEntityType()
val persistenceId = PersistenceId(entityType, "my-persistenceId").id

val probe = testKit.createTestProbe[Done]()
val persister = testKit.spawn(Persister(persistenceId))
persister ! PersistWithAck("a", probe.ref)
probe.expectMessage(Done)
testKit.stop(persister)

val value1 = "Genuinely Collaborative"

store.upsertObject(persistenceId, 1L, value1, tag, s"Changed to $value1").failed.futureValue

store.getObject(persistenceId).futureValue.value shouldBe None
}

"detect and reject concurrent inserts, and not store change event" in {
val entityType = nextEntityType()
val persistenceId = PersistenceId(entityType, "id-to-be-inserted-concurrently").id
Expand Down

0 comments on commit fbe67e0

Please sign in to comment.