Skip to content

Commit

Permalink
feat: Write change event of DurableState to event journal
Browse files Browse the repository at this point in the history
  • Loading branch information
patriknw committed Dec 12, 2023
1 parent a405616 commit e67c6fd
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ import akka.Done
import akka.NotUsed
import akka.annotation.InternalApi
import akka.stream.scaladsl.Source

import java.time.Instant

import scala.concurrent.Future

import akka.persistence.r2dbc.internal.JournalDao.SerializedJournalRow

/**
* INTERNAL API
*/
Expand Down Expand Up @@ -48,7 +50,7 @@ private[r2dbc] trait DurableStateDao extends BySliceQuery.Dao[DurableStateDao.Se

def readState(persistenceId: String): Future[Option[SerializedStateRow]]

def upsertState(state: SerializedStateRow, value: Any): Future[Done]
def upsertState(state: SerializedStateRow, value: Any, changeEvent: Option[SerializedJournalRow]): Future[Done]

def deleteState(persistenceId: String, revision: Long): Future[Done]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
package akka.persistence.r2dbc.internal

import akka.annotation.InternalApi

import java.time.Instant

import scala.concurrent.Future

import akka.persistence.r2dbc.internal.JournalDao.SerializedJournalRow
import io.r2dbc.spi.Connection

/**
* INTERNAL API
*/
Expand Down Expand Up @@ -56,6 +59,9 @@ private[r2dbc] trait JournalDao {
* a select (in same transaction).
*/
def writeEvents(events: Seq[JournalDao.SerializedJournalRow]): Future[Instant]

def writeEventInTx(event: SerializedJournalRow, connection: Connection): Future[Instant]

def readHighestSequenceNr(persistenceId: String, fromSequenceNr: Long): Future[Long]

def readLowestSequenceNr(persistenceId: String): Future[Long]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ private[r2dbc] object H2Dialect extends Dialect {

override def createDurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit
system: ActorSystem[_]): DurableStateDao =
new H2DurableStateDao(settings, connectionFactory)(ecForDaos(system, settings), system)
new H2DurableStateDao(settings, connectionFactory, createJournalDao(settings, connectionFactory))(
ecForDaos(system, settings),
system)

private def ecForDaos(system: ActorSystem[_], settings: R2dbcSettings): ExecutionContext = {
// H2 R2DBC driver blocks in surprising places (Mono.toFuture in stmt.execute().asFuture())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,21 @@ import akka.persistence.r2dbc.internal.postgres.PostgresDurableStateDao
import io.r2dbc.spi.ConnectionFactory
import org.slf4j.Logger
import org.slf4j.LoggerFactory

import scala.concurrent.ExecutionContext
import scala.concurrent.duration.Duration
import scala.concurrent.duration.FiniteDuration

import akka.persistence.r2dbc.internal.JournalDao

/**
* INTERNAL API
*/
@InternalApi
private[r2dbc] final class H2DurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit
ec: ExecutionContext,
system: ActorSystem[_])
extends PostgresDurableStateDao(settings, connectionFactory) {
private[r2dbc] final class H2DurableStateDao(
settings: R2dbcSettings,
connectionFactory: ConnectionFactory,
journalDao: JournalDao)(implicit ec: ExecutionContext, system: ActorSystem[_])
extends PostgresDurableStateDao(settings, connectionFactory, journalDao) {

override protected lazy val log: Logger = LoggerFactory.getLogger(classOf[H2DurableStateDao])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,5 +129,7 @@ private[r2dbc] object PostgresDialect extends Dialect {

override def createDurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit
system: ActorSystem[_]): DurableStateDao =
new PostgresDurableStateDao(settings, connectionFactory)(system.executionContext, system)
new PostgresDurableStateDao(settings, connectionFactory, createJournalDao(settings, connectionFactory))(
system.executionContext,
system)
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,20 @@ import io.r2dbc.spi.Row
import io.r2dbc.spi.Statement
import org.slf4j.Logger
import org.slf4j.LoggerFactory

import java.lang
import java.time.Instant
import java.util

import scala.collection.immutable
import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.duration.Duration
import scala.concurrent.duration.FiniteDuration
import scala.util.control.NonFatal

import akka.persistence.r2dbc.internal.JournalDao
import akka.persistence.r2dbc.internal.JournalDao.SerializedJournalRow

/**
* INTERNAL API
*/
Expand All @@ -70,9 +73,10 @@ private[r2dbc] object PostgresDurableStateDao {
* INTERNAL API
*/
@InternalApi
private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit
ec: ExecutionContext,
system: ActorSystem[_])
private[r2dbc] class PostgresDurableStateDao(
settings: R2dbcSettings,
connectionFactory: ConnectionFactory,
journalDao: JournalDao)(implicit ec: ExecutionContext, system: ActorSystem[_])
extends DurableStateDao {
import DurableStateDao._
import PostgresDurableStateDao._
Expand Down Expand Up @@ -264,7 +268,7 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection
LIMIT ?"""
}

def readState(persistenceId: String): Future[Option[SerializedStateRow]] = {
override def readState(persistenceId: String): Future[Option[SerializedStateRow]] = {
val entityType = PersistenceId.extractEntityType(persistenceId)
r2dbcExecutor.selectOne(s"select [$persistenceId]")(
connection =>
Expand Down Expand Up @@ -293,7 +297,25 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection
Option(rowPayload)
}

def upsertState(state: SerializedStateRow, value: Any): Future[Done] = {
private def writeChangeEventAndCallChangeHander(
connection: Connection,
updatedRows: Long,
entityType: String,
change: DurableStateChange[Any],
changeEvent: Option[SerializedJournalRow]): Future[Done] = {
if (updatedRows == 1)
for {
_ <- changeEvent.map(journalDao.writeEventInTx(_, connection)).getOrElse(FutureDone)
_ <- changeHandlers.get(entityType).map(processChange(_, connection, change)).getOrElse(FutureDone)
} yield Done
else
FutureDone
}

override def upsertState(
state: SerializedStateRow,
value: Any,
changeEvent: Option[SerializedJournalRow]): Future[Done] = {
require(state.revision > 0)

def bindTags(stmt: Statement, i: Int): Statement = {
Expand Down Expand Up @@ -360,17 +382,15 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection
s"Insert failed: durable state for persistence id [${state.persistenceId}] already exists"))
}

changeHandlers.get(entityType) match {
case None =>
recoverDataIntegrityViolation(r2dbcExecutor.updateOne(s"insert [${state.persistenceId}]")(insertStatement))
case Some(handler) =>
r2dbcExecutor.withConnection(s"insert [${state.persistenceId}] with change handler") { connection =>
for {
updatedRows <- recoverDataIntegrityViolation(R2dbcExecutor.updateOneInTx(insertStatement(connection)))
_ <- if (updatedRows == 1) processChange(handler, connection, change) else FutureDone
} yield updatedRows
}
}
if (!changeHandlers.contains(entityType) && changeEvent.isEmpty)
recoverDataIntegrityViolation(r2dbcExecutor.updateOne(s"insert [${state.persistenceId}]")(insertStatement))
else
r2dbcExecutor.withConnection(s"insert [${state.persistenceId}]") { connection =>
for {
updatedRows <- recoverDataIntegrityViolation(R2dbcExecutor.updateOneInTx(insertStatement(connection)))
_ <- writeChangeEventAndCallChangeHander(connection, updatedRows, entityType, change, changeEvent = None)
} yield updatedRows
}
} else {
val previousRevision = state.revision - 1

Expand Down Expand Up @@ -405,17 +425,15 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection
}
}

changeHandlers.get(entityType) match {
case None =>
r2dbcExecutor.updateOne(s"update [${state.persistenceId}]")(updateStatement)
case Some(handler) =>
r2dbcExecutor.withConnection(s"update [${state.persistenceId}] with change handler") { connection =>
for {
updatedRows <- R2dbcExecutor.updateOneInTx(updateStatement(connection))
_ <- if (updatedRows == 1) processChange(handler, connection, change) else FutureDone
} yield updatedRows
}
}
if (!changeHandlers.contains(entityType) && changeEvent.isEmpty)
r2dbcExecutor.updateOne(s"update [${state.persistenceId}]")(updateStatement)
else
r2dbcExecutor.withConnection(s"update [${state.persistenceId}]") { connection =>
for {
updatedRows <- R2dbcExecutor.updateOneInTx(updateStatement(connection))
_ <- writeChangeEventAndCallChangeHander(connection, updatedRows, entityType, change, changeEvent = None)
} yield updatedRows
}
}
}

Expand Down Expand Up @@ -451,7 +469,7 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection
}
}

def deleteState(persistenceId: String, revision: Long): Future[Done] = {
override def deleteState(persistenceId: String, revision: Long): Future[Done] = {
if (revision == 0) {
hardDeleteState(persistenceId)
} else {
Expand Down Expand Up @@ -490,10 +508,7 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection
for {
updatedRows <- recoverDataIntegrityViolation(
R2dbcExecutor.updateOneInTx(insertDeleteMarkerStatement(connection)))
_ <- changeHandler match {
case None => FutureDone
case Some(handler) => if (updatedRows == 1) processChange(handler, connection, change) else FutureDone
}
_ <- writeChangeEventAndCallChangeHander(connection, updatedRows, entityType, change, changeEvent = None)
} yield updatedRows
}

Expand Down Expand Up @@ -537,10 +552,7 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection
r2dbcExecutor.withConnection(s"delete [$persistenceId]$changeHandlerHint") { connection =>
for {
updatedRows <- R2dbcExecutor.updateOneInTx(updateStatement(connection))
_ <- changeHandler match {
case None => FutureDone
case Some(handler) => if (updatedRows == 1) processChange(handler, connection, change) else FutureDone
}
_ <- writeChangeEventAndCallChangeHander(connection, updatedRows, entityType, change, changeEvent = None)
} yield updatedRows
}
}
Expand Down Expand Up @@ -572,14 +584,9 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection
connection
.createStatement(hardDeleteStateSql(entityType))
.bind(0, persistenceId))
_ <- changeHandler match {
case None => FutureDone
case Some(handler) =>
if (updatedRows == 1) {
val change = new DeletedDurableState[Any](persistenceId, 0L, NoOffset, EmptyDbTimestamp.toEpochMilli)
processChange(handler, connection, change)
} else
FutureDone
_ <- {
val change = new DeletedDurableState[Any](persistenceId, 0L, NoOffset, EmptyDbTimestamp.toEpochMilli)
writeChangeEventAndCallChangeHander(connection, updatedRows, entityType, change, changeEvent = None)
}
} yield updatedRows
}
Expand Down Expand Up @@ -669,7 +676,7 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection
Source.futureSource(result.map(Source(_))).mapMaterializedValue(_ => NotUsed)
}

def persistenceIds(afterId: Option[String], limit: Long): Source[String, NotUsed] = {
override def persistenceIds(afterId: Option[String], limit: Long): Source[String, NotUsed] = {
if (settings.durableStateTableByEntityTypeWithSchema.isEmpty)
persistenceIds(afterId, limit, settings.durableStateTableWithSchema)
else {
Expand Down Expand Up @@ -699,7 +706,7 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection
}
}

def persistenceIds(afterId: Option[String], limit: Long, table: String): Source[String, NotUsed] = {
override def persistenceIds(afterId: Option[String], limit: Long, table: String): Source[String, NotUsed] = {
val result = readPersistenceIds(afterId, limit, table)

Source.futureSource(result.map(Source(_))).mapMaterializedValue(_ => NotUsed)
Expand Down Expand Up @@ -729,7 +736,7 @@ private[r2dbc] class PostgresDurableStateDao(settings: R2dbcSettings, connection
result
}

def persistenceIds(entityType: String, afterId: Option[String], limit: Long): Source[String, NotUsed] = {
override def persistenceIds(entityType: String, afterId: Option[String], limit: Long): Source[String, NotUsed] = {
val table = settings.getDurableStateTableWithSchema(entityType)
val likeStmtPostfix = PersistenceId.DefaultSeparator + "%"
val result = r2dbcExecutor.select(s"select persistenceIds by entity type")(
Expand Down
Loading

0 comments on commit e67c6fd

Please sign in to comment.