Skip to content

Commit

Permalink
feat: Write change event of DurableState to event journal (#485)
Browse files Browse the repository at this point in the history
* store in event journal in same transaction
* impl DurableStateUpdateWithChangeEventStore trait
* test
* publish the event
* lazy journalDao, dialect as factory param
* Akka 2.9.1
* override in H2JournalDao
  • Loading branch information
patriknw authored Dec 19, 2023
1 parent 55aaee3 commit b5c077e
Show file tree
Hide file tree
Showing 17 changed files with 751 additions and 267 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import akka.actor.typed.ActorSystem
import akka.actor.typed.scaladsl.LoggerOps
import akka.annotation.ApiMayChange
import akka.annotation.InternalApi
import akka.dispatch.ExecutionContexts
import akka.persistence.r2dbc.ConnectionFactoryProvider
import akka.persistence.r2dbc.R2dbcSettings
import akka.persistence.r2dbc.internal.DurableStateDao
Expand Down Expand Up @@ -65,11 +66,17 @@ final class DurableStateCleanup(systemProvider: ClassicActorSystemProvider, conf
*/
def deleteState(persistenceId: String, resetRevisionNumber: Boolean): Future[Done] = {
if (resetRevisionNumber)
stateDao.deleteState(persistenceId, revision = 0L) // hard delete without revision check
stateDao
.deleteState(persistenceId, revision = 0L, changeEvent = None) // hard delete without revision check
.map(_ => Done)(ExecutionContexts.parasitic)
else {
stateDao.readState(persistenceId).flatMap {
case None => Future.successful(Done) // already deleted
case Some(s) => stateDao.deleteState(persistenceId, s.revision + 1)
case None =>
Future.successful(Done) // already deleted
case Some(s) =>
stateDao
.deleteState(persistenceId, s.revision + 1, changeEvent = None)
.map(_ => Done)(ExecutionContexts.parasitic)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ import akka.Done
import akka.NotUsed
import akka.annotation.InternalApi
import akka.stream.scaladsl.Source

import java.time.Instant

import scala.concurrent.Future

import akka.persistence.r2dbc.internal.JournalDao.SerializedJournalRow

/**
* INTERNAL API
*/
Expand Down Expand Up @@ -48,9 +50,15 @@ private[r2dbc] trait DurableStateDao extends BySliceQuery.Dao[DurableStateDao.Se

def readState(persistenceId: String): Future[Option[SerializedStateRow]]

def upsertState(state: SerializedStateRow, value: Any): Future[Done]
def upsertState(
state: SerializedStateRow,
value: Any,
changeEvent: Option[SerializedJournalRow]): Future[Option[Instant]]

def deleteState(persistenceId: String, revision: Long): Future[Done]
def deleteState(
persistenceId: String,
revision: Long,
changeEvent: Option[SerializedJournalRow]): Future[Option[Instant]]

def persistenceIds(afterId: Option[String], limit: Long): Source[String, NotUsed]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
package akka.persistence.r2dbc.internal

import akka.annotation.InternalApi

import java.time.Instant

import scala.concurrent.Future

import akka.persistence.r2dbc.internal.JournalDao.SerializedJournalRow
import io.r2dbc.spi.Connection

/**
* INTERNAL API
*/
Expand Down Expand Up @@ -56,6 +59,9 @@ private[r2dbc] trait JournalDao {
* a select (in same transaction).
*/
def writeEvents(events: Seq[JournalDao.SerializedJournalRow]): Future[Instant]

def writeEventInTx(event: SerializedJournalRow, connection: Connection): Future[Instant]

def readHighestSequenceNr(persistenceId: String, fromSequenceNr: Long): Future[Long]

def readLowestSequenceNr(persistenceId: String): Future[Long]
Expand Down
66 changes: 42 additions & 24 deletions core/src/main/scala/akka/persistence/r2dbc/internal/PubSub.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,29 +91,7 @@ import org.slf4j.LoggerFactory
}

def publish(pr: PersistentRepr, timestamp: Instant): Unit = {

val n = throughputCounter.incrementAndGet()
if (n % throughputSampler == 0) {
val ewma = throughput
val durationMillis = (System.nanoTime() - ewma.nanoTime) / 1000 / 1000
if (durationMillis >= throughputCollectIntervalMillis) {
// doesn't have to be exact so "missed" or duplicate concurrent calls don't matter
throughputCounter.set(0L)
val rps = n * 1000.0 / durationMillis
val newEwma = ewma :+ rps
throughput = newEwma
if (ewma.value < throughputThreshold && newEwma.value >= throughputThreshold) {
log.info("Disabled publishing of events. Throughput greater than [{}] events/s", throughputThreshold)
} else if (ewma.value >= throughputThreshold && newEwma.value < throughputThreshold) {
log.info("Enabled publishing of events. Throughput less than [{}] events/s", throughputThreshold)
} else {
log.debug(
"Publishing of events is {}. Throughput is [{}] events/s",
if (newEwma.value < throughputThreshold) "enabled" else "disabled",
newEwma.value)
}
}
}
updateThroughput()

if (throughput.value < throughputThreshold) {
val pid = pr.persistenceId
Expand Down Expand Up @@ -143,7 +121,47 @@ import org.slf4j.LoggerFactory
filtered,
source = EnvelopeOrigin.SourcePubSub,
tags)
eventTopic(entityType, slice) ! Topic.Publish(envelope)

publishToTopic(envelope)
}
}

def publish(envelope: EventEnvelope[Any]): Unit = {
updateThroughput()

if (throughput.value < throughputThreshold)
publishToTopic(envelope)
}

private def publishToTopic(envelope: EventEnvelope[Any]): Unit = {
val entityType = PersistenceId.extractEntityType(envelope.persistenceId)
val slice = persistenceExt.sliceForPersistenceId(envelope.persistenceId)

eventTopic(entityType, slice) ! Topic.Publish(envelope)
}

private def updateThroughput(): Unit = {
val n = throughputCounter.incrementAndGet()
if (n % throughputSampler == 0) {
val ewma = throughput
val durationMillis = (System.nanoTime() - ewma.nanoTime) / 1000 / 1000
if (durationMillis >= throughputCollectIntervalMillis) {
// doesn't have to be exact so "missed" or duplicate concurrent calls don't matter
throughputCounter.set(0L)
val rps = n * 1000.0 / durationMillis
val newEwma = ewma :+ rps
throughput = newEwma
if (ewma.value < throughputThreshold && newEwma.value >= throughputThreshold) {
log.info("Disabled publishing of events. Throughput greater than [{}] events/s", throughputThreshold)
} else if (ewma.value >= throughputThreshold && newEwma.value < throughputThreshold) {
log.info("Enabled publishing of events. Throughput less than [{}] events/s", throughputThreshold)
} else {
log.debug(
"Publishing of events is {}. Throughput is [{}] events/s",
if (newEwma.value < throughputThreshold) "enabled" else "disabled",
newEwma.value)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ import reactor.core.publisher.Mono
result.getRowsUpdated.asFuture().map(_.longValue())(ExecutionContexts.parasitic)
}

def updateOneReturningInTx[A](stmt: Statement, mapRow: Row => A)(implicit ec: ExecutionContext): Future[A] =
stmt.execute().asFuture().flatMap { result =>
Mono
.from[A](result.map((row, _) => mapRow(row)))
.asFuture()
}

def updateBatchInTx(stmt: Statement)(implicit ec: ExecutionContext): Future[Long] = {
val consumer: BiConsumer[Long, java.lang.Long] = (acc, elem) => acc + elem.longValue()
Flux
Expand Down Expand Up @@ -195,12 +202,7 @@ class R2dbcExecutor(
def updateOneReturning[A](
logPrefix: String)(statementFactory: Connection => Statement, mapRow: Row => A): Future[A] = {
withAutoCommitConnection(logPrefix) { connection =>
val stmt = statementFactory(connection)
stmt.execute().asFuture().flatMap { result =>
Mono
.from[A](result.map((row, _) => mapRow(row)))
.asFuture()
}
updateOneReturningInTx(statementFactory(connection), mapRow)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ private[r2dbc] object H2Dialect extends Dialect {

override def createDurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit
system: ActorSystem[_]): DurableStateDao =
new H2DurableStateDao(settings, connectionFactory)(ecForDaos(system, settings), system)
new H2DurableStateDao(settings, connectionFactory, this)(ecForDaos(system, settings), system)

private def ecForDaos(system: ActorSystem[_], settings: R2dbcSettings): ExecutionContext = {
// H2 R2DBC driver blocks in surprising places (Mono.toFuture in stmt.execute().asFuture())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,29 @@

package akka.persistence.r2dbc.internal.h2

import akka.actor.typed.ActorSystem
import akka.annotation.InternalApi
import akka.persistence.r2dbc.R2dbcSettings
import akka.persistence.r2dbc.internal.postgres.PostgresDurableStateDao
import scala.concurrent.ExecutionContext
import scala.concurrent.duration.Duration
import scala.concurrent.duration.FiniteDuration

import io.r2dbc.spi.ConnectionFactory
import org.slf4j.Logger
import org.slf4j.LoggerFactory

import scala.concurrent.ExecutionContext
import scala.concurrent.duration.Duration
import scala.concurrent.duration.FiniteDuration
import akka.actor.typed.ActorSystem
import akka.annotation.InternalApi
import akka.persistence.r2dbc.R2dbcSettings
import akka.persistence.r2dbc.internal.Dialect
import akka.persistence.r2dbc.internal.postgres.PostgresDurableStateDao

/**
* INTERNAL API
*/
@InternalApi
private[r2dbc] final class H2DurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit
ec: ExecutionContext,
system: ActorSystem[_])
extends PostgresDurableStateDao(settings, connectionFactory) {
private[r2dbc] final class H2DurableStateDao(
settings: R2dbcSettings,
connectionFactory: ConnectionFactory,
dialect: Dialect)(implicit ec: ExecutionContext, system: ActorSystem[_])
extends PostgresDurableStateDao(settings, connectionFactory, dialect) {

override protected lazy val log: Logger = LoggerFactory.getLogger(classOf[H2DurableStateDao])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ import io.r2dbc.spi.ConnectionFactory
import io.r2dbc.spi.Statement
import org.slf4j.Logger
import org.slf4j.LoggerFactory

import java.time.Instant

import scala.concurrent.ExecutionContext
import scala.concurrent.Future

import io.r2dbc.spi.Connection

import akka.persistence.r2dbc.internal.R2dbcExecutor

/**
* INTERNAL API
*/
Expand All @@ -35,7 +39,7 @@ private[r2dbc] class H2JournalDao(journalSettings: R2dbcSettings, connectionFact
require(journalSettings.useAppTimestamp)
require(journalSettings.dbTimestampMonotonicIncreasing)

val insertSql = sql"INSERT INTO $journalTable " +
private val insertSql = sql"INSERT INTO $journalTable " +
"(slice, entity_type, persistence_id, seq_nr, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, tags, meta_ser_id, meta_ser_manifest, meta_payload, db_timestamp) " +
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"

Expand All @@ -54,66 +58,74 @@ private[r2dbc] class H2JournalDao(journalSettings: R2dbcSettings, connectionFact

// it's always the same persistenceId for all events
val persistenceId = events.head.persistenceId
val previousSeqNr = events.head.seqNr - 1

def bind(stmt: Statement, write: SerializedJournalRow): Statement = {
stmt
.bind(0, write.slice)
.bind(1, write.entityType)
.bind(2, write.persistenceId)
.bind(3, write.seqNr)
.bind(4, write.writerUuid)
.bind(5, "") // FIXME event adapter
.bind(6, write.serId)
.bind(7, write.serManifest)
.bindPayload(8, write.payload.get)

if (write.tags.isEmpty)
stmt.bindNull(9, classOf[Array[String]])
else
stmt.bind(9, write.tags.toArray)

// optional metadata
write.metadata match {
case Some(m) =>
stmt
.bind(10, m.serId)
.bind(11, m.serManifest)
.bind(12, m.payload)
case None =>
stmt
.bindNull(10, classOf[Integer])
.bindNull(11, classOf[String])
.bindNull(12, classOf[Array[Byte]])

val totalEvents = events.size
val result =
if (totalEvents == 1) {
r2dbcExecutor.updateOne(s"insert [$persistenceId]")(connection =>
bindInsertStatement(connection.createStatement(insertSql), events.head))
} else {
r2dbcExecutor.updateInBatch(s"batch insert [$persistenceId], [$totalEvents] events")(connection =>
events.foldLeft(connection.createStatement(insertSql)) { (stmt, write) =>
stmt.add()
bindInsertStatement(stmt, write)
})
}

stmt.bind(13, write.dbTimestamp)
if (log.isDebugEnabled())
result.foreach { _ =>
log.debug("Wrote [{}] events for persistenceId [{}]", 1, events.head.persistenceId)
}
result.map(_ => events.head.dbTimestamp)(ExecutionContexts.parasitic)
}

stmt
}
override def writeEventInTx(event: SerializedJournalRow, connection: Connection): Future[Instant] = {
val persistenceId = event.persistenceId

val totalEvents = events.size
if (totalEvents == 1) {
val result = r2dbcExecutor.updateOne(s"insert [$persistenceId]")(connection =>
bind(connection.createStatement(insertSql), events.head))
if (log.isDebugEnabled())
result.foreach { _ =>
log.debug("Wrote [{}] events for persistenceId [{}]", 1, events.head.persistenceId)
}
result.map(_ => events.head.dbTimestamp)(ExecutionContexts.parasitic)
} else {
val result = r2dbcExecutor.updateInBatch(s"batch insert [$persistenceId], [$totalEvents] events")(connection =>
events.foldLeft(connection.createStatement(insertSql)) { (stmt, write) =>
stmt.add()
bind(stmt, write)
})
if (log.isDebugEnabled()) {
result.foreach { _ =>
log.debug("Wrote [{}] events for persistenceId [{}]", totalEvents, events.head.persistenceId)
}
val stmt = bindInsertStatement(connection.createStatement(insertSql), event)
val result = R2dbcExecutor.updateOneInTx(stmt)

if (log.isDebugEnabled())
result.foreach { _ =>
log.debug("Wrote [{}] event for persistenceId [{}]", 1, persistenceId)
}
result.map(_ => events.head.dbTimestamp)(ExecutionContexts.parasitic)
result.map(_ => event.dbTimestamp)(ExecutionContexts.parasitic)
}

private def bindInsertStatement(stmt: Statement, write: SerializedJournalRow): Statement = {
stmt
.bind(0, write.slice)
.bind(1, write.entityType)
.bind(2, write.persistenceId)
.bind(3, write.seqNr)
.bind(4, write.writerUuid)
.bind(5, "") // FIXME event adapter
.bind(6, write.serId)
.bind(7, write.serManifest)
.bindPayload(8, write.payload.get)

if (write.tags.isEmpty)
stmt.bindNull(9, classOf[Array[String]])
else
stmt.bind(9, write.tags.toArray)

// optional metadata
write.metadata match {
case Some(m) =>
stmt
.bind(10, m.serId)
.bind(11, m.serManifest)
.bind(12, m.payload)
case None =>
stmt
.bindNull(10, classOf[Integer])
.bindNull(11, classOf[String])
.bindNull(12, classOf[Array[Byte]])
}

stmt.bind(13, write.dbTimestamp)

stmt
}

}
Loading

0 comments on commit b5c077e

Please sign in to comment.