diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/DurableStateDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/DurableStateDao.scala index ba3cbdd2..63427e8c 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/DurableStateDao.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/DurableStateDao.scala @@ -8,10 +8,12 @@ import akka.Done import akka.NotUsed import akka.annotation.InternalApi import akka.stream.scaladsl.Source - import java.time.Instant + import scala.concurrent.Future +import akka.persistence.r2dbc.internal.JournalDao.SerializedJournalRow + /** * INTERNAL API */ @@ -48,7 +50,7 @@ private[r2dbc] trait DurableStateDao extends BySliceQuery.Dao[DurableStateDao.Se def readState(persistenceId: String): Future[Option[SerializedStateRow]] - def upsertState(state: SerializedStateRow, value: Any): Future[Done] + def upsertState(state: SerializedStateRow, value: Any, changeEvent: Option[SerializedJournalRow]): Future[Done] def deleteState(persistenceId: String, revision: Long): Future[Done] diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/JournalDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/JournalDao.scala index 22b046ee..ac64692f 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/JournalDao.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/JournalDao.scala @@ -5,10 +5,13 @@ package akka.persistence.r2dbc.internal import akka.annotation.InternalApi - import java.time.Instant + import scala.concurrent.Future +import akka.persistence.r2dbc.internal.JournalDao.SerializedJournalRow +import io.r2dbc.spi.Connection + /** * INTERNAL API */ @@ -56,6 +59,9 @@ private[r2dbc] trait JournalDao { * a select (in same transaction). */ def writeEvents(events: Seq[JournalDao.SerializedJournalRow]): Future[Instant] + + def writeEventInTx(event: SerializedJournalRow, connection: Connection): Future[Instant] + def readHighestSequenceNr(persistenceId: String, fromSequenceNr: Long): Future[Long] def readLowestSequenceNr(persistenceId: String): Future[Long] diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2Dialect.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2Dialect.scala index d13e6e4a..fbe219cf 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2Dialect.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2Dialect.scala @@ -94,7 +94,9 @@ private[r2dbc] object H2Dialect extends Dialect { override def createDurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit system: ActorSystem[_]): DurableStateDao = - new H2DurableStateDao(settings, connectionFactory)(ecForDaos(system, settings), system) + new H2DurableStateDao(settings, connectionFactory, createJournalDao(settings, connectionFactory))( + ecForDaos(system, settings), + system) private def ecForDaos(system: ActorSystem[_], settings: R2dbcSettings): ExecutionContext = { // H2 R2DBC driver blocks in surprising places (Mono.toFuture in stmt.execute().asFuture()) diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2DurableStateDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2DurableStateDao.scala index c3f42d5a..180ef355 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2DurableStateDao.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2DurableStateDao.scala @@ -11,19 +11,21 @@ import akka.persistence.r2dbc.internal.postgres.PostgresDurableStateDao import io.r2dbc.spi.ConnectionFactory import org.slf4j.Logger import org.slf4j.LoggerFactory - import scala.concurrent.ExecutionContext import scala.concurrent.duration.Duration import scala.concurrent.duration.FiniteDuration +import akka.persistence.r2dbc.internal.JournalDao + /** * INTERNAL API */ @InternalApi -private[r2dbc] final class H2DurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit - ec: ExecutionContext, - system: ActorSystem[_]) - extends PostgresDurableStateDao(settings, connectionFactory) { +private[r2dbc] final class H2DurableStateDao( + settings: R2dbcSettings, + connectionFactory: ConnectionFactory, + journalDao: JournalDao)(implicit ec: ExecutionContext, system: ActorSystem[_]) + extends PostgresDurableStateDao(settings, connectionFactory, journalDao) { override protected lazy val log: Logger = LoggerFactory.getLogger(classOf[H2DurableStateDao]) diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresDialect.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresDialect.scala index 87ddc81a..9c9362dd 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresDialect.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresDialect.scala @@ -129,5 +129,7 @@ private[r2dbc] object PostgresDialect extends Dialect { override def createDurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit system: ActorSystem[_]): DurableStateDao = - new PostgresDurableStateDao(settings, connectionFactory)(system.executionContext, system) + new PostgresDurableStateDao(settings, connectionFactory, createJournalDao(settings, connectionFactory))( + system.executionContext, + system) } diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresDurableStateDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresDurableStateDao.scala index 7cca3f15..7a18ecb4 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresDurableStateDao.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresDurableStateDao.scala @@ -40,10 +40,10 @@ import io.r2dbc.spi.Row import io.r2dbc.spi.Statement import org.slf4j.Logger import org.slf4j.LoggerFactory - import java.lang import java.time.Instant import java.util + import scala.collection.immutable import scala.concurrent.ExecutionContext import scala.concurrent.Future @@ -51,6 +51,9 @@ import scala.concurrent.duration.Duration import scala.concurrent.duration.FiniteDuration import scala.util.control.NonFatal +import akka.persistence.r2dbc.internal.JournalDao +import akka.persistence.r2dbc.internal.JournalDao.SerializedJournalRow + /** * INTERNAL API */ @@ -70,9 +73,10 @@ private[r2dbc] object PostgresDurableStateDao { * INTERNAL API */ @InternalApi -private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit - ec: ExecutionContext, - system: ActorSystem[_]) +private[r2dbc] class PostgresDurableStateDao( + settings: R2dbcSettings, + connectionFactory: ConnectionFactory, + journalDao: JournalDao)(implicit ec: ExecutionContext, system: ActorSystem[_]) extends DurableStateDao { import DurableStateDao._ import PostgresDurableStateDao._ @@ -264,7 +268,7 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection LIMIT ?""" } - def readState(persistenceId: String): Future[Option[SerializedStateRow]] = { + override def readState(persistenceId: String): Future[Option[SerializedStateRow]] = { val entityType = PersistenceId.extractEntityType(persistenceId) r2dbcExecutor.selectOne(s"select [$persistenceId]")( connection => @@ -293,7 +297,25 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection Option(rowPayload) } - def upsertState(state: SerializedStateRow, value: Any): Future[Done] = { + private def writeChangeEventAndCallChangeHander( + connection: Connection, + updatedRows: Long, + entityType: String, + change: DurableStateChange[Any], + changeEvent: Option[SerializedJournalRow]): Future[Done] = { + if (updatedRows == 1) + for { + _ <- changeEvent.map(journalDao.writeEventInTx(_, connection)).getOrElse(FutureDone) + _ <- changeHandlers.get(entityType).map(processChange(_, connection, change)).getOrElse(FutureDone) + } yield Done + else + FutureDone + } + + override def upsertState( + state: SerializedStateRow, + value: Any, + changeEvent: Option[SerializedJournalRow]): Future[Done] = { require(state.revision > 0) def bindTags(stmt: Statement, i: Int): Statement = { @@ -360,17 +382,15 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection s"Insert failed: durable state for persistence id [${state.persistenceId}] already exists")) } - changeHandlers.get(entityType) match { - case None => - recoverDataIntegrityViolation(r2dbcExecutor.updateOne(s"insert [${state.persistenceId}]")(insertStatement)) - case Some(handler) => - r2dbcExecutor.withConnection(s"insert [${state.persistenceId}] with change handler") { connection => - for { - updatedRows <- recoverDataIntegrityViolation(R2dbcExecutor.updateOneInTx(insertStatement(connection))) - _ <- if (updatedRows == 1) processChange(handler, connection, change) else FutureDone - } yield updatedRows - } - } + if (!changeHandlers.contains(entityType) && changeEvent.isEmpty) + recoverDataIntegrityViolation(r2dbcExecutor.updateOne(s"insert [${state.persistenceId}]")(insertStatement)) + else + r2dbcExecutor.withConnection(s"insert [${state.persistenceId}]") { connection => + for { + updatedRows <- recoverDataIntegrityViolation(R2dbcExecutor.updateOneInTx(insertStatement(connection))) + _ <- writeChangeEventAndCallChangeHander(connection, updatedRows, entityType, change, changeEvent = None) + } yield updatedRows + } } else { val previousRevision = state.revision - 1 @@ -405,17 +425,15 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection } } - changeHandlers.get(entityType) match { - case None => - r2dbcExecutor.updateOne(s"update [${state.persistenceId}]")(updateStatement) - case Some(handler) => - r2dbcExecutor.withConnection(s"update [${state.persistenceId}] with change handler") { connection => - for { - updatedRows <- R2dbcExecutor.updateOneInTx(updateStatement(connection)) - _ <- if (updatedRows == 1) processChange(handler, connection, change) else FutureDone - } yield updatedRows - } - } + if (!changeHandlers.contains(entityType) && changeEvent.isEmpty) + r2dbcExecutor.updateOne(s"update [${state.persistenceId}]")(updateStatement) + else + r2dbcExecutor.withConnection(s"update [${state.persistenceId}]") { connection => + for { + updatedRows <- R2dbcExecutor.updateOneInTx(updateStatement(connection)) + _ <- writeChangeEventAndCallChangeHander(connection, updatedRows, entityType, change, changeEvent = None) + } yield updatedRows + } } } @@ -451,7 +469,7 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection } } - def deleteState(persistenceId: String, revision: Long): Future[Done] = { + override def deleteState(persistenceId: String, revision: Long): Future[Done] = { if (revision == 0) { hardDeleteState(persistenceId) } else { @@ -490,10 +508,7 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection for { updatedRows <- recoverDataIntegrityViolation( R2dbcExecutor.updateOneInTx(insertDeleteMarkerStatement(connection))) - _ <- changeHandler match { - case None => FutureDone - case Some(handler) => if (updatedRows == 1) processChange(handler, connection, change) else FutureDone - } + _ <- writeChangeEventAndCallChangeHander(connection, updatedRows, entityType, change, changeEvent = None) } yield updatedRows } @@ -537,10 +552,7 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection r2dbcExecutor.withConnection(s"delete [$persistenceId]$changeHandlerHint") { connection => for { updatedRows <- R2dbcExecutor.updateOneInTx(updateStatement(connection)) - _ <- changeHandler match { - case None => FutureDone - case Some(handler) => if (updatedRows == 1) processChange(handler, connection, change) else FutureDone - } + _ <- writeChangeEventAndCallChangeHander(connection, updatedRows, entityType, change, changeEvent = None) } yield updatedRows } } @@ -572,14 +584,9 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection connection .createStatement(hardDeleteStateSql(entityType)) .bind(0, persistenceId)) - _ <- changeHandler match { - case None => FutureDone - case Some(handler) => - if (updatedRows == 1) { - val change = new DeletedDurableState[Any](persistenceId, 0L, NoOffset, EmptyDbTimestamp.toEpochMilli) - processChange(handler, connection, change) - } else - FutureDone + _ <- { + val change = new DeletedDurableState[Any](persistenceId, 0L, NoOffset, EmptyDbTimestamp.toEpochMilli) + writeChangeEventAndCallChangeHander(connection, updatedRows, entityType, change, changeEvent = None) } } yield updatedRows } @@ -669,7 +676,7 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection Source.futureSource(result.map(Source(_))).mapMaterializedValue(_ => NotUsed) } - def persistenceIds(afterId: Option[String], limit: Long): Source[String, NotUsed] = { + override def persistenceIds(afterId: Option[String], limit: Long): Source[String, NotUsed] = { if (settings.durableStateTableByEntityTypeWithSchema.isEmpty) persistenceIds(afterId, limit, settings.durableStateTableWithSchema) else { @@ -699,7 +706,7 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection } } - def persistenceIds(afterId: Option[String], limit: Long, table: String): Source[String, NotUsed] = { + override def persistenceIds(afterId: Option[String], limit: Long, table: String): Source[String, NotUsed] = { val result = readPersistenceIds(afterId, limit, table) Source.futureSource(result.map(Source(_))).mapMaterializedValue(_ => NotUsed) @@ -729,7 +736,7 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection result } - def persistenceIds(entityType: String, afterId: Option[String], limit: Long): Source[String, NotUsed] = { + override def persistenceIds(entityType: String, afterId: Option[String], limit: Long): Source[String, NotUsed] = { val table = settings.getDurableStateTableWithSchema(entityType) val likeStmtPostfix = PersistenceId.DefaultSeparator + "%" val result = r2dbcExecutor.select(s"select persistenceIds by entity type")( diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresJournalDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresJournalDao.scala index f3f0a7ad..6af05f85 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresJournalDao.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresJournalDao.scala @@ -24,8 +24,8 @@ import io.r2dbc.spi.Row import io.r2dbc.spi.Statement import org.slf4j.Logger import org.slf4j.LoggerFactory - import java.time.Instant + import scala.concurrent.ExecutionContext import scala.concurrent.Future @@ -133,7 +133,7 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti * it can return `JournalDao.EmptyDbTimestamp` when the pub-sub feature is disabled. When enabled it would have to use * a select (in same transaction). */ - def writeEvents(events: Seq[SerializedJournalRow]): Future[Instant] = { + override def writeEvents(events: Seq[SerializedJournalRow]): Future[Instant] = { require(events.nonEmpty) // it's always the same persistenceId for all events @@ -143,56 +143,6 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti // The MigrationTool defines the dbTimestamp to preserve the original event timestamp val useTimestampFromDb = events.head.dbTimestamp == Instant.EPOCH - def bind(stmt: Statement, write: SerializedJournalRow): Statement = { - stmt - .bind(0, write.slice) - .bind(1, write.entityType) - .bind(2, write.persistenceId) - .bind(3, write.seqNr) - .bind(4, write.writerUuid) - .bind(5, "") // FIXME event adapter - .bind(6, write.serId) - .bind(7, write.serManifest) - .bindPayload(8, write.payload.get) - - if (write.tags.isEmpty) - stmt.bindNull(9, classOf[Array[String]]) - else - stmt.bind(9, write.tags.toArray) - - // optional metadata - write.metadata match { - case Some(m) => - stmt - .bind(10, m.serId) - .bind(11, m.serManifest) - .bind(12, m.payload) - case None => - stmt - .bindNull(10, classOf[Integer]) - .bindNull(11, classOf[String]) - .bindNull(12, classOf[Array[Byte]]) - } - - if (useTimestampFromDb) { - if (!journalSettings.dbTimestampMonotonicIncreasing) - stmt - .bind(13, write.persistenceId) - .bind(14, previousSeqNr) - } else { - if (journalSettings.dbTimestampMonotonicIncreasing) - stmt - .bind(13, write.dbTimestamp) - else - stmt - .bind(13, write.dbTimestamp) - .bind(14, write.persistenceId) - .bind(15, previousSeqNr) - } - - stmt - } - val insertSql = if (useTimestampFromDb) insertEventWithTransactionTimestampSql else insertEventWithParameterTimestampSql @@ -200,11 +150,12 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti val totalEvents = events.size if (totalEvents == 1) { val result = r2dbcExecutor.updateOneReturning(s"insert [$persistenceId]")( - connection => bind(connection.createStatement(insertSql), events.head), + connection => + bindInsertStatement(connection.createStatement(insertSql), events.head, useTimestampFromDb, previousSeqNr), row => row.get(0, classOf[Instant])) if (log.isDebugEnabled()) result.foreach { _ => - log.debug("Wrote [{}] events for persistenceId [{}]", 1, events.head.persistenceId) + log.debug("Wrote [{}] events for persistenceId [{}]", 1, persistenceId) } result } else { @@ -212,18 +163,94 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti connection => events.foldLeft(connection.createStatement(insertSql)) { (stmt, write) => stmt.add() - bind(stmt, write) + bindInsertStatement(stmt, write, useTimestampFromDb, previousSeqNr) }, row => row.get(0, classOf[Instant])) if (log.isDebugEnabled()) result.foreach { _ => - log.debug("Wrote [{}] events for persistenceId [{}]", totalEvents, events.head.persistenceId) + log.debug("Wrote [{}] events for persistenceId [{}]", totalEvents, persistenceId) } result.map(_.head)(ExecutionContexts.parasitic) } } - def readHighestSequenceNr(persistenceId: String, fromSequenceNr: Long): Future[Long] = { + override def writeEventInTx(event: SerializedJournalRow, connection: Connection): Future[Instant] = { + val persistenceId = event.persistenceId + val previousSeqNr = event.seqNr - 1 + + // The MigrationTool defines the dbTimestamp to preserve the original event timestamp + val useTimestampFromDb = event.dbTimestamp == Instant.EPOCH + + val insertSql = + if (useTimestampFromDb) insertEventWithTransactionTimestampSql + else insertEventWithParameterTimestampSql + + val result = r2dbcExecutor.updateOneReturning(s"insert [$persistenceId]")( + connection => + bindInsertStatement(connection.createStatement(insertSql), event, useTimestampFromDb, previousSeqNr), + row => row.get(0, classOf[Instant])) + if (log.isDebugEnabled()) + result.foreach { _ => + log.debug("Wrote [{}] event for persistenceId [{}]", 1, persistenceId) + } + result + } + + private def bindInsertStatement( + stmt: Statement, + write: SerializedJournalRow, + useTimestampFromDb: Boolean, + previousSeqNr: Long): Statement = { + stmt + .bind(0, write.slice) + .bind(1, write.entityType) + .bind(2, write.persistenceId) + .bind(3, write.seqNr) + .bind(4, write.writerUuid) + .bind(5, "") // FIXME event adapter + .bind(6, write.serId) + .bind(7, write.serManifest) + .bindPayload(8, write.payload.get) + + if (write.tags.isEmpty) + stmt.bindNull(9, classOf[Array[String]]) + else + stmt.bind(9, write.tags.toArray) + + // optional metadata + write.metadata match { + case Some(m) => + stmt + .bind(10, m.serId) + .bind(11, m.serManifest) + .bind(12, m.payload) + case None => + stmt + .bindNull(10, classOf[Integer]) + .bindNull(11, classOf[String]) + .bindNull(12, classOf[Array[Byte]]) + } + + if (useTimestampFromDb) { + if (!journalSettings.dbTimestampMonotonicIncreasing) + stmt + .bind(13, write.persistenceId) + .bind(14, previousSeqNr) + } else { + if (journalSettings.dbTimestampMonotonicIncreasing) + stmt + .bind(13, write.dbTimestamp) + else + stmt + .bind(13, write.dbTimestamp) + .bind(14, write.persistenceId) + .bind(15, previousSeqNr) + } + + stmt + } + + override def readHighestSequenceNr(persistenceId: String, fromSequenceNr: Long): Future[Long] = { val result = r2dbcExecutor .select(s"select highest seqNr [$persistenceId]")( connection => @@ -243,7 +270,7 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti result } - def readLowestSequenceNr(persistenceId: String): Future[Long] = { + override def readLowestSequenceNr(persistenceId: String): Future[Long] = { val result = r2dbcExecutor .select(s"select lowest seqNr [$persistenceId]")( connection => @@ -277,7 +304,7 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti } } - def deleteEventsTo(persistenceId: String, toSequenceNr: Long, resetSequenceNumber: Boolean): Future[Unit] = { + override def deleteEventsTo(persistenceId: String, toSequenceNr: Long, resetSequenceNumber: Boolean): Future[Unit] = { def insertDeleteMarkerStmt(deleteMarkerSeqNr: Long, connection: Connection): Statement = { val entityType = PersistenceId.extractEntityType(persistenceId) diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/YugabyteDialect.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/YugabyteDialect.scala index 2d7a6a7d..27b273dd 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/YugabyteDialect.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/YugabyteDialect.scala @@ -42,5 +42,7 @@ private[r2dbc] object YugabyteDialect extends Dialect { override def createDurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit system: ActorSystem[_]): DurableStateDao = - new YugabyteDurableStateDao(settings, connectionFactory)(system.executionContext, system) + new YugabyteDurableStateDao(settings, connectionFactory, createJournalDao(settings, connectionFactory))( + system.executionContext, + system) } diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/YugabyteDurableStateDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/YugabyteDurableStateDao.scala index 9f7f4051..eabfaf9f 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/YugabyteDurableStateDao.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/YugabyteDurableStateDao.scala @@ -10,18 +10,19 @@ import akka.persistence.r2dbc.R2dbcSettings import io.r2dbc.spi._ import org.slf4j.Logger import org.slf4j.LoggerFactory - import scala.concurrent.ExecutionContext +import akka.persistence.r2dbc.internal.JournalDao + /** * INTERNAL API */ @InternalApi -private[r2dbc] final class YugabyteDurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)( - implicit - ec: ExecutionContext, - system: ActorSystem[_]) - extends PostgresDurableStateDao(settings, connectionFactory) { +private[r2dbc] final class YugabyteDurableStateDao( + settings: R2dbcSettings, + connectionFactory: ConnectionFactory, + journalDao: JournalDao)(implicit ec: ExecutionContext, system: ActorSystem[_]) + extends PostgresDurableStateDao(settings, connectionFactory, journalDao) { override protected lazy val log: Logger = LoggerFactory.getLogger(classOf[YugabyteDurableStateDao]) diff --git a/core/src/main/scala/akka/persistence/r2dbc/state/scaladsl/R2dbcDurableStateStore.scala b/core/src/main/scala/akka/persistence/r2dbc/state/scaladsl/R2dbcDurableStateStore.scala index 4e42eae8..e70b2d12 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/state/scaladsl/R2dbcDurableStateStore.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/state/scaladsl/R2dbcDurableStateStore.scala @@ -7,12 +7,14 @@ package akka.persistence.r2dbc.state.scaladsl import scala.collection.immutable import scala.concurrent.ExecutionContext import scala.concurrent.Future + import akka.Done import akka.NotUsed import akka.actor.ExtendedActorSystem import akka.actor.typed.scaladsl.LoggerOps import akka.actor.typed.scaladsl.adapter._ import akka.persistence.Persistence +import akka.persistence.SerializedEvent import akka.persistence.query.DeletedDurableState import akka.persistence.query.DurableStateChange import akka.persistence.query.Offset @@ -26,8 +28,12 @@ import akka.persistence.r2dbc.internal.BySliceQuery import akka.persistence.r2dbc.internal.ContinuousQuery import akka.persistence.r2dbc.internal.DurableStateDao import akka.persistence.r2dbc.internal.DurableStateDao.SerializedStateRow +import akka.persistence.r2dbc.internal.InstantFactory +import akka.persistence.r2dbc.internal.JournalDao +import akka.persistence.r2dbc.internal.JournalDao.SerializedJournalRow import akka.persistence.state.scaladsl.DurableStateUpdateStore import akka.persistence.state.scaladsl.GetObjectResult +import akka.persistence.typed.PersistenceId import akka.serialization.SerializationExtension import akka.serialization.Serializers import akka.stream.scaladsl.Source @@ -53,6 +59,7 @@ class R2dbcDurableStateStore[A](system: ExtendedActorSystem, config: Config, cfg private val log = LoggerFactory.getLogger(getClass) private val sharedConfigPath = cfgPath.replaceAll("""\.state$""", "") private val settings = R2dbcSettings(system.settings.config.getConfig(sharedConfigPath)) + private val journalSettings = R2dbcSettings(system.settings.config.getConfig(sharedConfigPath)) log.debug("R2DBC journal starting up with dialect [{}]", settings.dialectName) private val typedSystem = system.toTyped @@ -109,7 +116,26 @@ class R2dbcDurableStateStore[A](system: ExtendedActorSystem, config: Config, cfg * the existing stored `revision` + 1 isn't equal to the given `revision`. This optimistic locking check can be * disabled with configuration `assert-single-writer`. */ - override def upsertObject(persistenceId: String, revision: Long, value: A, tag: String): Future[Done] = { + override def upsertObject(persistenceId: String, revision: Long, value: A, tag: String): Future[Done] = + upsertObject(persistenceId, revision, value, tag, changeEvent = None) + + /** + * Insert the value if `revision` is 1, which will fail with `IllegalStateException` if there is already a stored + * value for the given `persistenceId`. Otherwise update the value, which will fail with `IllegalStateException` if + * the existing stored `revision` + 1 isn't equal to the given `revision`. This optimistic locking check can be + * disabled with configuration `assert-single-writer`. + * + * The `changeEvent`, if defined, is written to the event journal in the same transaction as the DurableState upsert. + * Same `persistenceId` is used in the journal and the `revision` is used as `sequenceNr`. + */ + def upsertObject( + persistenceId: String, + revision: Long, + value: A, + tag: String, + changeEvent: Option[Any]): Future[Done] = { + // FIXME add new trait in Akka for this method. Maybe we need it for the deletes too. + val valueAnyRef = value.asInstanceOf[AnyRef] val serialized = serialization.serialize(valueAnyRef).get val serializer = serialization.findSerializerFor(valueAnyRef) @@ -125,7 +151,41 @@ class R2dbcDurableStateStore[A](system: ExtendedActorSystem, config: Config, cfg manifest, if (tag.isEmpty) Set.empty else Set(tag)) - stateDao.upsertState(serializedRow, value) + val serializedChangedEvent: Option[SerializedJournalRow] = { + changeEvent.map { event => + val eventAnyRef = event.asInstanceOf[AnyRef] + val serializedEvent = eventAnyRef match { + case s: SerializedEvent => s // already serialized + case _ => + val bytes = serialization.serialize(eventAnyRef).get + val serializer = serialization.findSerializerFor(eventAnyRef) + val manifest = Serializers.manifestFor(serializer, eventAnyRef) + new SerializedEvent(bytes, serializer.identifier, manifest) + } + + val entityType = PersistenceId.extractEntityType(persistenceId) + val slice = persistenceExt.sliceForPersistenceId(persistenceId) + val timestamp = if (journalSettings.useAppTimestamp) InstantFactory.now() else JournalDao.EmptyDbTimestamp + + SerializedJournalRow( + slice, + entityType, + persistenceId, + revision, + timestamp, + JournalDao.EmptyDbTimestamp, + Some(serializedEvent.bytes), + serializedEvent.serializerId, + serializedEvent.serializerManifest, + "", // FIXME writerUuid, or shall we make one? + if (tag.isEmpty) Set.empty else Set(tag), + metadata = None) + } + } + + stateDao.upsertState(serializedRow, value, serializedChangedEvent) + + // FIXME PubSub, but not via PersistentRepr }