Skip to content

Commit

Permalink
perf: Avoid asyncReadHighestSequenceNr query (#583)
Browse files Browse the repository at this point in the history
* perf: Avoid asyncReadHighestSequenceNr query
* AsyncReplay added in Akka akka/akka#32434
* Akka 2.9.4
  • Loading branch information
patriknw authored Jun 25, 2024
1 parent 12cfe90 commit cc4aa83
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 45 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ inThisBuild(
resolvers += "Akka library repository".at("https://repo.akka.io/maven"),
// add snapshot repo when Akka version overriden
resolvers ++=
(if (System.getProperty("override.akka.version") != null)
(if (Dependencies.AkkaVersion.endsWith("-SNAPSHOT"))
Seq("Akka library snapshot repository".at("https://repo.akka.io/snapshots"))
else Seq.empty)))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ private[r2dbc] trait QueryDao extends BySliceQuery.Dao[SerializedJournalRow] {
def eventsByPersistenceId(
persistenceId: String,
fromSequenceNr: Long,
toSequenceNr: Long): Source[SerializedJournalRow, NotUsed]
toSequenceNr: Long,
includeDeleted: Boolean): Source[SerializedJournalRow, NotUsed]

def persistenceIds(entityType: String, afterId: Option[String], limit: Long): Source[String, NotUsed]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,24 @@ private[r2dbc] class PostgresQueryDao(executorProvider: R2dbcExecutorProvider) e
protected def selectEventsSql(slice: Int): String =
sqlCache.get(slice, "selectEventsSql") {
sql"""
SELECT slice, entity_type, persistence_id, seq_nr, db_timestamp, CURRENT_TIMESTAMP AS read_db_timestamp, event_ser_id, event_ser_manifest, event_payload, writer, adapter_manifest, meta_ser_id, meta_ser_manifest, meta_payload, tags
SELECT slice, entity_type, seq_nr, db_timestamp, CURRENT_TIMESTAMP AS read_db_timestamp, event_ser_id, event_ser_manifest, event_payload, writer, adapter_manifest, meta_ser_id, meta_ser_manifest, meta_payload, tags
from ${journalTable(slice)}
WHERE persistence_id = ? AND seq_nr >= ? AND seq_nr <= ?
AND deleted = false
ORDER BY seq_nr
LIMIT ?"""
}

protected def selectEventsIncludeDeletedSql(slice: Int): String =
sqlCache.get(slice, "selectEventsIncludeDeletedSql") {
sql"""
SELECT slice, entity_type, seq_nr, db_timestamp, CURRENT_TIMESTAMP AS read_db_timestamp, event_ser_id, event_ser_manifest, event_payload, writer, adapter_manifest, meta_ser_id, meta_ser_manifest, meta_payload, tags, deleted
from ${journalTable(slice)}
WHERE persistence_id = ? AND seq_nr >= ? AND seq_nr <= ?
ORDER BY seq_nr
LIMIT ?"""
}

protected def bindSelectEventsSql(
stmt: Statement,
persistenceId: String,
Expand Down Expand Up @@ -378,28 +388,47 @@ private[r2dbc] class PostgresQueryDao(executorProvider: R2dbcExecutorProvider) e
override def eventsByPersistenceId(
persistenceId: String,
fromSequenceNr: Long,
toSequenceNr: Long): Source[SerializedJournalRow, NotUsed] = {
toSequenceNr: Long,
includeDeleted: Boolean): Source[SerializedJournalRow, NotUsed] = {
val slice = persistenceExt.sliceForPersistenceId(persistenceId)
val executor = executorProvider.executorFor(slice)
val result = executor.select(s"select eventsByPersistenceId [$persistenceId]")(
connection => {
val stmt = connection.createStatement(selectEventsSql(slice))
val selectSql = if (includeDeleted) selectEventsIncludeDeletedSql(slice) else selectEventsSql(slice)
val stmt = connection.createStatement(selectSql)
bindSelectEventsSql(stmt, persistenceId, fromSequenceNr, toSequenceNr, settings.querySettings.bufferSize)
},
row =>
SerializedJournalRow(
slice = row.get[Integer]("slice", classOf[Integer]),
entityType = row.get("entity_type", classOf[String]),
persistenceId = row.get("persistence_id", classOf[String]),
seqNr = row.get[java.lang.Long]("seq_nr", classOf[java.lang.Long]),
dbTimestamp = row.getTimestamp("db_timestamp"),
readDbTimestamp = row.getTimestamp("read_db_timestamp"),
payload = Some(row.getPayload("event_payload")),
serId = row.get[Integer]("event_ser_id", classOf[Integer]),
serManifest = row.get("event_ser_manifest", classOf[String]),
writerUuid = row.get("writer", classOf[String]),
tags = row.getTags("tags"),
metadata = readMetadata(row)))
if (includeDeleted && row.get[java.lang.Boolean]("deleted", classOf[java.lang.Boolean])) {
// deleted row
SerializedJournalRow(
slice = row.get[Integer]("slice", classOf[Integer]),
entityType = row.get("entity_type", classOf[String]),
persistenceId = persistenceId,
seqNr = row.get[java.lang.Long]("seq_nr", classOf[java.lang.Long]),
dbTimestamp = row.getTimestamp("db_timestamp"),
readDbTimestamp = row.getTimestamp("read_db_timestamp"),
payload = None,
serId = 0,
serManifest = "",
writerUuid = "",
tags = Set.empty,
metadata = None)
} else {
SerializedJournalRow(
slice = row.get[Integer]("slice", classOf[Integer]),
entityType = row.get("entity_type", classOf[String]),
persistenceId = persistenceId,
seqNr = row.get[java.lang.Long]("seq_nr", classOf[java.lang.Long]),
dbTimestamp = row.getTimestamp("db_timestamp"),
readDbTimestamp = row.getTimestamp("read_db_timestamp"),
payload = Some(row.getPayload("event_payload")),
serId = row.get[Integer]("event_ser_id", classOf[Integer]),
serManifest = row.get("event_ser_manifest", classOf[String]),
writerUuid = row.get("writer", classOf[String]),
tags = row.getTags("tags"),
metadata = readMetadata(row))
})

if (log.isDebugEnabled)
result.foreach(rows => log.debug("Read [{}] events for persistenceId [{}]", rows.size, persistenceId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ private[r2dbc] class SqlServerQueryDao(executorProvider: R2dbcExecutorProvider)
ORDER BY seq_nr"""
}

override protected def selectEventsIncludeDeletedSql(slice: Int): String =
sqlCache.get(slice, "selectEventsIncludeDeletedSql") {
sql"""
SELECT TOP(@limit) slice, entity_type, persistence_id, seq_nr, db_timestamp, SYSUTCDATETIME() AS read_db_timestamp, event_ser_id, event_ser_manifest, event_payload, writer, adapter_manifest, meta_ser_id, meta_ser_manifest, meta_payload, tags, deleted
from ${journalTable(slice)}
WHERE persistence_id = @persistenceId AND seq_nr >= @from AND seq_nr <= @to
ORDER BY seq_nr"""
}

/**
* custom binding because the first param in the query is @limit (or '0' when using positional binding)
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import akka.stream.scaladsl.Sink
import com.typesafe.config.Config
import org.slf4j.LoggerFactory

import akka.persistence.journal.AsyncReplay
import akka.persistence.r2dbc.internal.R2dbcExecutorProvider

/**
Expand Down Expand Up @@ -72,13 +73,16 @@ private[r2dbc] object R2dbcJournal {
}
reprWithMeta
}

val FutureDone: Future[Done] = Future.successful(Done)
}

/**
* INTERNAL API
*/
@InternalApi
private[r2dbc] final class R2dbcJournal(config: Config, cfgPath: String) extends AsyncWriteJournal {
private[r2dbc] final class R2dbcJournal(config: Config, cfgPath: String) extends AsyncWriteJournal with AsyncReplay {
import R2dbcJournal.FutureDone
import R2dbcJournal.WriteFinished
import R2dbcJournal.deserializeRow

Expand Down Expand Up @@ -215,30 +219,71 @@ private[r2dbc] final class R2dbcJournal(config: Config, cfgPath: String) extends
journalDao.deleteEventsTo(persistenceId, toSequenceNr, resetSequenceNumber = false)
}

override def asyncReplayMessages(persistenceId: String, fromSequenceNr: Long, toSequenceNr: Long, max: Long)(
recoveryCallback: PersistentRepr => Unit): Future[Unit] = {
log.debug("asyncReplayMessages persistenceId [{}], fromSequenceNr [{}]", persistenceId, fromSequenceNr)
val effectiveToSequenceNr =
if (max == Long.MaxValue) toSequenceNr
else math.min(toSequenceNr, fromSequenceNr + max - 1)
query
.internalCurrentEventsByPersistenceId(persistenceId, fromSequenceNr, effectiveToSequenceNr)
.runWith(Sink.foreach { row =>
val repr = deserializeRow(serialization, row)
recoveryCallback(repr)
})
.map(_ => ())
}

override def asyncReadHighestSequenceNr(persistenceId: String, fromSequenceNr: Long): Future[Long] = {
log.debug("asyncReadHighestSequenceNr [{}] [{}]", persistenceId, fromSequenceNr)
override def replayMessages(persistenceId: String, fromSequenceNr: Long, toSequenceNr: Long, max: Long)(
recoveryCallback: PersistentRepr => Unit): Future[Long] = {
log.debug("replayMessages [{}] [{}]", persistenceId, fromSequenceNr)
val pendingWrite = Option(writesInProgress.get(persistenceId)) match {
case Some(f) =>
log.debug("Write in progress for [{}], deferring highest seq nr until write completed", persistenceId)
log.debug("Write in progress for [{}], deferring replayMessages until write completed", persistenceId)
// we only want to make write - replay sequential, not fail if previous write failed
f.recover { case _ => Done }(ExecutionContexts.parasitic)
case None => Future.successful(Done)
case None => FutureDone
}
pendingWrite.flatMap { _ =>
if (toSequenceNr == Long.MaxValue && max == Long.MaxValue) {
// this is the normal case, highest sequence number from last event
query
.internalCurrentEventsByPersistenceId(
persistenceId,
fromSequenceNr,
toSequenceNr,
readHighestSequenceNr = false,
includeDeleted = true)
.runWith(Sink.fold(0L) { (_, item) =>
// payload is empty for deleted item
if (item.payload.isDefined) {
val repr = deserializeRow(serialization, item)
recoveryCallback(repr)
}
item.seqNr
})
} else if (toSequenceNr <= 0) {
// no replay
journalDao.readHighestSequenceNr(persistenceId, fromSequenceNr)
} else {
// replay to custom sequence number

val highestSeqNr = journalDao.readHighestSequenceNr(persistenceId, fromSequenceNr)

val effectiveToSequenceNr =
if (max == Long.MaxValue) toSequenceNr
else math.min(toSequenceNr, fromSequenceNr + max - 1)

query
.internalCurrentEventsByPersistenceId(
persistenceId,
fromSequenceNr,
effectiveToSequenceNr,
readHighestSequenceNr = false,
includeDeleted = false)
.runWith(Sink
.foreach { item =>
val repr = deserializeRow(serialization, item)
recoveryCallback(repr)
})
.flatMap(_ => highestSeqNr)
}
}
pendingWrite.flatMap(_ => journalDao.readHighestSequenceNr(persistenceId, fromSequenceNr))
}

override def asyncReplayMessages(persistenceId: String, fromSequenceNr: Long, toSequenceNr: Long, max: Long)(
recoveryCallback: PersistentRepr => Unit): Future[Unit] = {
throw new IllegalStateException(
"asyncReplayMessages is not supposed to be called when implementing AsyncReplay. This is a bug, please report.")
}

override def asyncReadHighestSequenceNr(persistenceId: String, fromSequenceNr: Long): Future[Long] = {
throw new IllegalStateException(
"asyncReplayMessages is not supposed to be called when implementing AsyncReplay. This is a bug, please report.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import com.typesafe.config.Config
import org.slf4j.LoggerFactory

import akka.persistence.r2dbc.internal.R2dbcExecutorProvider
import akka.util.OptionVal

object R2dbcReadJournal {
val Identifier = "akka.persistence.r2dbc.query"
Expand Down Expand Up @@ -569,7 +570,9 @@ final class R2dbcReadJournal(system: ExtendedActorSystem, config: Config, cfgPat
@InternalApi private[r2dbc] def internalCurrentEventsByPersistenceId(
persistenceId: String,
fromSequenceNr: Long,
toSequenceNr: Long): Source[SerializedJournalRow, NotUsed] = {
toSequenceNr: Long,
readHighestSequenceNr: Boolean = true,
includeDeleted: Boolean = false): Source[SerializedJournalRow, NotUsed] = {

def updateState(state: ByPersistenceIdState, row: SerializedJournalRow): ByPersistenceIdState =
state.copy(rowCount = state.rowCount + 1, latestSeqNr = row.seqNr)
Expand All @@ -591,7 +594,7 @@ final class R2dbcReadJournal(system: ExtendedActorSystem, config: Config, cfgPat

newState -> Some(
queryDao
.eventsByPersistenceId(persistenceId, state.latestSeqNr + 1, highestSeqNr))
.eventsByPersistenceId(persistenceId, state.latestSeqNr + 1, highestSeqNr, includeDeleted))
} else {
log.debugN(
"currentEventsByPersistenceId query [{}] for persistenceId [{}] completed. Found [{}] rows in previous query.",
Expand All @@ -611,7 +614,8 @@ final class R2dbcReadJournal(system: ExtendedActorSystem, config: Config, cfgPat
toSequenceNr)

val highestSeqNrFut =
if (toSequenceNr == Long.MaxValue) journalDao.readHighestSequenceNr(persistenceId, fromSequenceNr)
if (readHighestSequenceNr && toSequenceNr == Long.MaxValue)
journalDao.readHighestSequenceNr(persistenceId, fromSequenceNr)
else Future.successful(toSequenceNr)

Source
Expand Down Expand Up @@ -707,7 +711,7 @@ final class R2dbcReadJournal(system: ExtendedActorSystem, config: Config, cfgPat
newState ->
Some(
queryDao
.eventsByPersistenceId(persistenceId, state.latestSeqNr + 1, toSequenceNr))
.eventsByPersistenceId(persistenceId, state.latestSeqNr + 1, toSequenceNr, includeDeleted = false))
}
}

Expand Down
2 changes: 1 addition & 1 deletion native-image-tests/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ scalaVersion := "2.13.14"

resolvers += "Akka library repository".at("https://repo.akka.io/maven")

lazy val akkaVersion = sys.props.getOrElse("akka.version", "2.9.3")
lazy val akkaVersion = sys.props.getOrElse("akka.version", "2.9.4")
lazy val akkaR2dbcVersion = sys.props.getOrElse("akka.r2dbc.version", "1.2.3")

fork := true
Expand Down
2 changes: 1 addition & 1 deletion project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ object Dependencies {
val Scala3 = "3.3.3"
val Scala2Versions = Seq(Scala213)
val ScalaVersions = Dependencies.Scala2Versions :+ Dependencies.Scala3
val AkkaVersion = System.getProperty("override.akka.version", "2.9.3")
val AkkaVersion = System.getProperty("override.akka.version", "2.9.4")
val AkkaVersionInDocs = VersionNumber(AkkaVersion).numbers match { case Seq(major, minor, _*) => s"$major.$minor" }
val AkkaPersistenceJdbcVersion = "5.4.0" // only in migration tool tests
val AkkaProjectionVersionInDocs = "current"
Expand Down

0 comments on commit cc4aa83

Please sign in to comment.