Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Cache sql construction #522

Merged
merged 4 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions core/src/main/scala/akka/persistence/r2dbc/internal/Sql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package akka.persistence.r2dbc.internal

import scala.annotation.varargs
import scala.collection.immutable.IntMap

import akka.annotation.InternalApi
import akka.annotation.InternalStableApi
Expand Down Expand Up @@ -81,4 +82,54 @@ object Sql {
}
}

object Cache {
def apply(dataPartitionsEnabled: Boolean): Cache =
if (dataPartitionsEnabled) new CacheBySlice
else new CacheIgnoringSlice
}

sealed trait Cache {
def get(slice: Int, key: Any)(orCreate: => String): String
}

private final class CacheIgnoringSlice extends Cache {
private var entries: Map[Any, String] = Map.empty

def get(slice: Int, key: Any)(orCreate: => String): String = {
entries.get(key) match {
case Some(value) => value
case None =>
// it's just a cache so no need for guarding concurrent updates
val entry = orCreate
entries = entries.updated(key, entry)
entry
}
}
}

private final class CacheBySlice extends Cache {
private var entriesPerSlice: IntMap[Map[Any, String]] = IntMap.empty

def get(slice: Int, key: Any)(orCreate: => String): String = {

def createEntry(entries: Map[Any, String]): String = {
// it's just a cache so no need for guarding concurrent updates
val entry = orCreate
val newEntries = entries.updated(key, entry)
entriesPerSlice = entriesPerSlice.updated(slice, newEntries)
entry
}

entriesPerSlice.get(slice) match {
case Some(entries) =>
entries.get(key) match {
case Some(value) => value
case None => createEntry(entries)
}
case None =>
createEntry(Map.empty)
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import akka.dispatch.ExecutionContexts
import akka.persistence.r2dbc.internal.JournalDao
import akka.persistence.r2dbc.internal.R2dbcExecutor
import akka.persistence.r2dbc.internal.R2dbcExecutorProvider
import akka.persistence.r2dbc.internal.Sql
import akka.persistence.r2dbc.internal.Sql.InterpolationWithAdapter
import akka.persistence.r2dbc.internal.codec.PayloadCodec.RichStatement
import akka.persistence.r2dbc.internal.postgres.PostgresJournalDao
Expand All @@ -36,9 +37,14 @@ private[r2dbc] class H2JournalDao(executorProvider: R2dbcExecutorProvider)
require(settings.useAppTimestamp)
require(settings.dbTimestampMonotonicIncreasing)

private def insertSql(slice: Int) = sql"INSERT INTO ${journalTable(slice)} " +
"(slice, entity_type, persistence_id, seq_nr, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, tags, meta_ser_id, meta_ser_manifest, meta_payload, db_timestamp) " +
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
private val sqlCache = Sql.Cache(settings.numberOfDataPartitions > 1)

private def insertSql(slice: Int) =
sqlCache.get(slice, "insertSql") {
sql"INSERT INTO ${journalTable(slice)} " +
"(slice, entity_type, persistence_id, seq_nr, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, tags, meta_ser_id, meta_ser_manifest, meta_payload, db_timestamp) " +
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
}

/**
* All events must be for the same persistenceId.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ private[r2dbc] class H2QueryDao(executorProvider: R2dbcExecutorProvider) extends
backtracking: Boolean,
minSlice: Int,
maxSlice: Int): String = {
// not caching, too many combinations

def toDbTimestampParamCondition =
if (toDbTimestampParam) "AND db_timestamp <= ?" else ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.slf4j.LoggerFactory

import akka.annotation.InternalApi
import akka.persistence.r2dbc.internal.R2dbcExecutorProvider
import akka.persistence.r2dbc.internal.Sql
import akka.persistence.r2dbc.internal.Sql.InterpolationWithAdapter
import akka.persistence.r2dbc.internal.postgres.PostgresSnapshotDao

Expand All @@ -22,21 +23,24 @@ private[r2dbc] final class H2SnapshotDao(executorProvider: R2dbcExecutorProvider

override protected lazy val log: Logger = LoggerFactory.getLogger(classOf[H2SnapshotDao])

override protected def upsertSql(slice: Int): String = {
// db_timestamp and tags columns were added in 1.2.0
if (settings.querySettings.startFromSnapshotEnabled)
sql"""
MERGE INTO ${snapshotTable(slice)}
(slice, entity_type, persistence_id, seq_nr, write_timestamp, snapshot, ser_id, ser_manifest, meta_payload, meta_ser_id, meta_ser_manifest, db_timestamp, tags)
KEY (persistence_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
else
sql"""
MERGE INTO ${snapshotTable(slice)}
(slice, entity_type, persistence_id, seq_nr, write_timestamp, snapshot, ser_id, ser_manifest, meta_payload, meta_ser_id, meta_ser_manifest)
KEY (persistence_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
}
private val sqlCache = Sql.Cache(settings.numberOfDataPartitions > 1)

override protected def upsertSql(slice: Int): String =
sqlCache.get(slice, "upsertSql") {
// db_timestamp and tags columns were added in 1.2.0
if (settings.querySettings.startFromSnapshotEnabled)
sql"""
MERGE INTO ${snapshotTable(slice)}
(slice, entity_type, persistence_id, seq_nr, write_timestamp, snapshot, ser_id, ser_manifest, meta_payload, meta_ser_id, meta_ser_manifest, db_timestamp, tags)
KEY (persistence_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
else
sql"""
MERGE INTO ${snapshotTable(slice)}
(slice, entity_type, persistence_id, seq_nr, write_timestamp, snapshot, ser_id, ser_manifest, meta_payload, meta_ser_id, meta_ser_manifest)
KEY (persistence_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import akka.persistence.r2dbc.internal.codec.PayloadCodec.RichRow
import akka.persistence.r2dbc.internal.codec.PayloadCodec.RichStatement
import akka.persistence.r2dbc.internal.R2dbcExecutor
import akka.persistence.r2dbc.internal.R2dbcExecutorProvider
import akka.persistence.r2dbc.internal.Sql
import akka.persistence.r2dbc.internal.Sql.InterpolationWithAdapter
import akka.persistence.r2dbc.internal.codec.TagsCodec.TagsCodecRichStatement
import akka.persistence.r2dbc.internal.codec.TimestampCodec.TimestampCodecRichRow
Expand Down Expand Up @@ -90,6 +91,8 @@ private[r2dbc] class PostgresDurableStateDao(executorProvider: R2dbcExecutorProv

private val persistenceExt = Persistence(system)

private val sqlCache = Sql.Cache(settings.numberOfDataPartitions > 1)

// used for change events
private lazy val journalDao: JournalDao = dialect.createJournalDao(executorProvider)

Expand All @@ -107,24 +110,26 @@ private[r2dbc] class PostgresDurableStateDao(executorProvider: R2dbcExecutorProv
}
}

protected def selectStateSql(slice: Int, entityType: String): String = {
val stateTable = settings.getDurableStateTableWithSchema(entityType, slice)
sql"""
SELECT revision, state_ser_id, state_ser_manifest, state_payload, db_timestamp
FROM $stateTable WHERE persistence_id = ?"""
}
protected def selectStateSql(slice: Int, entityType: String): String =
sqlCache.get(slice, s"selectStateSql-$entityType") {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when reviewing this, the highest risk of mistake (by me) is using the wrong cache key, such as accidentally using the same in several places

val stateTable = settings.getDurableStateTableWithSchema(entityType, slice)
sql"""
SELECT revision, state_ser_id, state_ser_manifest, state_payload, db_timestamp
FROM $stateTable WHERE persistence_id = ?"""
}

protected def selectBucketsSql(entityType: String, minSlice: Int, maxSlice: Int): String = {
val stateTable = settings.getDurableStateTableWithSchema(entityType, minSlice)
sql"""
SELECT extract(EPOCH from db_timestamp)::BIGINT / 10 AS bucket, count(*) AS count
FROM $stateTable
WHERE entity_type = ?
AND ${sliceCondition(minSlice, maxSlice)}
AND db_timestamp >= ? AND db_timestamp <= ?
GROUP BY bucket ORDER BY bucket LIMIT ?
"""
}
protected def selectBucketsSql(entityType: String, minSlice: Int, maxSlice: Int): String =
sqlCache.get(minSlice, s"selectBucketsSql-$entityType-$minSlice-$maxSlice") {
val stateTable = settings.getDurableStateTableWithSchema(entityType, minSlice)
sql"""
SELECT extract(EPOCH from db_timestamp)::BIGINT / 10 AS bucket, count(*) AS count
FROM $stateTable
WHERE entity_type = ?
AND ${sliceCondition(minSlice, maxSlice)}
AND db_timestamp >= ? AND db_timestamp <= ?
GROUP BY bucket ORDER BY bucket LIMIT ?
"""
}

protected def sliceCondition(minSlice: Int, maxSlice: Int): String =
s"slice in (${(minSlice to maxSlice).mkString(",")})"
Expand All @@ -133,13 +138,20 @@ private[r2dbc] class PostgresDurableStateDao(executorProvider: R2dbcExecutorProv
slice: Int,
entityType: String,
additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = {
val stateTable = settings.getDurableStateTableWithSchema(entityType, slice)
val additionalCols = additionalInsertColumns(additionalBindings)
val additionalParams = additionalInsertParameters(additionalBindings)
sql"""
INSERT INTO $stateTable
(slice, entity_type, persistence_id, revision, state_ser_id, state_ser_manifest, state_payload, tags$additionalCols, db_timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?, ?$additionalParams, CURRENT_TIMESTAMP)"""
def createSql = {
val stateTable = settings.getDurableStateTableWithSchema(entityType, slice)
val additionalCols = additionalInsertColumns(additionalBindings)
val additionalParams = additionalInsertParameters(additionalBindings)
sql"""
INSERT INTO $stateTable
(slice, entity_type, persistence_id, revision, state_ser_id, state_ser_manifest, state_payload, tags$additionalCols, db_timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?, ?$additionalParams, CURRENT_TIMESTAMP)"""
}

if (additionalBindings.isEmpty)
sqlCache.get(slice, s"insertStateSql-$entityType")(createSql)
else
createSql // no cache
}

protected def additionalInsertColumns(
Expand Down Expand Up @@ -179,28 +191,35 @@ private[r2dbc] class PostgresDurableStateDao(executorProvider: R2dbcExecutorProv
updateTags: Boolean,
additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings],
currentTimestamp: String = "CURRENT_TIMESTAMP"): String = {
def createSql = {
val stateTable = settings.getDurableStateTableWithSchema(entityType, slice)

val stateTable = settings.getDurableStateTableWithSchema(entityType, slice)
val timestamp =
if (settings.dbTimestampMonotonicIncreasing)
currentTimestamp
else
"GREATEST(CURRENT_TIMESTAMP, " +
s"(SELECT db_timestamp + '1 microsecond'::interval FROM $stateTable WHERE persistence_id = ? AND revision = ?))"

val timestamp =
if (settings.dbTimestampMonotonicIncreasing)
currentTimestamp
else
"GREATEST(CURRENT_TIMESTAMP, " +
s"(SELECT db_timestamp + '1 microsecond'::interval FROM $stateTable WHERE persistence_id = ? AND revision = ?))"

val revisionCondition =
if (settings.durableStateAssertSingleWriter) " AND revision = ?"
else ""
val revisionCondition =
if (settings.durableStateAssertSingleWriter) " AND revision = ?"
else ""

val tags = if (updateTags) ", tags = ?" else ""
val tags = if (updateTags) ", tags = ?" else ""

val additionalParams = additionalUpdateParameters(additionalBindings)
sql"""
val additionalParams = additionalUpdateParameters(additionalBindings)
sql"""
UPDATE $stateTable
SET revision = ?, state_ser_id = ?, state_ser_manifest = ?, state_payload = ?$tags$additionalParams, db_timestamp = $timestamp
WHERE persistence_id = ?
$revisionCondition"""
}

if (additionalBindings.isEmpty)
// timestamp param doesn't have to be part of cache key because it's just different for different dialects
sqlCache.get(slice, s"updateStateSql-$entityType-$updateTags")(createSql)
else
createSql // no cache
}

protected def additionalUpdateParameters(
Expand All @@ -220,21 +239,29 @@ private[r2dbc] class PostgresDurableStateDao(executorProvider: R2dbcExecutorProv
}

protected def hardDeleteStateSql(entityType: String, slice: Int): String = {
val stateTable = settings.getDurableStateTableWithSchema(entityType, slice)
sql"DELETE from $stateTable WHERE persistence_id = ?"
sqlCache.get(slice, s"hardDeleteStateSql-$entityType") {
val stateTable = settings.getDurableStateTableWithSchema(entityType, slice)
sql"DELETE from $stateTable WHERE persistence_id = ?"
}
}

private val currentDbTimestampSql =
sql"SELECT CURRENT_TIMESTAMP AS db_timestamp"

protected def allPersistenceIdsSql(table: String): String =
protected def allPersistenceIdsSql(table: String): String = {
// not worth caching
sql"SELECT persistence_id from $table ORDER BY persistence_id LIMIT ?"
}

protected def persistenceIdsForEntityTypeSql(table: String): String =
protected def persistenceIdsForEntityTypeSql(table: String): String = {
// not worth caching
sql"SELECT persistence_id from $table WHERE persistence_id LIKE ? ORDER BY persistence_id LIMIT ?"
}

protected def allPersistenceIdsAfterSql(table: String): String =
protected def allPersistenceIdsAfterSql(table: String): String = {
// not worth caching
sql"SELECT persistence_id from $table WHERE persistence_id > ? ORDER BY persistence_id LIMIT ?"
}

protected def bindPersistenceIdsForEntityTypeAfterSql(
stmt: Statement,
Expand All @@ -248,8 +275,10 @@ private[r2dbc] class PostgresDurableStateDao(executorProvider: R2dbcExecutorProv
.bind(2, limit)
}

protected def persistenceIdsForEntityTypeAfterSql(table: String): String =
protected def persistenceIdsForEntityTypeAfterSql(table: String): String = {
// not worth caching
sql"SELECT persistence_id from $table WHERE persistence_id LIKE ? AND persistence_id > ? ORDER BY persistence_id LIMIT ?"
}

protected def behindCurrentTimeIntervalConditionFor(behindCurrentTime: FiniteDuration): String =
if (behindCurrentTime > Duration.Zero)
Expand All @@ -263,6 +292,7 @@ private[r2dbc] class PostgresDurableStateDao(executorProvider: R2dbcExecutorProv
backtracking: Boolean,
minSlice: Int,
maxSlice: Int): String = {
// not caching, too many combinations

val stateTable = settings.getDurableStateTableWithSchema(entityType, minSlice)

Expand Down
Loading
Loading