From c194e4f0bfae4bc30defe9ef6c4962fa582bd44b Mon Sep 17 00:00:00 2001 From: sebastian-alfers Date: Tue, 9 Jan 2024 11:49:42 +0100 Subject: [PATCH] Separate sql query definition+binding from the actual locic+error handling to reuse the latter one while adding new dialects --- .../persistence/r2dbc/R2dbcSettings.scala | 16 +- .../r2dbc/internal/TimestampCodec.scala | 66 ++ .../r2dbc/internal/h2/H2DurableStateDao.scala | 11 +- .../internal/h2/sql/H2DurableStateSql.scala | 32 + .../postgres/PostgresDurableStateDao.scala | 379 ++-------- .../postgres/PostgresJournalDao.scala | 214 ++---- .../postgres/sql/BaseDurableStateSql.scala | 103 +++ .../postgres/sql/BaseJournalSql.scala | 55 ++ .../sql/PostgresDurableStateSql.scala | 379 ++++++++++ .../postgres/sql/PostgresJournalSql.scala | 193 +++++ .../internal/sqlserver/SqlServerDialect.scala | 2 +- .../sqlserver/SqlServerDialectHelper.scala | 18 +- .../sqlserver/SqlServerDurableStateDao.scala | 711 +----------------- .../sqlserver/SqlServerJournalDao.scala | 275 +------ .../sqlserver/SqlServerQueryDao.scala | 46 +- .../sqlserver/SqlServerSnapshotDao.scala | 26 +- .../sql/SqlServerDurableStateSql.scala | 350 +++++++++ .../sqlserver/sql/SqlServerJournalSql.scala | 161 ++++ 18 files changed, 1571 insertions(+), 1466 deletions(-) create mode 100644 core/src/main/scala/akka/persistence/r2dbc/internal/TimestampCodec.scala create mode 100644 core/src/main/scala/akka/persistence/r2dbc/internal/h2/sql/H2DurableStateSql.scala create mode 100644 core/src/main/scala/akka/persistence/r2dbc/internal/postgres/sql/BaseDurableStateSql.scala create mode 100644 core/src/main/scala/akka/persistence/r2dbc/internal/postgres/sql/BaseJournalSql.scala create mode 100644 core/src/main/scala/akka/persistence/r2dbc/internal/postgres/sql/PostgresDurableStateSql.scala create mode 100644 core/src/main/scala/akka/persistence/r2dbc/internal/postgres/sql/PostgresJournalSql.scala create mode 100644 core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/sql/SqlServerDurableStateSql.scala create mode 100644 core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/sql/SqlServerJournalSql.scala diff --git a/core/src/main/scala/akka/persistence/r2dbc/R2dbcSettings.scala b/core/src/main/scala/akka/persistence/r2dbc/R2dbcSettings.scala index 80d1185c..cd6e9a48 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/R2dbcSettings.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/R2dbcSettings.scala @@ -6,8 +6,7 @@ package akka.persistence.r2dbc import akka.annotation.InternalApi import akka.annotation.InternalStableApi -import akka.persistence.r2dbc.internal.ConnectionFactorySettings -import akka.persistence.r2dbc.internal.PayloadCodec +import akka.persistence.r2dbc.internal.{ ConnectionFactorySettings, PayloadCodec, TimestampCodec } import akka.util.JavaDurationConverters._ import com.typesafe.config.Config @@ -83,6 +82,13 @@ object R2dbcSettings { val connectionFactorySettings = ConnectionFactorySettings(config.getConfig("connection-factory")) + val timestampCodec: TimestampCodec = { + connectionFactorySettings.dialect.name match { + case "sqlserver" => TimestampCodec.SqlServerCodec + case _ => TimestampCodec.PostgresTimestampCodec + } + } + val querySettings = new QuerySettings(config.getConfig("query")) val dbTimestampMonotonicIncreasing: Boolean = config.getBoolean("db-timestamp-monotonic-increasing") @@ -105,6 +111,7 @@ object R2dbcSettings { snapshotPayloadCodec, durableStateTable, durableStatePayloadCodec, + timestampCodec, durableStateAssertSingleWriter, logDbCallsExceeding, querySettings, @@ -139,6 +146,7 @@ final class R2dbcSettings private ( val snapshotPayloadCodec: PayloadCodec, val durableStateTable: String, val durableStatePayloadCodec: PayloadCodec, + val timestampCodec: TimestampCodec, val durableStateAssertSingleWriter: Boolean, val logDbCallsExceeding: FiniteDuration, val querySettings: QuerySettings, @@ -155,7 +163,7 @@ final class R2dbcSettings private ( val durableStateTableWithSchema: String = schema.map(_ + ".").getOrElse("") + durableStateTable /** - * One of the supported dialects 'postgres', 'yugabyte' or 'h2' + * One of the supported dialects 'postgres', 'yugabyte', 'sqlserver' or 'h2' */ def dialectName: String = _connectionFactorySettings.dialect.name @@ -217,6 +225,7 @@ final class R2dbcSettings private ( snapshotPayloadCodec: PayloadCodec = snapshotPayloadCodec, durableStateTable: String = durableStateTable, durableStatePayloadCodec: PayloadCodec = durableStatePayloadCodec, + timestampCodec: TimestampCodec = timestampCodec, durableStateAssertSingleWriter: Boolean = durableStateAssertSingleWriter, logDbCallsExceeding: FiniteDuration = logDbCallsExceeding, querySettings: QuerySettings = querySettings, @@ -237,6 +246,7 @@ final class R2dbcSettings private ( snapshotPayloadCodec, durableStateTable, durableStatePayloadCodec, + timestampCodec, durableStateAssertSingleWriter, logDbCallsExceeding, querySettings, diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/TimestampCodec.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/TimestampCodec.scala new file mode 100644 index 00000000..86105a8d --- /dev/null +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/TimestampCodec.scala @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2022 - 2023 Lightbend Inc. + */ + +package akka.persistence.r2dbc.internal + +import java.nio.charset.StandardCharsets.UTF_8 +import akka.annotation.InternalApi +import io.r2dbc.postgresql.codec.Json +import io.r2dbc.spi.Row +import io.r2dbc.spi.Statement + +import java.time.{ Instant, LocalDateTime } +import java.util.TimeZone + +/** + * INTERNAL API + */ +@InternalApi private[akka] sealed trait TimestampCodec { + + def encode(timestamp: Instant): Any + def decode(row: Row, name: String): Instant + + protected def instantNow() = InstantFactory.now() + + def now[T](): T +} + +/** + * INTERNAL API + */ +@InternalApi private[akka] object TimestampCodec { + case object PostgresTimestampCodec extends TimestampCodec { + override def decode(row: Row, name: String): Instant = row.get(name, classOf[Instant]) + + override def encode(timestamp: Instant): Any = timestamp + + override def now[T](): T = instantNow().asInstanceOf[T] + } + + case object SqlServerCodec extends TimestampCodec { + + // should this come from config? + private val zone = TimeZone.getTimeZone("UTC").toZoneId + + override def decode(row: Row, name: String): Instant = { + row + .get(name, classOf[LocalDateTime]) + .atZone(zone) + .toInstant + } + + override def encode(timestamp: Instant): Any = LocalDateTime.ofInstant(timestamp, zone) + + override def now[T](): T = LocalDateTime.ofInstant(instantNow(), zone).asInstanceOf[T] + } + + + implicit class RichStatement[T](val statement: Statement)(implicit codec: TimestampCodec) extends AnyRef { + def bindTimestamp(name: String, timestamp: Instant): Statement = statement.bind(name, codec.encode(timestamp)) + def bindTimestamp(index: Int, timestamp: Instant): Statement = statement.bind(index, codec.encode(timestamp)) + } + implicit class RichRow[T](val row: Row)(implicit codec: TimestampCodec) extends AnyRef { + def getTimestamp(rowName: String = "db_timestamp"): Instant = codec.decode(row, rowName) + } +} diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2DurableStateDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2DurableStateDao.scala index 27393ed0..3d6e312f 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2DurableStateDao.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2DurableStateDao.scala @@ -7,16 +7,16 @@ package akka.persistence.r2dbc.internal.h2 import scala.concurrent.ExecutionContext import scala.concurrent.duration.Duration import scala.concurrent.duration.FiniteDuration - import io.r2dbc.spi.ConnectionFactory import org.slf4j.Logger import org.slf4j.LoggerFactory - import akka.actor.typed.ActorSystem import akka.annotation.InternalApi import akka.persistence.r2dbc.R2dbcSettings import akka.persistence.r2dbc.internal.Dialect +import akka.persistence.r2dbc.internal.h2.sql.H2DurableStateSql import akka.persistence.r2dbc.internal.postgres.PostgresDurableStateDao +import akka.persistence.r2dbc.internal.postgres.sql.BaseDurableStateSql /** * INTERNAL API @@ -28,11 +28,8 @@ private[r2dbc] final class H2DurableStateDao( dialect: Dialect)(implicit ec: ExecutionContext, system: ActorSystem[_]) extends PostgresDurableStateDao(settings, connectionFactory, dialect) { - override protected lazy val log: Logger = LoggerFactory.getLogger(classOf[H2DurableStateDao]) + override val durableStateSql: BaseDurableStateSql = new H2DurableStateSql(settings) - protected override def behindCurrentTimeIntervalConditionFor(behindCurrentTime: FiniteDuration): String = - if (behindCurrentTime > Duration.Zero) - s"AND db_timestamp < CURRENT_TIMESTAMP - interval '${behindCurrentTime.toMillis.toDouble / 1000}' second" - else "" + override protected lazy val log: Logger = LoggerFactory.getLogger(classOf[H2DurableStateDao]) } diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/h2/sql/H2DurableStateSql.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/h2/sql/H2DurableStateSql.scala new file mode 100644 index 00000000..5286bda0 --- /dev/null +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/h2/sql/H2DurableStateSql.scala @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2022 - 2023 Lightbend Inc. + */ + +package akka.persistence.r2dbc.internal.h2.sql + +import akka.persistence.r2dbc.R2dbcSettings +import akka.persistence.r2dbc.internal.DurableStateDao.SerializedStateRow +import akka.persistence.r2dbc.internal.PayloadCodec.RichStatement +import akka.persistence.r2dbc.internal.Sql.Interpolation +import akka.persistence.r2dbc.internal.TimestampCodec.{ RichStatement => TimestampRichStatement } +import akka.persistence.r2dbc.internal.postgres.PostgresDurableStateDao.EvaluatedAdditionalColumnBindings +import akka.persistence.r2dbc.internal.postgres.sql.PostgresDurableStateSql +import akka.persistence.r2dbc.internal.{ PayloadCodec, TimestampCodec } +import akka.persistence.r2dbc.state.scaladsl.AdditionalColumn +import io.r2dbc.spi.Statement + +import java.lang +import java.time.Instant +import scala.collection.immutable +import scala.concurrent.duration.{ Duration, FiniteDuration } + +class H2DurableStateSql(settings: R2dbcSettings)(implicit + statePayloadCodec: PayloadCodec, + timestampCodec: TimestampCodec) + extends PostgresDurableStateSql(settings) { + protected override def behindCurrentTimeIntervalConditionFor(behindCurrentTime: FiniteDuration): String = + if (behindCurrentTime > Duration.Zero) + s"AND db_timestamp < CURRENT_TIMESTAMP - interval '${behindCurrentTime.toMillis.toDouble / 1000}' second" + else "" + +} diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresDurableStateDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresDurableStateDao.scala index 4037ae7f..6b5735a9 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresDurableStateDao.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresDurableStateDao.scala @@ -7,14 +7,12 @@ package akka.persistence.r2dbc.internal.postgres import java.lang import java.time.Instant import java.util - import scala.collection.immutable import scala.concurrent.ExecutionContext import scala.concurrent.Future import scala.concurrent.duration.Duration import scala.concurrent.duration.FiniteDuration import scala.util.control.NonFatal - import io.r2dbc.spi.Connection import io.r2dbc.spi.ConnectionFactory import io.r2dbc.spi.R2dbcDataIntegrityViolationException @@ -22,7 +20,6 @@ import io.r2dbc.spi.Row import io.r2dbc.spi.Statement import org.slf4j.Logger import org.slf4j.LoggerFactory - import akka.Done import akka.NotUsed import akka.actor.typed.ActorSystem @@ -35,20 +32,27 @@ import akka.persistence.query.DurableStateChange import akka.persistence.query.NoOffset import akka.persistence.query.UpdatedDurableState import akka.persistence.r2dbc.R2dbcSettings -import akka.persistence.r2dbc.internal.AdditionalColumnFactory +import akka.persistence.r2dbc.internal.{ + AdditionalColumnFactory, + ChangeHandlerFactory, + Dialect, + DurableStateDao, + InstantFactory, + JournalDao, + PayloadCodec, + R2dbcExecutor, + TimestampCodec +} import akka.persistence.r2dbc.internal.BySliceQuery.Buckets import akka.persistence.r2dbc.internal.BySliceQuery.Buckets.Bucket -import akka.persistence.r2dbc.internal.ChangeHandlerFactory -import akka.persistence.r2dbc.internal.Dialect -import akka.persistence.r2dbc.internal.DurableStateDao -import akka.persistence.r2dbc.internal.InstantFactory -import akka.persistence.r2dbc.internal.JournalDao import akka.persistence.r2dbc.internal.JournalDao.SerializedJournalRow -import akka.persistence.r2dbc.internal.PayloadCodec import akka.persistence.r2dbc.internal.PayloadCodec.RichRow -import akka.persistence.r2dbc.internal.PayloadCodec.RichStatement -import akka.persistence.r2dbc.internal.R2dbcExecutor +import akka.persistence.r2dbc.internal.TimestampCodec.{ + RichRow => TimestampRichRow, + RichStatement => TimestampRichStatement +} import akka.persistence.r2dbc.internal.Sql.Interpolation +import akka.persistence.r2dbc.internal.postgres.sql.{ BaseDurableStateSql, PostgresDurableStateSql } import akka.persistence.r2dbc.session.scaladsl.R2dbcSession import akka.persistence.r2dbc.state.ChangeHandlerException import akka.persistence.r2dbc.state.scaladsl.AdditionalColumn @@ -64,7 +68,8 @@ private[r2dbc] object PostgresDurableStateDao { private val log: Logger = LoggerFactory.getLogger(classOf[PostgresDurableStateDao]) - private final case class EvaluatedAdditionalColumnBindings( + // move this to a dialect independent place? + final case class EvaluatedAdditionalColumnBindings( additionalColumn: AdditionalColumn[_, _], binding: AdditionalColumn.Binding[_]) @@ -92,7 +97,10 @@ private[r2dbc] class PostgresDurableStateDao( settings.logDbCallsExceeding, settings.connectionFactorySettings.poolSettings.closeCallsExceeding)(ec, system) - private implicit val statePayloadCodec: PayloadCodec = settings.durableStatePayloadCodec + protected implicit val statePayloadCodec: PayloadCodec = settings.durableStatePayloadCodec + protected implicit val timestampCodec: TimestampCodec = settings.timestampCodec + + val durableStateSql: BaseDurableStateSql = new PostgresDurableStateSql(settings) // used for change events private lazy val journalDao: JournalDao = dialect.createJournalDao(settings, connectionFactory) @@ -111,181 +119,25 @@ private[r2dbc] class PostgresDurableStateDao( } } - private def selectStateSql(entityType: String): String = { - val stateTable = settings.getDurableStateTableWithSchema(entityType) - sql""" - SELECT revision, state_ser_id, state_ser_manifest, state_payload, db_timestamp - FROM $stateTable WHERE persistence_id = ?""" - } - - private def selectBucketsSql(entityType: String, minSlice: Int, maxSlice: Int): String = { - val stateTable = settings.getDurableStateTableWithSchema(entityType) - sql""" - SELECT extract(EPOCH from db_timestamp)::BIGINT / 10 AS bucket, count(*) AS count - FROM $stateTable - WHERE entity_type = ? - AND ${sliceCondition(minSlice, maxSlice)} - AND db_timestamp >= ? AND db_timestamp <= ? - GROUP BY bucket ORDER BY bucket LIMIT ? - """ - } - - protected def sliceCondition(minSlice: Int, maxSlice: Int): String = - s"slice in (${(minSlice to maxSlice).mkString(",")})" - - private def insertStateSql( - entityType: String, - additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { - val stateTable = settings.getDurableStateTableWithSchema(entityType) - val additionalCols = additionalInsertColumns(additionalBindings) - val additionalParams = additionalInsertParameters(additionalBindings) - sql""" - INSERT INTO $stateTable - (slice, entity_type, persistence_id, revision, state_ser_id, state_ser_manifest, state_payload, tags$additionalCols, db_timestamp) - VALUES (?, ?, ?, ?, ?, ?, ?, ?$additionalParams, CURRENT_TIMESTAMP)""" - } - - private def additionalInsertColumns( - additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { - if (additionalBindings.isEmpty) "" - else { - val strB = new lang.StringBuilder() - additionalBindings.foreach { - case EvaluatedAdditionalColumnBindings(c, _: AdditionalColumn.BindValue[_]) => - strB.append(", ").append(c.columnName) - case EvaluatedAdditionalColumnBindings(c, AdditionalColumn.BindNull) => - strB.append(", ").append(c.columnName) - case EvaluatedAdditionalColumnBindings(_, AdditionalColumn.Skip) => - } - strB.toString - } - } - - private def additionalInsertParameters( - additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { - if (additionalBindings.isEmpty) "" - else { - val strB = new lang.StringBuilder() - additionalBindings.foreach { - case EvaluatedAdditionalColumnBindings(_, _: AdditionalColumn.BindValue[_]) | - EvaluatedAdditionalColumnBindings(_, AdditionalColumn.BindNull) => - strB.append(", ?") - case EvaluatedAdditionalColumnBindings(_, AdditionalColumn.Skip) => - } - strB.toString - } - } - - private def updateStateSql( - entityType: String, - updateTags: Boolean, - additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { - val stateTable = settings.getDurableStateTableWithSchema(entityType) - - val timestamp = - if (settings.dbTimestampMonotonicIncreasing) - "CURRENT_TIMESTAMP" - else - "GREATEST(CURRENT_TIMESTAMP, " + - s"(SELECT db_timestamp + '1 microsecond'::interval FROM $stateTable WHERE persistence_id = ? AND revision = ?))" - - val revisionCondition = - if (settings.durableStateAssertSingleWriter) " AND revision = ?" - else "" - - val tags = if (updateTags) ", tags = ?" else "" - - val additionalParams = additionalUpdateParameters(additionalBindings) - sql""" - UPDATE $stateTable - SET revision = ?, state_ser_id = ?, state_ser_manifest = ?, state_payload = ?$tags$additionalParams, db_timestamp = $timestamp - WHERE persistence_id = ? - $revisionCondition""" - } - - private def additionalUpdateParameters( - additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { - if (additionalBindings.isEmpty) "" - else { - val strB = new lang.StringBuilder() - additionalBindings.foreach { - case EvaluatedAdditionalColumnBindings(col, _: AdditionalColumn.BindValue[_]) => - strB.append(", ").append(col.columnName).append(" = ?") - case EvaluatedAdditionalColumnBindings(col, AdditionalColumn.BindNull) => - strB.append(", ").append(col.columnName).append(" = ?") - case EvaluatedAdditionalColumnBindings(_, AdditionalColumn.Skip) => - } - strB.toString - } - } - - private def hardDeleteStateSql(entityType: String): String = { - val stateTable = settings.getDurableStateTableWithSchema(entityType) - sql"DELETE from $stateTable WHERE persistence_id = ?" - } - private val currentDbTimestampSql = sql"SELECT CURRENT_TIMESTAMP AS db_timestamp" - private def allPersistenceIdsSql(table: String): String = - sql"SELECT persistence_id from $table ORDER BY persistence_id LIMIT ?" - - private def persistenceIdsForEntityTypeSql(table: String): String = - sql"SELECT persistence_id from $table WHERE persistence_id LIKE ? ORDER BY persistence_id LIMIT ?" - - private def allPersistenceIdsAfterSql(table: String): String = - sql"SELECT persistence_id from $table WHERE persistence_id > ? ORDER BY persistence_id LIMIT ?" - - private def persistenceIdsForEntityTypeAfterSql(table: String): String = - sql"SELECT persistence_id from $table WHERE persistence_id LIKE ? AND persistence_id > ? ORDER BY persistence_id LIMIT ?" - - protected def behindCurrentTimeIntervalConditionFor(behindCurrentTime: FiniteDuration): String = - if (behindCurrentTime > Duration.Zero) - s"AND db_timestamp < CURRENT_TIMESTAMP - interval '${behindCurrentTime.toMillis} milliseconds'" - else "" - - protected def stateBySlicesRangeSql( - entityType: String, - maxDbTimestampParam: Boolean, - behindCurrentTime: FiniteDuration, - backtracking: Boolean, - minSlice: Int, - maxSlice: Int): String = { - val stateTable = settings.getDurableStateTableWithSchema(entityType) - - def maxDbTimestampParamCondition = - if (maxDbTimestampParam) s"AND db_timestamp < ?" else "" - - val behindCurrentTimeIntervalCondition = behindCurrentTimeIntervalConditionFor(behindCurrentTime) - - val selectColumns = - if (backtracking) - "SELECT persistence_id, revision, db_timestamp, CURRENT_TIMESTAMP AS read_db_timestamp, state_ser_id " - else - "SELECT persistence_id, revision, db_timestamp, CURRENT_TIMESTAMP AS read_db_timestamp, state_ser_id, state_ser_manifest, state_payload " - - sql""" - $selectColumns - FROM $stateTable - WHERE entity_type = ? - AND ${sliceCondition(minSlice, maxSlice)} - AND db_timestamp >= ? $maxDbTimestampParamCondition $behindCurrentTimeIntervalCondition - ORDER BY db_timestamp, revision - LIMIT ?""" - } + protected def sliceCondition(minSlice: Int, maxSlice: Int): String = + s"slice in (${(minSlice to maxSlice).mkString(",")})" override def readState(persistenceId: String): Future[Option[SerializedStateRow]] = { val entityType = PersistenceId.extractEntityType(persistenceId) r2dbcExecutor.selectOne(s"select [$persistenceId]")( - connection => - connection - .createStatement(selectStateSql(entityType)) - .bind(0, persistenceId), - row => + { connection => + val stmt = connection.createStatement(durableStateSql.selectStateSql(entityType)) + durableStateSql.bindForSelectStateSql(stmt, persistenceId) + + }, + (row: Row) => SerializedStateRow( persistenceId = persistenceId, revision = row.get[java.lang.Long]("revision", classOf[java.lang.Long]), - dbTimestamp = row.get("db_timestamp", classOf[Instant]), + dbTimestamp = row.getTimestamp(), readDbTimestamp = Instant.EPOCH, // not needed here payload = getPayload(row), serId = row.get[Integer]("state_ser_id", classOf[Integer]), @@ -371,16 +223,14 @@ private[r2dbc] class PostgresDurableStateDao( def insertStatement(connection: Connection): Statement = { val stmt = connection - .createStatement(insertStateSql(entityType, additionalBindings)) - .bind(getAndIncIndex(), slice) - .bind(getAndIncIndex(), entityType) - .bind(getAndIncIndex(), state.persistenceId) - .bind(getAndIncIndex(), state.revision) - .bind(getAndIncIndex(), state.serId) - .bind(getAndIncIndex(), state.serManifest) - .bindPayloadOption(getAndIncIndex(), state.payload) - bindTags(stmt, getAndIncIndex()) - bindAdditionalColumns(stmt, additionalBindings) + .createStatement(durableStateSql.insertStateSql(entityType, additionalBindings)) + durableStateSql.bindInsertStateForUpsertState( + stmt, + getAndIncIndex, + slice, + entityType, + state, + additionalBindings) } def recoverDataIntegrityViolation[A](f: Future[A]): Future[A] = @@ -411,33 +261,13 @@ private[r2dbc] class PostgresDurableStateDao( def updateStatement(connection: Connection): Statement = { val stmt = connection - .createStatement(updateStateSql(entityType, updateTags = true, additionalBindings)) - .bind(getAndIncIndex(), state.revision) - .bind(getAndIncIndex(), state.serId) - .bind(getAndIncIndex(), state.serManifest) - .bindPayloadOption(getAndIncIndex(), state.payload) - bindTags(stmt, getAndIncIndex()) - bindAdditionalColumns(stmt, additionalBindings) - - if (settings.dbTimestampMonotonicIncreasing) { - if (settings.durableStateAssertSingleWriter) - stmt - .bind(getAndIncIndex(), state.persistenceId) - .bind(getAndIncIndex(), previousRevision) - else - stmt - .bind(getAndIncIndex(), state.persistenceId) - } else { - stmt - .bind(getAndIncIndex(), state.persistenceId) - .bind(getAndIncIndex(), previousRevision) - .bind(getAndIncIndex(), state.persistenceId) - - if (settings.durableStateAssertSingleWriter) - stmt.bind(getAndIncIndex(), previousRevision) - else - stmt - } + .createStatement(durableStateSql.updateStateSql(entityType, updateTags = true, additionalBindings)) + durableStateSql.binUpdateStateSqlForUpsertState( + stmt, + getAndIncIndex, + state, + additionalBindings, + previousRevision) } if (!changeHandlers.contains(entityType) && changeEvent.isEmpty) { @@ -506,18 +336,13 @@ private[r2dbc] class PostgresDurableStateDao( val slice = persistenceExt.sliceForPersistenceId(persistenceId) def insertDeleteMarkerStatement(connection: Connection): Statement = { - connection + + val stmt = connection .createStatement( - insertStateSql(entityType, Vector.empty) + durableStateSql.insertStateSql(entityType, Vector.empty) ) // FIXME should the additional columns be cleared (null)? Then they must allow NULL - .bind(0, slice) - .bind(1, entityType) - .bind(2, persistenceId) - .bind(3, revision) - .bind(4, 0) - .bind(5, "") - .bindPayloadOption(6, None) - .bindNull(7, classOf[Array[String]]) + + durableStateSql.bindDeleteStateForInsertState(stmt, slice, entityType, persistenceId, revision) } def recoverDataIntegrityViolation[A](f: Future[A]): Future[A] = @@ -545,32 +370,9 @@ private[r2dbc] class PostgresDurableStateDao( def updateStatement(connection: Connection): Statement = { val stmt = connection .createStatement( - updateStateSql(entityType, updateTags = false, Vector.empty) + durableStateSql.updateStateSql(entityType, updateTags = false, Vector.empty) ) // FIXME should the additional columns be cleared (null)? Then they must allow NULL - .bind(0, revision) - .bind(1, 0) - .bind(2, "") - .bindPayloadOption(3, None) - - if (settings.dbTimestampMonotonicIncreasing) { - if (settings.durableStateAssertSingleWriter) - stmt - .bind(4, persistenceId) - .bind(5, previousRevision) - else - stmt - .bind(4, persistenceId) - } else { - stmt - .bind(4, persistenceId) - .bind(5, previousRevision) - .bind(6, persistenceId) - - if (settings.durableStateAssertSingleWriter) - stmt.bind(7, previousRevision) - else - stmt - } + durableStateSql.bindUpdateStateSqlForDeleteState(stmt, revision, persistenceId, previousRevision) } r2dbcExecutor.withConnection(s"delete [$persistenceId]") { connection => @@ -609,10 +411,10 @@ private[r2dbc] class PostgresDurableStateDao( val result = r2dbcExecutor.withConnection(s"hard delete [$persistenceId]$changeHandlerHint") { connection => for { - updatedRows <- R2dbcExecutor.updateOneInTx( - connection - .createStatement(hardDeleteStateSql(entityType)) - .bind(0, persistenceId)) + updatedRows <- R2dbcExecutor.updateOneInTx { + val stmt = connection.createStatement(durableStateSql.hardDeleteStateSql(entityType)) + durableStateSql.bindForHardDeleteState(stmt, persistenceId) + } _ <- { val change = new DeletedDurableState[Any](persistenceId, 0L, NoOffset, EmptyDbTimestamp.toEpochMilli) writeChangeEventAndCallChangeHander(connection, updatedRows, entityType, change, changeEvent = None) @@ -649,23 +451,14 @@ private[r2dbc] class PostgresDurableStateDao( connection => { val stmt = connection .createStatement( - stateBySlicesRangeSql( + durableStateSql.stateBySlicesRangeSql( entityType, maxDbTimestampParam = toTimestamp.isDefined, behindCurrentTime, backtracking, minSlice, maxSlice)) - .bind(0, entityType) - .bind(1, fromTimestamp) - toTimestamp match { - case Some(until) => - stmt.bind(2, until) - stmt.bind(3, settings.querySettings.bufferSize) - case None => - stmt.bind(2, settings.querySettings.bufferSize) - } - stmt + durableStateSql.bindForStateBySlicesRangeSql(stmt, entityType, fromTimestamp, toTimestamp, behindCurrentTime) }, row => if (backtracking) { @@ -677,8 +470,8 @@ private[r2dbc] class PostgresDurableStateDao( SerializedStateRow( persistenceId = row.get("persistence_id", classOf[String]), revision = row.get[java.lang.Long]("revision", classOf[java.lang.Long]), - dbTimestamp = row.get("db_timestamp", classOf[Instant]), - readDbTimestamp = row.get("read_db_timestamp", classOf[Instant]), + dbTimestamp = row.getTimestamp("db_timestamp"), + readDbTimestamp = row.getTimestamp("read_db_timestamp"), // payload = null => lazy loaded for backtracking (ugly, but not worth changing UpdatedDurableState in Akka) // payload = None => DeletedDurableState (no lazy loading) payload = if (isDeleted) None else null, @@ -690,8 +483,8 @@ private[r2dbc] class PostgresDurableStateDao( SerializedStateRow( persistenceId = row.get("persistence_id", classOf[String]), revision = row.get[java.lang.Long]("revision", classOf[java.lang.Long]), - dbTimestamp = row.get("db_timestamp", classOf[Instant]), - readDbTimestamp = row.get("read_db_timestamp", classOf[Instant]), + dbTimestamp = row.getTimestamp(), + readDbTimestamp = row.getTimestamp("read_db_timestamp"), payload = getPayload(row), serId = row.get[Integer]("state_ser_id", classOf[Integer]), serManifest = row.get("state_ser_manifest", classOf[String]), @@ -741,22 +534,18 @@ private[r2dbc] class PostgresDurableStateDao( Source.futureSource(result.map(Source(_))).mapMaterializedValue(_ => NotUsed) } - private def readPersistenceIds( - afterId: Option[String], - limit: Long, - table: String): Future[immutable.IndexedSeq[String]] = { + def readPersistenceIds(afterId: Option[String], limit: Long, table: String): Future[immutable.IndexedSeq[String]] = { val result = r2dbcExecutor.select(s"select persistenceIds")( connection => afterId match { case Some(after) => - connection - .createStatement(allPersistenceIdsAfterSql(table)) - .bind(0, after) - .bind(1, limit) + val stmt = connection.createStatement(durableStateSql.allPersistenceIdsAfterSql(table)) + durableStateSql.bindForAllPersistenceIdsAfter(stmt, after, limit) + case None => - connection - .createStatement(allPersistenceIdsSql(table)) - .bind(0, limit) + val stmt = connection.createStatement(durableStateSql.allPersistenceIdsSql(table)) + durableStateSql.bindForAllPersistenceIdsSql(stmt, limit) + }, row => row.get("persistence_id", classOf[String])) @@ -772,16 +561,13 @@ private[r2dbc] class PostgresDurableStateDao( connection => afterId match { case Some(after) => - connection - .createStatement(persistenceIdsForEntityTypeAfterSql(table)) - .bind(0, entityType + likeStmtPostfix) - .bind(1, after) - .bind(2, limit) + val stmt = connection.createStatement(durableStateSql.persistenceIdsForEntityTypeAfterSql(table)) + durableStateSql.bindPersistenceIdsForEntityTypeAfter(stmt, entityType + likeStmtPostfix, after, limit) + case None => - connection - .createStatement(persistenceIdsForEntityTypeSql(table)) - .bind(0, entityType + likeStmtPostfix) - .bind(1, limit) + val stmt = connection.createStatement(durableStateSql.persistenceIdsForEntityTypeSql(table)) + durableStateSql.bindPersistenceIdsForEntityType(stmt, entityType + likeStmtPostfix, limit) + }, row => row.get("persistence_id", classOf[String])) @@ -805,7 +591,7 @@ private[r2dbc] class PostgresDurableStateDao( limit: Int): Future[Seq[Bucket]] = { val toTimestamp = { - val now = InstantFactory.now() // not important to use database time + val now = timestampCodec.now() // not important to use database time if (fromTimestamp == Instant.EPOCH) now else { @@ -816,13 +602,10 @@ private[r2dbc] class PostgresDurableStateDao( } val result = r2dbcExecutor.select(s"select bucket counts [$minSlice - $maxSlice]")( - connection => - connection - .createStatement(selectBucketsSql(entityType, minSlice, maxSlice)) - .bind(0, entityType) - .bind(1, fromTimestamp) - .bind(2, toTimestamp) - .bind(3, limit), + connection => { + val stmt = connection.createStatement(durableStateSql.selectBucketsSql(entityType, minSlice, maxSlice)) + durableStateSql.bindSelectBucketsForCoundBuckets(stmt, entityType, fromTimestamp, toTimestamp, limit) + }, row => { val bucketStartEpochSeconds = row.get("bucket", classOf[java.lang.Long]).toLong * 10 val count = row.get[java.lang.Long]("count", classOf[java.lang.Long]).toLong diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresJournalDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresJournalDao.scala index a7a7c8e3..01e175f3 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresJournalDao.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresJournalDao.scala @@ -10,12 +10,15 @@ import akka.annotation.InternalApi import akka.dispatch.ExecutionContexts import akka.persistence.Persistence import akka.persistence.r2dbc.R2dbcSettings -import akka.persistence.r2dbc.internal.JournalDao -import akka.persistence.r2dbc.internal.PayloadCodec -import akka.persistence.r2dbc.internal.PayloadCodec.RichStatement -import akka.persistence.r2dbc.internal.R2dbcExecutor -import akka.persistence.r2dbc.internal.SerializedEventMetadata -import akka.persistence.r2dbc.internal.Sql.Interpolation +import akka.persistence.r2dbc.internal.{ + JournalDao, + PayloadCodec, + R2dbcExecutor, + SerializedEventMetadata, + TimestampCodec +} +import akka.persistence.r2dbc.internal.TimestampCodec.RichRow +import akka.persistence.r2dbc.internal.postgres.sql.{ BaseJournalSql, PostgresJournalSql } import akka.persistence.typed.PersistenceId import io.r2dbc.spi.Connection import io.r2dbc.spi.ConnectionFactory @@ -23,8 +26,8 @@ import io.r2dbc.spi.Row import io.r2dbc.spi.Statement import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.time.Instant +import java.time.Instant import scala.concurrent.ExecutionContext import scala.concurrent.Future @@ -61,6 +64,7 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti extends JournalDao { import JournalDao.SerializedJournalRow + protected def log: Logger = PostgresJournalDao.log private val persistenceExt = Persistence(system) @@ -74,61 +78,9 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti protected val journalTable = journalSettings.journalTableWithSchema protected implicit val journalPayloadCodec: PayloadCodec = journalSettings.journalPayloadCodec + protected implicit val timestampCodec: TimestampCodec = journalSettings.timestampCodec - private val (insertEventWithParameterTimestampSql, insertEventWithTransactionTimestampSql) = { - val baseSql = - s"INSERT INTO $journalTable " + - "(slice, entity_type, persistence_id, seq_nr, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, tags, meta_ser_id, meta_ser_manifest, meta_payload, db_timestamp) " + - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, " - - // The subselect of the db_timestamp of previous seqNr for same pid is to ensure that db_timestamp is - // always increasing for a pid (time not going backwards). - // TODO we could skip the subselect when inserting seqNr 1 as a possible optimization - def timestampSubSelect = - s"(SELECT db_timestamp + '1 microsecond'::interval FROM $journalTable " + - "WHERE persistence_id = ? AND seq_nr = ?)" - - val insertEventWithParameterTimestampSql = { - if (journalSettings.dbTimestampMonotonicIncreasing) - sql"$baseSql ?) RETURNING db_timestamp" - else - sql"$baseSql GREATEST(?, $timestampSubSelect)) RETURNING db_timestamp" - } - - val insertEventWithTransactionTimestampSql = { - if (journalSettings.dbTimestampMonotonicIncreasing) - sql"$baseSql CURRENT_TIMESTAMP) RETURNING db_timestamp" - else - sql"$baseSql GREATEST(CURRENT_TIMESTAMP, $timestampSubSelect)) RETURNING db_timestamp" - } - - (insertEventWithParameterTimestampSql, insertEventWithTransactionTimestampSql) - } - - private val selectHighestSequenceNrSql = sql""" - SELECT MAX(seq_nr) from $journalTable - WHERE persistence_id = ? AND seq_nr >= ?""" - - private val selectLowestSequenceNrSql = - sql""" - SELECT MIN(seq_nr) from $journalTable - WHERE persistence_id = ?""" - - private val deleteEventsSql = sql""" - DELETE FROM $journalTable - WHERE persistence_id = ? AND seq_nr >= ? AND seq_nr <= ?""" - private val insertDeleteMarkerSql = sql""" - INSERT INTO $journalTable - (slice, entity_type, persistence_id, seq_nr, db_timestamp, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, deleted) - VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, ?, ?, ?, ?, ?, ?)""" - - private val deleteEventsByPersistenceIdBeforeTimestampSql = sql""" - DELETE FROM $journalTable - WHERE persistence_id = ? AND db_timestamp < ?""" - - private val deleteEventsBySliceBeforeTimestampSql = sql""" - DELETE FROM $journalTable - WHERE slice = ? AND entity_type = ? AND db_timestamp < ?""" + protected val journalSql: BaseJournalSql = new PostgresJournalSql(journalSettings) /** * All events must be for the same persistenceId. @@ -150,16 +102,19 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti // The MigrationTool defines the dbTimestamp to preserve the original event timestamp val useTimestampFromDb = events.head.dbTimestamp == Instant.EPOCH - val insertSql = - if (useTimestampFromDb) insertEventWithTransactionTimestampSql - else insertEventWithParameterTimestampSql + val insertSql = journalSql.insertSql(useTimestampFromDb) val totalEvents = events.size if (totalEvents == 1) { val result = r2dbcExecutor.updateOneReturning(s"insert [$persistenceId]")( - connection => - bindInsertStatement(connection.createStatement(insertSql), events.head, useTimestampFromDb, previousSeqNr), - row => row.get(0, classOf[Instant])) + connection => { + journalSql.bindInsertForWriteEvent( + connection.createStatement(insertSql), + events.head, + useTimestampFromDb, + previousSeqNr) + }, + journalSql.parseInsertForWriteEvent) if (log.isDebugEnabled()) result.foreach { _ => log.debug("Wrote [{}] events for persistenceId [{}]", 1, persistenceId) @@ -170,9 +125,9 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti connection => events.foldLeft(connection.createStatement(insertSql)) { (stmt, write) => stmt.add() - bindInsertStatement(stmt, write, useTimestampFromDb, previousSeqNr) + journalSql.bindInsertForWriteEvent(stmt, write, useTimestampFromDb, previousSeqNr) }, - row => row.get(0, classOf[Instant])) + journalSql.parseInsertForWriteEvent) if (log.isDebugEnabled()) result.foreach { _ => log.debug("Wrote [{}] events for persistenceId [{}]", totalEvents, persistenceId) @@ -188,12 +143,14 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti // The MigrationTool defines the dbTimestamp to preserve the original event timestamp val useTimestampFromDb = event.dbTimestamp == Instant.EPOCH - val insertSql = - if (useTimestampFromDb) insertEventWithTransactionTimestampSql - else insertEventWithParameterTimestampSql + val insertSql = journalSql.insertSql(useTimestampFromDb) - val stmt = bindInsertStatement(connection.createStatement(insertSql), event, useTimestampFromDb, previousSeqNr) - val result = R2dbcExecutor.updateOneReturningInTx(stmt, row => row.get(0, classOf[Instant])) + val stmt = journalSql.bindInsertForWriteEvent( + connection.createStatement(insertSql), + event, + useTimestampFromDb, + previousSeqNr) + val result = R2dbcExecutor.updateOneReturningInTx(stmt, row => row.getTimestamp()) if (log.isDebugEnabled()) result.foreach { _ => log.debug("Wrote [{}] event for persistenceId [{}]", 1, persistenceId) @@ -201,68 +158,13 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti result } - private def bindInsertStatement( - stmt: Statement, - write: SerializedJournalRow, - useTimestampFromDb: Boolean, - previousSeqNr: Long): Statement = { - stmt - .bind(0, write.slice) - .bind(1, write.entityType) - .bind(2, write.persistenceId) - .bind(3, write.seqNr) - .bind(4, write.writerUuid) - .bind(5, "") // FIXME event adapter - .bind(6, write.serId) - .bind(7, write.serManifest) - .bindPayload(8, write.payload.get) - - if (write.tags.isEmpty) - stmt.bindNull(9, classOf[Array[String]]) - else - stmt.bind(9, write.tags.toArray) - - // optional metadata - write.metadata match { - case Some(m) => - stmt - .bind(10, m.serId) - .bind(11, m.serManifest) - .bind(12, m.payload) - case None => - stmt - .bindNull(10, classOf[Integer]) - .bindNull(11, classOf[String]) - .bindNull(12, classOf[Array[Byte]]) - } - - if (useTimestampFromDb) { - if (!journalSettings.dbTimestampMonotonicIncreasing) - stmt - .bind(13, write.persistenceId) - .bind(14, previousSeqNr) - } else { - if (journalSettings.dbTimestampMonotonicIncreasing) - stmt - .bind(13, write.dbTimestamp) - else - stmt - .bind(13, write.dbTimestamp) - .bind(14, write.persistenceId) - .bind(15, previousSeqNr) - } - - stmt - } - override def readHighestSequenceNr(persistenceId: String, fromSequenceNr: Long): Future[Long] = { val result = r2dbcExecutor .select(s"select highest seqNr [$persistenceId]")( - connection => - connection - .createStatement(selectHighestSequenceNrSql) - .bind(0, persistenceId) - .bind(1, fromSequenceNr), + { connection => + val stmt = connection.createStatement(journalSql.selectHighestSequenceNrSql) + journalSql.bindSelectHighestSequenceNrSql(stmt, persistenceId, fromSequenceNr) + }, row => { val seqNr = row.get(0, classOf[java.lang.Long]) if (seqNr eq null) 0L else seqNr.longValue @@ -278,10 +180,10 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti override def readLowestSequenceNr(persistenceId: String): Future[Long] = { val result = r2dbcExecutor .select(s"select lowest seqNr [$persistenceId]")( - connection => - connection - .createStatement(selectLowestSequenceNrSql) - .bind(0, persistenceId), + { connection => + val stmt = connection.createStatement(journalSql.selectLowestSequenceNrSql) + journalSql.bindSelectLowestSequenceNrSql(stmt, persistenceId) + }, row => { val seqNr = row.get(0, classOf[java.lang.Long]) if (seqNr eq null) 0L else seqNr.longValue @@ -294,14 +196,14 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti result } - protected def highestSeqNrForDelete(persistenceId: String, toSequenceNr: Long): Future[Long] = { + private def highestSeqNrForDelete(persistenceId: String, toSequenceNr: Long): Future[Long] = { if (toSequenceNr == Long.MaxValue) readHighestSequenceNr(persistenceId, 0L) else Future.successful(toSequenceNr) } - protected def lowestSequenceNrForDelete(persistenceId: String, toSeqNr: Long, batchSize: Int): Future[Long] = { + private def lowestSequenceNrForDelete(persistenceId: String, toSeqNr: Long, batchSize: Int): Future[Long] = { if (toSeqNr <= batchSize) { Future.successful(1L) } else { @@ -314,18 +216,9 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti def insertDeleteMarkerStmt(deleteMarkerSeqNr: Long, connection: Connection): Statement = { val entityType = PersistenceId.extractEntityType(persistenceId) val slice = persistenceExt.sliceForPersistenceId(persistenceId) - connection - .createStatement(insertDeleteMarkerSql) - .bind(0, slice) - .bind(1, entityType) - .bind(2, persistenceId) - .bind(3, deleteMarkerSeqNr) - .bind(4, "") - .bind(5, "") - .bind(6, 0) - .bind(7, "") - .bindPayloadOption(8, None) - .bind(9, true) + val stmt = connection + .createStatement(journalSql.insertDeleteMarkerSql) + journalSql.bindForInsertDeleteMarkerSql(stmt, slice, entityType, persistenceId, deleteMarkerSeqNr) } def deleteBatch(from: Long, to: Long, lastBatch: Boolean): Future[Unit] = { @@ -333,14 +226,18 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti r2dbcExecutor .update(s"delete [$persistenceId] and insert marker") { connection => Vector( - connection.createStatement(deleteEventsSql).bind(0, persistenceId).bind(1, from).bind(2, to), + { + val stmt = connection.createStatement(journalSql.deleteEventsSql) + journalSql.bindForDeleteEventsSql(stmt, persistenceId, from, to) + }, insertDeleteMarkerStmt(to, connection)) } .map(_.head) } else { r2dbcExecutor .updateOne(s"delete [$persistenceId]") { connection => - connection.createStatement(deleteEventsSql).bind(0, persistenceId).bind(1, from).bind(2, to) + val stmt = connection.createStatement(journalSql.deleteEventsSql) + journalSql.bindForDeleteEventsSql(stmt, persistenceId, from, to) } }).map(deletedRows => if (log.isDebugEnabled) { @@ -374,10 +271,7 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti override def deleteEventsBefore(persistenceId: String, timestamp: Instant): Future[Unit] = { r2dbcExecutor .updateOne(s"delete [$persistenceId]") { connection => - connection - .createStatement(deleteEventsByPersistenceIdBeforeTimestampSql) - .bind(0, persistenceId) - .bind(1, timestamp) + journalSql.deleteEventsByPersistenceIdBeforeTimestamp(connection.createStatement, persistenceId, timestamp) } .map(deletedRows => log.debugN("Deleted [{}] events for persistenceId [{}], before [{}]", deletedRows, persistenceId, timestamp))( @@ -387,11 +281,7 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti override def deleteEventsBefore(entityType: String, slice: Int, timestamp: Instant): Future[Unit] = { r2dbcExecutor .updateOne(s"delete [$entityType]") { connection => - connection - .createStatement(deleteEventsBySliceBeforeTimestampSql) - .bind(0, slice) - .bind(1, entityType) - .bind(2, timestamp) + journalSql.deleteEventsBySliceBeforeTimestamp(connection.createStatement, slice, entityType, timestamp) } .map(deletedRows => log.debugN( diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/sql/BaseDurableStateSql.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/sql/BaseDurableStateSql.scala new file mode 100644 index 00000000..9dc5fb76 --- /dev/null +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/sql/BaseDurableStateSql.scala @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2022 - 2023 Lightbend Inc. + */ + +package akka.persistence.r2dbc.internal.postgres.sql + +import akka.persistence.r2dbc.internal.DurableStateDao.SerializedStateRow +import akka.persistence.r2dbc.internal.postgres.PostgresDurableStateDao.EvaluatedAdditionalColumnBindings +import io.r2dbc.spi.Statement + +import java.time.Instant +import scala.collection.immutable +import scala.concurrent.duration.FiniteDuration + +trait BaseDurableStateSql { + def bindForStateBySlicesRangeSql( + stmt: Statement, + entityType: String, + fromTimestamp: Instant, + toTimestamp: Option[Instant], + behindCurrentTime: FiniteDuration): Statement + + def stateBySlicesRangeSql( + entityType: String, + maxDbTimestampParam: Boolean, + behindCurrentTime: FiniteDuration, + backtracking: Boolean, + minSlice: Int, + maxSlice: Int): String + + def bindForAllPersistenceIdsSql(stmt: Statement, limit: Long): Statement + + def allPersistenceIdsSql(table: String): String + + def bindForAllPersistenceIdsAfter(stmt: Statement, after: String, limit: Long): Statement + + def allPersistenceIdsAfterSql(table: String): String + + def bindPersistenceIdsForEntityType(stmt: Statement, str: String, limit: Long): Statement + + def persistenceIdsForEntityTypeSql(table: String): String + + def bindPersistenceIdsForEntityTypeAfter(stmt: Statement, str: String, after: String, limit: Long): Statement + + def bindSelectBucketsForCoundBuckets( + stmt: Statement, + entityType: String, + fromTimestamp: Instant, + toTimestamp: Instant, + limit: Int): Statement + + def persistenceIdsForEntityTypeAfterSql(table: String): String + + def selectBucketsSql(entityType: String, minSlice: Int, maxSlice: Int): String + + def bindForSelectStateSql(stmt: Statement, persistenceId: String): Statement + + def selectStateSql(entityType: String): String + + def bindUpdateStateSqlForDeleteState( + stmt: Statement, + revision: Long, + persistenceId: String, + previousRevision: Long): _root_.io.r2dbc.spi.Statement + + def bindDeleteStateForInsertState( + stmt: Statement, + slice: Int, + entityType: String, + persistenceId: String, + revision: Long): Statement + + def binUpdateStateSqlForUpsertState( + stmt: Statement, + getAndIncIndex: () => Int, + state: SerializedStateRow, + additionalBindings: IndexedSeq[EvaluatedAdditionalColumnBindings], + previousRevision: Long): Statement + + def bindInsertStateForUpsertState( + stmt: Statement, + getAndIncIndex: () => Int, + slice: Int, + entityType: String, + state: SerializedStateRow, + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): Statement + + def updateStateSql( + entityType: String, + updateTags: Boolean, + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String + + def insertStateSql( + entityType: String, + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String + + def bindForHardDeleteState(stmt: Statement, persistenceId: String): Statement + + def hardDeleteStateSql(entityType: String): String + + protected def sliceCondition(minSlice: Int, maxSlice: Int): String = + s"slice in (${(minSlice to maxSlice).mkString(",")})" +} diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/sql/BaseJournalSql.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/sql/BaseJournalSql.scala new file mode 100644 index 00000000..cc9b32e3 --- /dev/null +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/sql/BaseJournalSql.scala @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2022 - 2023 Lightbend Inc. + */ + +package akka.persistence.r2dbc.internal.postgres.sql + +import akka.persistence.r2dbc.internal.DurableStateDao.SerializedStateRow +import akka.persistence.r2dbc.internal.JournalDao +import akka.persistence.r2dbc.internal.JournalDao.SerializedJournalRow +import akka.persistence.r2dbc.internal.postgres.PostgresDurableStateDao.EvaluatedAdditionalColumnBindings +import io.r2dbc.spi.{ Row, Statement } + +import java.time.Instant +import scala.collection.immutable + +trait BaseJournalSql { + def bindForDeleteEventsSql(stmt: Statement, persistenceId: String, from: Long, to: Long): Statement + + def bindForInsertDeleteMarkerSql( + stmt: Statement, + slice: Int, + entityType: String, + persistenceId: String, + deleteMarkerSeqNr: Long): Statement + + val selectHighestSequenceNrSql: String + val selectLowestSequenceNrSql: String + val insertDeleteMarkerSql: String + val deleteEventsSql: String + + def bindSelectHighestSequenceNrSql(stmt: Statement, persistenceId: String, fromSequenceNr: Long): Statement + + def bindSelectLowestSequenceNrSql(stmt: Statement, persistenceId: String): Statement + + def parseInsertForWriteEvent(row: Row): Instant + + def deleteEventsBySliceBeforeTimestamp( + createStatement: String => Statement, + slice: Int, + entityType: String, + timestamp: Instant): Statement + + def deleteEventsByPersistenceIdBeforeTimestamp( + createStatement: String => Statement, + persistenceId: String, + timestamp: Instant): Statement + + def bindInsertForWriteEvent( + stmt: Statement, + write: SerializedJournalRow, + useTimestampFromDb: Boolean, + previousSeqNr: Long): Statement + + def insertSql(useTimestampFromDb: Boolean): String +} diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/sql/PostgresDurableStateSql.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/sql/PostgresDurableStateSql.scala new file mode 100644 index 00000000..810fafde --- /dev/null +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/sql/PostgresDurableStateSql.scala @@ -0,0 +1,379 @@ +/* + * Copyright (C) 2022 - 2023 Lightbend Inc. + */ + +package akka.persistence.r2dbc.internal.postgres.sql + +import akka.persistence.r2dbc.R2dbcSettings +import akka.persistence.r2dbc.internal.DurableStateDao.SerializedStateRow +import akka.persistence.r2dbc.internal.{ PayloadCodec, TimestampCodec } +import akka.persistence.r2dbc.internal.Sql.Interpolation +import akka.persistence.r2dbc.internal.postgres.PostgresDurableStateDao.EvaluatedAdditionalColumnBindings +import akka.persistence.r2dbc.state.scaladsl.AdditionalColumn +import io.r2dbc.spi.Statement +import akka.persistence.r2dbc.internal.PayloadCodec.RichStatement +import akka.persistence.r2dbc.internal.TimestampCodec.{ RichStatement => TimestampRichStatement } + +import java.lang +import java.time.Instant +import scala.collection.immutable +import scala.concurrent.duration.{ Duration, FiniteDuration } + +class PostgresDurableStateSql(settings: R2dbcSettings)(implicit + statePayloadCodec: PayloadCodec, + timestampCodec: TimestampCodec) + extends BaseDurableStateSql { + def bindUpdateStateSqlForDeleteState( + stmt: Statement, + revision: Long, + persistenceId: String, + previousRevision: Long): _root_.io.r2dbc.spi.Statement = { + + stmt + .bind(0, revision) + .bind(1, 0) + .bind(2, "") + .bindPayloadOption(3, None) + + if (settings.dbTimestampMonotonicIncreasing) { + if (settings.durableStateAssertSingleWriter) + stmt + .bind(4, persistenceId) + .bind(5, previousRevision) + else + stmt + .bind(4, persistenceId) + } else { + stmt + .bind(4, persistenceId) + .bind(5, previousRevision) + .bind(6, persistenceId) + + if (settings.durableStateAssertSingleWriter) + stmt.bind(7, previousRevision) + else + stmt + } + + } + + def bindDeleteStateForInsertState( + stmt: Statement, + slice: Int, + entityType: String, + persistenceId: String, + revision: Long): Statement = { + stmt + .bind(0, slice) + .bind(1, entityType) + .bind(2, persistenceId) + .bind(3, revision) + .bind(4, 0) + .bind(5, "") + .bindPayloadOption(6, None) + .bindNull(7, classOf[Array[String]]) + } + + def binUpdateStateSqlForUpsertState( + stmt: Statement, + getAndIncIndex: () => Int, + state: SerializedStateRow, + additionalBindings: IndexedSeq[EvaluatedAdditionalColumnBindings], + previousRevision: Long): Statement = { + stmt + .bind(getAndIncIndex(), state.revision) + .bind(getAndIncIndex(), state.serId) + .bind(getAndIncIndex(), state.serManifest) + .bindPayloadOption(getAndIncIndex(), state.payload) + bindTags(stmt, getAndIncIndex(), state) + bindAdditionalColumns(stmt, additionalBindings, getAndIncIndex) + + if (settings.dbTimestampMonotonicIncreasing) { + if (settings.durableStateAssertSingleWriter) + stmt + .bind(getAndIncIndex(), state.persistenceId) + .bind(getAndIncIndex(), previousRevision) + else + stmt + .bind(getAndIncIndex(), state.persistenceId) + } else { + stmt + .bind(getAndIncIndex(), state.persistenceId) + .bind(getAndIncIndex(), previousRevision) + .bind(getAndIncIndex(), state.persistenceId) + + if (settings.durableStateAssertSingleWriter) + stmt.bind(getAndIncIndex(), previousRevision) + else + stmt + } + } + + // duplicated for now + private def bindTags(stmt: Statement, i: Int, state: SerializedStateRow): Statement = { + if (state.tags.isEmpty) + stmt.bindNull(i, classOf[Array[String]]) + else + stmt.bind(i, state.tags.toArray) + } + + // duplicated for now + def bindAdditionalColumns( + stmt: Statement, + additionalBindings: IndexedSeq[EvaluatedAdditionalColumnBindings], + getAndIncIndex: () => Int): Statement = { + additionalBindings.foreach { + case EvaluatedAdditionalColumnBindings(_, AdditionalColumn.BindValue(v)) => + stmt.bind(getAndIncIndex(), v) + case EvaluatedAdditionalColumnBindings(col, AdditionalColumn.BindNull) => + stmt.bindNull(getAndIncIndex(), col.fieldClass) + case EvaluatedAdditionalColumnBindings(_, AdditionalColumn.Skip) => + } + stmt + } + + def bindInsertStateForUpsertState( + stmt: Statement, + getAndIncIndex: () => Int, + slice: Int, + entityType: String, + state: SerializedStateRow, + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]) = { + val st = stmt + .bind(getAndIncIndex(), slice) + .bind(getAndIncIndex(), entityType) + .bind(getAndIncIndex(), state.persistenceId) + .bind(getAndIncIndex(), state.revision) + .bind(getAndIncIndex(), state.serId) + .bind(getAndIncIndex(), state.serManifest) + .bindPayloadOption(getAndIncIndex(), state.payload) + bindTags(st, getAndIncIndex(), state) + bindAdditionalColumns(st, additionalBindings, getAndIncIndex) + } + + private def additionalUpdateParameters( + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { + if (additionalBindings.isEmpty) "" + else { + val strB = new lang.StringBuilder() + additionalBindings.foreach { + case EvaluatedAdditionalColumnBindings(col, _: AdditionalColumn.BindValue[_]) => + strB.append(", ").append(col.columnName).append(" = ?") + case EvaluatedAdditionalColumnBindings(col, AdditionalColumn.BindNull) => + strB.append(", ").append(col.columnName).append(" = ?") + case EvaluatedAdditionalColumnBindings(_, AdditionalColumn.Skip) => + } + strB.toString + } + } + + def updateStateSql( + entityType: String, + updateTags: Boolean, + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { + val stateTable = settings.getDurableStateTableWithSchema(entityType) + + val timestamp = + if (settings.dbTimestampMonotonicIncreasing) + "CURRENT_TIMESTAMP" + else + "GREATEST(CURRENT_TIMESTAMP, " + + s"(SELECT db_timestamp + '1 microsecond'::interval FROM $stateTable WHERE persistence_id = ? AND revision = ?))" + + val revisionCondition = + if (settings.durableStateAssertSingleWriter) " AND revision = ?" + else "" + + val tags = if (updateTags) ", tags = ?" else "" + + val additionalParams = additionalUpdateParameters(additionalBindings) + sql""" + UPDATE $stateTable + SET revision = ?, state_ser_id = ?, state_ser_manifest = ?, state_payload = ?$tags$additionalParams, db_timestamp = $timestamp + WHERE persistence_id = ? + $revisionCondition""" + } + + private def additionalInsertColumns( + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { + if (additionalBindings.isEmpty) "" + else { + val strB = new lang.StringBuilder() + additionalBindings.foreach { + case EvaluatedAdditionalColumnBindings(c, _: AdditionalColumn.BindValue[_]) => + strB.append(", ").append(c.columnName) + case EvaluatedAdditionalColumnBindings(c, AdditionalColumn.BindNull) => + strB.append(", ").append(c.columnName) + case EvaluatedAdditionalColumnBindings(_, AdditionalColumn.Skip) => + } + strB.toString + } + } + + private def additionalInsertParameters( + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { + if (additionalBindings.isEmpty) "" + else { + val strB = new lang.StringBuilder() + additionalBindings.foreach { + case EvaluatedAdditionalColumnBindings(_, _: AdditionalColumn.BindValue[_]) | + EvaluatedAdditionalColumnBindings(_, AdditionalColumn.BindNull) => + strB.append(", ?") + case EvaluatedAdditionalColumnBindings(_, AdditionalColumn.Skip) => + } + strB.toString + } + } + + def insertStateSql( + entityType: String, + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { + + val stateTable = settings.getDurableStateTableWithSchema(entityType) + val additionalCols = additionalInsertColumns(additionalBindings) + val additionalParams = additionalInsertParameters(additionalBindings) + sql""" + INSERT INTO $stateTable + (slice, entity_type, persistence_id, revision, state_ser_id, state_ser_manifest, state_payload, tags$additionalCols, db_timestamp) + VALUES (?, ?, ?, ?, ?, ?, ?, ?$additionalParams, CURRENT_TIMESTAMP)""" + + } + + def bindForHardDeleteState(stmt: Statement, persistenceId: String): Statement = { + stmt.bind(0, persistenceId) + } + + def hardDeleteStateSql(entityType: String): String = { + val stateTable = settings.getDurableStateTableWithSchema(entityType) + sql"DELETE from $stateTable WHERE persistence_id = ?" + } + + override def selectStateSql(entityType: String): String = { + val stateTable = settings.getDurableStateTableWithSchema(entityType) + sql""" + SELECT revision, state_ser_id, state_ser_manifest, state_payload, db_timestamp + FROM $stateTable WHERE persistence_id = ?""" + } + + override def bindForSelectStateSql(stmt: Statement, persistenceId: String): Statement = stmt.bind(0, persistenceId) + + override def selectBucketsSql(entityType: String, minSlice: Int, maxSlice: Int): String = { + val stateTable = settings.getDurableStateTableWithSchema(entityType) + sql""" + SELECT extract(EPOCH from db_timestamp)::BIGINT / 10 AS bucket, count(*) AS count + FROM $stateTable + WHERE entity_type = ? + AND ${sliceCondition(minSlice, maxSlice)} + AND db_timestamp >= ? AND db_timestamp <= ? + GROUP BY bucket ORDER BY bucket LIMIT ? + """ + } + + override def bindSelectBucketsForCoundBuckets( + stmt: Statement, + entityType: String, + fromTimestamp: Instant, + toTimestamp: Instant, + limit: Int): Statement = { + stmt + .bind(0, entityType) + .bindTimestamp(1, fromTimestamp) + .bindTimestamp(2, toTimestamp) + .bind(3, limit) + } + + override def persistenceIdsForEntityTypeAfterSql(table: String): String = + sql"SELECT persistence_id from $table WHERE persistence_id LIKE ? AND persistence_id > ? ORDER BY persistence_id LIMIT ?" + + override def bindPersistenceIdsForEntityTypeAfter( + stmt: Statement, + entityTypePluslikeStmtPostfix: String, + after: String, + limit: Long): Statement = { + stmt + .bind(0, entityTypePluslikeStmtPostfix) + .bind(1, after) + .bind(2, limit) + } + + override def persistenceIdsForEntityTypeSql(table: String): String = + sql"SELECT persistence_id from $table WHERE persistence_id LIKE ? ORDER BY persistence_id LIMIT ?" + + override def bindPersistenceIdsForEntityType( + stmt: Statement, + entityTypePlusLikeStmtPostfix: String, + limit: Long): Statement = { + stmt + .bind(0, entityTypePlusLikeStmtPostfix) + .bind(1, limit) + } + + override def bindForAllPersistenceIdsAfter(stmt: Statement, after: String, limit: Long): Statement = + stmt + .bind(0, after) + .bind(1, limit) + + override def allPersistenceIdsAfterSql(table: String): String = + sql"SELECT persistence_id from $table WHERE persistence_id > ? ORDER BY persistence_id LIMIT ?" + + override def allPersistenceIdsSql(table: String): String = + sql"SELECT persistence_id from $table ORDER BY persistence_id LIMIT ?" + + override def bindForAllPersistenceIdsSql(stmt: Statement, limit: Long): Statement = + stmt.bind(0, limit) + + protected def behindCurrentTimeIntervalConditionFor(behindCurrentTime: FiniteDuration): String = + if (behindCurrentTime > Duration.Zero) + s"AND db_timestamp < CURRENT_TIMESTAMP - interval '${behindCurrentTime.toMillis} milliseconds'" + else "" + + override def stateBySlicesRangeSql( + entityType: String, + maxDbTimestampParam: Boolean, + behindCurrentTime: FiniteDuration, + backtracking: Boolean, + minSlice: Int, + maxSlice: Int): String = { + val stateTable = settings.getDurableStateTableWithSchema(entityType) + + def maxDbTimestampParamCondition = + if (maxDbTimestampParam) s"AND db_timestamp < ?" else "" + + val behindCurrentTimeIntervalCondition = behindCurrentTimeIntervalConditionFor(behindCurrentTime) + + val selectColumns = + if (backtracking) + "SELECT persistence_id, revision, db_timestamp, CURRENT_TIMESTAMP AS read_db_timestamp, state_ser_id " + else + "SELECT persistence_id, revision, db_timestamp, CURRENT_TIMESTAMP AS read_db_timestamp, state_ser_id, state_ser_manifest, state_payload " + + sql""" + $selectColumns + FROM $stateTable + WHERE entity_type = ? + AND ${sliceCondition(minSlice, maxSlice)} + AND db_timestamp >= ? $maxDbTimestampParamCondition $behindCurrentTimeIntervalCondition + ORDER BY db_timestamp, revision + LIMIT ?""" + } + + override def bindForStateBySlicesRangeSql( + stmt: Statement, + entityType: String, + fromTimestamp: Instant, + toTimestamp: Option[Instant], + behindCurrentTime: FiniteDuration): Statement = { + stmt + .bind(0, entityType) + .bind(1, fromTimestamp) + toTimestamp match { + case Some(until) => + stmt.bind(2, until) + stmt.bind(3, settings.querySettings.bufferSize) + case None => + stmt.bind(2, settings.querySettings.bufferSize) + } + stmt + } + +} diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/sql/PostgresJournalSql.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/sql/PostgresJournalSql.scala new file mode 100644 index 00000000..b389a33f --- /dev/null +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/sql/PostgresJournalSql.scala @@ -0,0 +1,193 @@ +/* + * Copyright (C) 2022 - 2023 Lightbend Inc. + */ + +package akka.persistence.r2dbc.internal.postgres.sql + +import akka.persistence.r2dbc.R2dbcSettings +import akka.persistence.r2dbc.internal.JournalDao.SerializedJournalRow +import akka.persistence.r2dbc.internal.PayloadCodec +import akka.persistence.r2dbc.internal.Sql.Interpolation +import akka.persistence.r2dbc.internal.PayloadCodec.RichStatement +import io.r2dbc.spi.{ Row, Statement } + +import java.time.Instant + +class PostgresJournalSql(journalSettings: R2dbcSettings)(implicit statePayloadCodec: PayloadCodec) + extends BaseJournalSql { + + private val journalTable = journalSettings.journalTableWithSchema + + private val (insertEventWithParameterTimestampSql, insertEventWithTransactionTimestampSql) = { + val baseSql = + s"INSERT INTO $journalTable " + + "(slice, entity_type, persistence_id, seq_nr, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, tags, meta_ser_id, meta_ser_manifest, meta_payload, db_timestamp) " + + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, " + + // The subselect of the db_timestamp of previous seqNr for same pid is to ensure that db_timestamp is + // always increasing for a pid (time not going backwards). + // TODO we could skip the subselect when inserting seqNr 1 as a possible optimization + def timestampSubSelect = + s"(SELECT db_timestamp + '1 microsecond'::interval FROM $journalTable " + + "WHERE persistence_id = ? AND seq_nr = ?)" + + val insertEventWithParameterTimestampSql = { + if (journalSettings.dbTimestampMonotonicIncreasing) + sql"$baseSql ?) RETURNING db_timestamp" + else + sql"$baseSql GREATEST(?, $timestampSubSelect)) RETURNING db_timestamp" + } + + val insertEventWithTransactionTimestampSql = { + if (journalSettings.dbTimestampMonotonicIncreasing) + sql"$baseSql CURRENT_TIMESTAMP) RETURNING db_timestamp" + else + sql"$baseSql GREATEST(CURRENT_TIMESTAMP, $timestampSubSelect)) RETURNING db_timestamp" + } + + (insertEventWithParameterTimestampSql, insertEventWithTransactionTimestampSql) + } + + override def insertSql(useTimestampFromDb: Boolean): String = if (useTimestampFromDb) + insertEventWithTransactionTimestampSql + else insertEventWithParameterTimestampSql + + private val deleteEventsBySliceBeforeTimestampSql = + sql""" + DELETE FROM $journalTable + WHERE slice = ? AND entity_type = ? AND db_timestamp < ?""" + + def deleteEventsBySliceBeforeTimestamp( + createStatement: String => Statement, + slice: Int, + entityType: String, + timestamp: Instant): Statement = { + createStatement(deleteEventsBySliceBeforeTimestampSql) + .bind(0, slice) + .bind(1, entityType) + .bind(2, timestamp) + } + + def deleteEventsByPersistenceIdBeforeTimestamp( + createStatement: String => Statement, + persistenceId: String, + timestamp: Instant): Statement = { + createStatement(deleteEventsByPersistenceIdBeforeTimestampSql) + .bind(0, persistenceId) + .bind(1, timestamp) + } + + private val deleteEventsByPersistenceIdBeforeTimestampSql = + sql""" + DELETE FROM $journalTable + WHERE persistence_id = ? AND db_timestamp < ?""" + + def bindInsertForWriteEvent( + stmt: Statement, + write: SerializedJournalRow, + useTimestampFromDb: Boolean, + previousSeqNr: Long): Statement = { + stmt + .bind(0, write.slice) + .bind(1, write.entityType) + .bind(2, write.persistenceId) + .bind(3, write.seqNr) + .bind(4, write.writerUuid) + .bind(5, "") // FIXME event adapter + .bind(6, write.serId) + .bind(7, write.serManifest) + .bindPayload(8, write.payload.get) + + if (write.tags.isEmpty) + stmt.bindNull(9, classOf[Array[String]]) + else + stmt.bind(9, write.tags.toArray) + + // optional metadata + write.metadata match { + case Some(m) => + stmt + .bind(10, m.serId) + .bind(11, m.serManifest) + .bind(12, m.payload) + case None => + stmt + .bindNull(10, classOf[Integer]) + .bindNull(11, classOf[String]) + .bindNull(12, classOf[Array[Byte]]) + } + + if (useTimestampFromDb) { + if (!journalSettings.dbTimestampMonotonicIncreasing) + stmt + .bind(13, write.persistenceId) + .bind(14, previousSeqNr) + } else { + if (journalSettings.dbTimestampMonotonicIncreasing) + stmt + .bind(13, write.dbTimestamp) + else + stmt + .bind(13, write.dbTimestamp) + .bind(14, write.persistenceId) + .bind(15, previousSeqNr) + } + + stmt + + } + + override def parseInsertForWriteEvent(row: Row): Instant = row.get(0, classOf[Instant]) + + override val selectHighestSequenceNrSql = + sql""" + SELECT MAX(seq_nr) from $journalTable + WHERE persistence_id = ? AND seq_nr >= ?""" + + override def bindSelectHighestSequenceNrSql(stmt: Statement, persistenceId: String, fromSequenceNr: Long): Statement = + stmt + .bind(0, persistenceId) + .bind(1, fromSequenceNr) + + override val selectLowestSequenceNrSql = + sql""" + SELECT MIN(seq_nr) from $journalTable + WHERE persistence_id = ?""" + + override def bindSelectLowestSequenceNrSql(stmt: Statement, persistenceId: String): Statement = + stmt + .bind(0, persistenceId) + + override val insertDeleteMarkerSql = + sql""" + INSERT INTO $journalTable + (slice, entity_type, persistence_id, seq_nr, db_timestamp, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, deleted) + VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, ?, ?, ?, ?, ?, ?)""" + + override def bindForInsertDeleteMarkerSql( + stmt: Statement, + slice: Int, + entityType: String, + persistenceId: String, + deleteMarkerSeqNr: Long): Statement = { + stmt + .bind(0, slice) + .bind(1, entityType) + .bind(2, persistenceId) + .bind(3, deleteMarkerSeqNr) + .bind(4, "") + .bind(5, "") + .bind(6, 0) + .bind(7, "") + .bindPayloadOption(8, None) + .bind(9, true) + } + + override val deleteEventsSql = + sql""" + DELETE FROM $journalTable + WHERE persistence_id = ? AND seq_nr >= ? AND seq_nr <= ?""" + + override def bindForDeleteEventsSql(stmt: Statement, persistenceId: String, from: Long, to: Long): Statement = + stmt.bind(0, persistenceId).bind(1, from).bind(2, to) +} diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialect.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialect.scala index 3b6ab6c9..009f1d9f 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialect.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialect.scala @@ -86,5 +86,5 @@ private[r2dbc] object SqlServerDialect extends Dialect { override def createDurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit system: ActorSystem[_]): DurableStateDao = - new SqlServerDurableStateDao(settings, connectionFactory)(system.executionContext, system) + new SqlServerDurableStateDao(settings, connectionFactory, this)(system.executionContext, system) } diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialectHelper.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialectHelper.scala index 9e974bcd..10414d5f 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialectHelper.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialectHelper.scala @@ -27,6 +27,11 @@ private[r2dbc] object SqlServerDialectHelper { @InternalApi private[r2dbc] class SqlServerDialectHelper(config: Config) { + /** + * @todo + * This helper should be converted into a implicit codec + */ + private val tagSeparator = config.getString("tag-separator") require(tagSeparator.length == 1, s"Tag separator '$tagSeparator' must be a single character.") @@ -44,17 +49,4 @@ private[r2dbc] class SqlServerDialectHelper(config: Config) { case entries => entries.split(tagSeparator).toSet } - private val zone = TimeZone.getTimeZone("UTC").toZoneId - - def nowInstant(): Instant = InstantFactory.now() - - def nowLocalDateTime(): LocalDateTime = LocalDateTime.ofInstant(nowInstant(), zone) - - def toDbTimestamp(timestamp: Instant): LocalDateTime = - LocalDateTime.ofInstant(timestamp, zone) - - def fromDbTimestamp(time: LocalDateTime): Instant = time - .atZone(zone) - .toInstant - } diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDurableStateDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDurableStateDao.scala index ef1c91a4..caecb951 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDurableStateDao.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDurableStateDao.scala @@ -16,16 +16,26 @@ import akka.persistence.query.DurableStateChange import akka.persistence.query.NoOffset import akka.persistence.query.UpdatedDurableState import akka.persistence.r2dbc.R2dbcSettings -import akka.persistence.r2dbc.internal.AdditionalColumnFactory +import akka.persistence.r2dbc.internal.{ + AdditionalColumnFactory, + ChangeHandlerFactory, + Dialect, + DurableStateDao, + InstantFactory, + PayloadCodec, + R2dbcExecutor, + TimestampCodec +} import akka.persistence.r2dbc.internal.BySliceQuery.Buckets import akka.persistence.r2dbc.internal.BySliceQuery.Buckets.Bucket -import akka.persistence.r2dbc.internal.ChangeHandlerFactory -import akka.persistence.r2dbc.internal.DurableStateDao -import akka.persistence.r2dbc.internal.PayloadCodec +import akka.persistence.r2dbc.internal.JournalDao.SerializedJournalRow import akka.persistence.r2dbc.internal.PayloadCodec.RichRow -import akka.persistence.r2dbc.internal.PayloadCodec.RichStatement -import akka.persistence.r2dbc.internal.R2dbcExecutor +import akka.persistence.r2dbc.internal.TimestampCodec.{ RichRow => TimestampRichRow } +import akka.persistence.r2dbc.internal.TimestampCodec.{ RichStatement => TimestampRichStatement } import akka.persistence.r2dbc.internal.Sql.Interpolation +import akka.persistence.r2dbc.internal.postgres.PostgresDurableStateDao +import akka.persistence.r2dbc.internal.postgres.sql.{ BaseDurableStateSql, PostgresDurableStateSql } +import akka.persistence.r2dbc.internal.sqlserver.sql.SqlServerDurableStateSql import akka.persistence.r2dbc.session.scaladsl.R2dbcSession import akka.persistence.r2dbc.state.ChangeHandlerException import akka.persistence.r2dbc.state.scaladsl.AdditionalColumn @@ -71,688 +81,19 @@ private[r2dbc] object SqlServerDurableStateDao { * INTERNAL API */ @InternalApi -private[r2dbc] class SqlServerDurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit - ec: ExecutionContext, - system: ActorSystem[_]) - extends DurableStateDao { - import DurableStateDao._ - import SqlServerDurableStateDao._ - protected def log: Logger = SqlServerDurableStateDao.log - - private val helper = SqlServerDialectHelper(settings.connectionFactorySettings.config) - import helper._ - - private val persistenceExt = Persistence(system) - protected val r2dbcExecutor = new R2dbcExecutor( - connectionFactory, - log, - settings.logDbCallsExceeding, - settings.connectionFactorySettings.poolSettings.closeCallsExceeding)(ec, system) - - private implicit val statePayloadCodec: PayloadCodec = settings.durableStatePayloadCodec - - private lazy val additionalColumns: Map[String, immutable.IndexedSeq[AdditionalColumn[Any, Any]]] = { - settings.durableStateAdditionalColumnClasses.map { case (entityType, columnClasses) => - val instances = columnClasses.map(fqcn => AdditionalColumnFactory.create(system, fqcn)) - entityType -> instances - } - } - - private lazy val changeHandlers: Map[String, ChangeHandler[Any]] = { - settings.durableStateChangeHandlerClasses.map { case (entityType, fqcn) => - val handler = ChangeHandlerFactory.create(system, fqcn) - entityType -> handler - } - } - - private def selectStateSql(entityType: String): String = { - val stateTable = settings.getDurableStateTableWithSchema(entityType) - sql""" - SELECT revision, state_ser_id, state_ser_manifest, state_payload, db_timestamp - FROM $stateTable WHERE persistence_id = @persistenceId""" - } - - private def selectBucketsSql(entityType: String, minSlice: Int, maxSlice: Int): String = { - val stateTable = settings.getDurableStateTableWithSchema(entityType) - - // group by column alias (bucket) needs a sub query - val subQuery = - s""" - | select TOP(@limit) CAST(DATEDIFF(s,'1970-01-01 00:00:00',db_timestamp) AS BIGINT) / 10 AS bucket - | FROM $stateTable - | WHERE entity_type = @entityType - | AND ${sliceCondition(minSlice, maxSlice)} - | AND db_timestamp >= @fromTimestamp AND db_timestamp <= @toTimestamp - |""".stripMargin - sql""" - SELECT bucket, count(*) as count from ($subQuery) as sub - GROUP BY bucket ORDER BY bucket - """ - } - - protected def sliceCondition(minSlice: Int, maxSlice: Int): String = - s"slice in (${(minSlice to maxSlice).mkString(",")})" - - private def insertStateSql( - entityType: String, - additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { - val stateTable = settings.getDurableStateTableWithSchema(entityType) - val additionalCols = additionalInsertColumns(additionalBindings) - val additionalParams = additionalInsertParameters(additionalBindings) - sql""" - INSERT INTO $stateTable - (slice, entity_type, persistence_id, revision, state_ser_id, state_ser_manifest, state_payload, tags$additionalCols, db_timestamp) - VALUES (@slice, @entityType, @persistenceId, @revision, @stateSerId, @stateSerManifest, @statePayload, @tags$additionalParams, @now)""" - } - - private def additionalInsertColumns( - additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { - if (additionalBindings.isEmpty) "" - else { - val strB = new lang.StringBuilder() - additionalBindings.foreach { - case EvaluatedAdditionalColumnBindings(c, bindBalue) if bindBalue != AdditionalColumn.Skip => - strB.append(", ").append(c.columnName) - case EvaluatedAdditionalColumnBindings(_, _) => - } - strB.toString - } - } - - private def additionalInsertParameters( - additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { - if (additionalBindings.isEmpty) "" - else { - val strB = new lang.StringBuilder() - additionalBindings.foreach { - case EvaluatedAdditionalColumnBindings(col, bindValue) if bindValue != AdditionalColumn.Skip => - strB.append(s", @${col.columnName}") - case EvaluatedAdditionalColumnBindings(_, _) => - } - strB.toString - } - } - - private def updateStateSql( - entityType: String, - updateTags: Boolean, - additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { - val stateTable = settings.getDurableStateTableWithSchema(entityType) - - val revisionCondition = - if (settings.durableStateAssertSingleWriter) " AND revision = @previousRevision" - else "" - - val tags = if (updateTags) ", tags = @tags" else "" - - val additionalParams = additionalUpdateParameters(additionalBindings) - sql""" - UPDATE $stateTable - SET revision = @revision, state_ser_id = @stateSerId, state_ser_manifest = @stateSerManifest, state_payload = @statePayload $tags $additionalParams, db_timestamp = @now - WHERE persistence_id = @persistenceId - $revisionCondition""" - } - - private def additionalUpdateParameters( - additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { - if (additionalBindings.isEmpty) "" - else { - val strB = new lang.StringBuilder() - additionalBindings.foreach { - case EvaluatedAdditionalColumnBindings(col, binValue) if binValue != AdditionalColumn.Skip => - strB.append(", ").append(col.columnName).append(s" = @${col.columnName}") - case EvaluatedAdditionalColumnBindings(_, _) => - } - strB.toString - } - } - - private def hardDeleteStateSql(entityType: String): String = { - val stateTable = settings.getDurableStateTableWithSchema(entityType) - sql"DELETE from $stateTable WHERE persistence_id = @persistenceId" - } - - private def allPersistenceIdsSql(table: String): String = - sql"SELECT TOP(@limit) persistence_id from $table ORDER BY persistence_id" - - private def persistenceIdsForEntityTypeSql(table: String): String = - sql"SELECT TOP(@limit) persistence_id from $table WHERE persistence_id LIKE @persistenceIdLike ORDER BY persistence_id" - - private def allPersistenceIdsAfterSql(table: String): String = - sql"SELECT TOP(@limit) persistence_id from $table WHERE persistence_id > @persistenceId ORDER BY persistence_id" - - private def persistenceIdsForEntityTypeAfterSql(table: String): String = - sql"SELECT TOP(@limit) persistence_id from $table WHERE persistence_id LIKE @persistenceIdLike AND persistence_id > @persistenceId ORDER BY persistence_id" - - protected def behindCurrentTimeIntervalConditionFor(behindCurrentTime: FiniteDuration): String = - if (behindCurrentTime > Duration.Zero) - s"AND db_timestamp < DATEADD(ms, -${behindCurrentTime.toMillis}, @now)" - else "" - - protected def stateBySlicesRangeSql( - entityType: String, - maxDbTimestampParam: Boolean, - behindCurrentTime: FiniteDuration, - backtracking: Boolean, - minSlice: Int, - maxSlice: Int): String = { - val stateTable = settings.getDurableStateTableWithSchema(entityType) - - def maxDbTimestampParamCondition = - if (maxDbTimestampParam) s"AND db_timestamp < @until" else "" - - val behindCurrentTimeIntervalCondition = behindCurrentTimeIntervalConditionFor(behindCurrentTime) - - val selectColumns = - if (backtracking) - "SELECT TOP(@limit) persistence_id, revision, db_timestamp, SYSUTCDATETIME() AS read_db_timestamp, state_ser_id " - else - "SELECT TOP(@limit) persistence_id, revision, db_timestamp, SYSUTCDATETIME() AS read_db_timestamp, state_ser_id, state_ser_manifest, state_payload " - - sql""" - $selectColumns - FROM $stateTable - WHERE entity_type = @entityType - AND ${sliceCondition(minSlice, maxSlice)} - AND db_timestamp >= @fromTimestamp $maxDbTimestampParamCondition $behindCurrentTimeIntervalCondition - ORDER BY db_timestamp, revision""" - } - - def readState(persistenceId: String): Future[Option[SerializedStateRow]] = { - val entityType = PersistenceId.extractEntityType(persistenceId) - r2dbcExecutor.selectOne(s"select [$persistenceId]")( - connection => - connection - .createStatement(selectStateSql(entityType)) - .bind("@persistenceId", persistenceId), - row => - SerializedStateRow( - persistenceId = persistenceId, - revision = row.get[java.lang.Long]("revision", classOf[java.lang.Long]), - dbTimestamp = fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime])), - readDbTimestamp = Instant.EPOCH, // not needed here - payload = getPayload(row), - serId = row.get[Integer]("state_ser_id", classOf[Integer]), - serManifest = row.get("state_ser_manifest", classOf[String]), - tags = Set.empty // tags not fetched in queries (yet) - )) - } - - private def getPayload(row: Row): Option[Array[Byte]] = { - val serId = row.get("state_ser_id", classOf[Integer]) - val rowPayload = row.getPayload("state_payload") - if (serId == 0 && (rowPayload == null || util.Arrays.equals(statePayloadCodec.nonePayload, rowPayload))) - None // delete marker - else - Option(rowPayload) - } - - def upsertState(state: SerializedStateRow, value: Any): Future[Done] = { - require(state.revision > 0) - - def bindTags(stmt: Statement, name: String): Statement = { - if (state.tags.isEmpty) - stmt.bindNull(name, classOf[String]) - else - stmt.bind(name, tagsToDb(state.tags)) - } - - def bindAdditionalColumns( - stmt: Statement, - additionalBindings: IndexedSeq[EvaluatedAdditionalColumnBindings]): Statement = { - additionalBindings.foreach { - case EvaluatedAdditionalColumnBindings(col, AdditionalColumn.BindValue(v)) => - stmt.bind(s"@${col.columnName}", v) - case EvaluatedAdditionalColumnBindings(col, AdditionalColumn.BindNull) => - stmt.bindNull(s"@${col.columnName}", col.fieldClass) - case EvaluatedAdditionalColumnBindings(_, AdditionalColumn.Skip) => - } - stmt - } - - def change = - new UpdatedDurableState[Any](state.persistenceId, state.revision, value, NoOffset, EmptyDbTimestamp.toEpochMilli) - - val entityType = PersistenceId.extractEntityType(state.persistenceId) - - val result = { - val additionalBindings = additionalColumns.get(entityType) match { - case None => Vector.empty[EvaluatedAdditionalColumnBindings] - case Some(columns) => - val slice = persistenceExt.sliceForPersistenceId(state.persistenceId) - val upsert = AdditionalColumn.Upsert(state.persistenceId, entityType, slice, state.revision, value) - columns.map(c => EvaluatedAdditionalColumnBindings(c, c.bind(upsert))) - } - - if (state.revision == 1) { - val slice = persistenceExt.sliceForPersistenceId(state.persistenceId) - - def insertStatement(connection: Connection): Statement = { - val stmt = connection - .createStatement(insertStateSql(entityType, additionalBindings)) - .bind("@slice", slice) - .bind("@entityType", entityType) - .bind("@persistenceId", state.persistenceId) - .bind("@revision", state.revision) - .bind("@stateSerId", state.serId) - .bind("@stateSerManifest", state.serManifest) - .bindPayloadOption("@statePayload", state.payload) - .bind("@now", nowLocalDateTime()) - bindTags(stmt, "@tags") - bindAdditionalColumns(stmt, additionalBindings) - } - - def recoverDataIntegrityViolation[A](f: Future[A]): Future[A] = - f.recoverWith { case _: R2dbcDataIntegrityViolationException => - Future.failed( - new IllegalStateException( - s"Insert failed: durable state for persistence id [${state.persistenceId}] already exists")) - } - - changeHandlers.get(entityType) match { - case None => - recoverDataIntegrityViolation(r2dbcExecutor.updateOne(s"insert [${state.persistenceId}]")(insertStatement)) - case Some(handler) => - r2dbcExecutor.withConnection(s"insert [${state.persistenceId}] with change handler") { connection => - for { - updatedRows <- recoverDataIntegrityViolation(R2dbcExecutor.updateOneInTx(insertStatement(connection))) - _ <- processChange(handler, connection, change) - } yield updatedRows - } - } - } else { - val previousRevision = state.revision - 1 - - def updateStatement(connection: Connection): Statement = { - - val query = updateStateSql(entityType, updateTags = true, additionalBindings) - val stmt = connection - .createStatement(query) - .bind("@revision", state.revision) - .bind("@stateSerId", state.serId) - .bind("@stateSerManifest", state.serManifest) - .bindPayloadOption("@statePayload", state.payload) - .bind("@now", nowLocalDateTime()) - .bind("@persistenceId", state.persistenceId) - bindTags(stmt, "@tags") - bindAdditionalColumns(stmt, additionalBindings) - - if (settings.durableStateAssertSingleWriter) { - stmt.bind("@previousRevision", previousRevision) - } - - stmt - } - - changeHandlers.get(entityType) match { - case None => - r2dbcExecutor.updateOne(s"update [${state.persistenceId}]")(updateStatement) - case Some(handler) => - r2dbcExecutor.withConnection(s"update [${state.persistenceId}] with change handler") { connection => - for { - updatedRows <- R2dbcExecutor.updateOneInTx(updateStatement(connection)) - _ <- if (updatedRows == 1) processChange(handler, connection, change) else FutureDone - } yield updatedRows - } - } - } - } - - result - .map { updatedRows => - if (updatedRows != 1) - throw new IllegalStateException( - s"Update failed: durable state for persistence id [${state.persistenceId}] could not be updated to revision [${state.revision}]") - else { - log.debug( - "Updated durable state for persistenceId [{}] to revision [{}]", - state.persistenceId, - state.revision) - Done - } - } - - } - - private def processChange( - handler: ChangeHandler[Any], - connection: Connection, - change: DurableStateChange[Any]): Future[Done] = { - val session = new R2dbcSession(connection) - - def excMessage(cause: Throwable): String = { - val (changeType, revision) = change match { - case upd: UpdatedDurableState[_] => "update" -> upd.revision - case del: DeletedDurableState[_] => "delete" -> del.revision - } - s"Change handler $changeType failed for [${change.persistenceId}] revision [$revision], due to ${cause.getMessage}" - } - - try handler.process(session, change).recoverWith { case cause => - Future.failed[Done](new ChangeHandlerException(excMessage(cause), cause)) - } catch { - case NonFatal(cause) => throw new ChangeHandlerException(excMessage(cause), cause) - } - } - - def deleteState(persistenceId: String, revision: Long): Future[Done] = { - if (revision == 0) { - hardDeleteState(persistenceId) - } else { - val result = { - val entityType = PersistenceId.extractEntityType(persistenceId) - def change = - new DeletedDurableState[Any](persistenceId, revision, NoOffset, EmptyDbTimestamp.toEpochMilli) - if (revision == 1) { - val slice = persistenceExt.sliceForPersistenceId(persistenceId) - - def insertDeleteMarkerStatement(connection: Connection): Statement = { - connection - .createStatement(insertStateSql(entityType, Vector.empty)) - .bind("@slice", slice) - .bind("@entityType", entityType) - .bind("@persistenceId", persistenceId) - .bind("@revision", revision) - .bind("@stateSerId", 0) - .bind("@stateSerManifest", "") - .bindPayloadOption("@statePayload", None) - .bindNull("@tags", classOf[String]) - .bind("@now", nowLocalDateTime()) - } - - def recoverDataIntegrityViolation[A](f: Future[A]): Future[A] = - f.recoverWith { case _: R2dbcDataIntegrityViolationException => - Future.failed(new IllegalStateException( - s"Insert delete marker with revision 1 failed: durable state for persistence id [$persistenceId] already exists")) - } - - val changeHandler = changeHandlers.get(entityType) - val changeHandlerHint = changeHandler.map(_ => " with change handler").getOrElse("") - - r2dbcExecutor.withConnection(s"insert delete marker [$persistenceId]$changeHandlerHint") { connection => - for { - updatedRows <- recoverDataIntegrityViolation( - R2dbcExecutor.updateOneInTx(insertDeleteMarkerStatement(connection))) - _ <- changeHandler match { - case None => FutureDone - case Some(handler) => processChange(handler, connection, change) - } - } yield updatedRows - } - - } else { - val previousRevision = revision - 1 - - def updateStatement(connection: Connection): Statement = { - connection - .createStatement( - updateStateSql(entityType, updateTags = false, Vector.empty) - ) // FIXME should the additional columns be cleared (null)? Then they must allow NULL - .bind("@revision", revision) - .bind("@stateSerId", 0) - .bind("@stateSerManifest", "") - .bindPayloadOption("@statePayload", None) - .bind("@now", nowLocalDateTime()) - .bind("@persistenceId", persistenceId) - .bind("@previousRevision", previousRevision) - } - - val changeHandler = changeHandlers.get(entityType) - val changeHandlerHint = changeHandler.map(_ => " with change handler").getOrElse("") - - r2dbcExecutor.withConnection(s"delete [$persistenceId]$changeHandlerHint") { connection => - for { - updatedRows <- R2dbcExecutor.updateOneInTx(updateStatement(connection)) - _ <- changeHandler match { - case None => FutureDone - case Some(handler) => processChange(handler, connection, change) - } - } yield updatedRows - } - } - } - - result.map { updatedRows => - if (updatedRows != 1) - throw new IllegalStateException( - s"Delete failed: durable state for persistence id [$persistenceId] could not be updated to revision [$revision]") - else { - log.debug("Deleted durable state for persistenceId [{}] to revision [{}]", persistenceId, revision) - Done - } - } - - } - } - - private def hardDeleteState(persistenceId: String): Future[Done] = { - val entityType = PersistenceId.extractEntityType(persistenceId) - - val changeHandler = changeHandlers.get(entityType) - val changeHandlerHint = changeHandler.map(_ => " with change handler").getOrElse("") - - val result = - r2dbcExecutor.withConnection(s"hard delete [$persistenceId]$changeHandlerHint") { connection => - for { - updatedRows <- R2dbcExecutor.updateOneInTx( - connection - .createStatement(hardDeleteStateSql(entityType)) - .bind("@persistenceId", persistenceId)) - _ <- changeHandler match { - case None => FutureDone - case Some(handler) => - val change = new DeletedDurableState[Any](persistenceId, 0L, NoOffset, EmptyDbTimestamp.toEpochMilli) - processChange(handler, connection, change) - } - } yield updatedRows - } - - if (log.isDebugEnabled()) - result.foreach(_ => log.debug("Hard deleted durable state for persistenceId [{}]", persistenceId)) - - result.map(_ => Done)(ExecutionContexts.parasitic) - } - - override def currentDbTimestamp(): Future[Instant] = Future.successful(nowInstant()) - - override def rowsBySlices( - entityType: String, - minSlice: Int, - maxSlice: Int, - fromTimestamp: Instant, - toTimestamp: Option[Instant], - behindCurrentTime: FiniteDuration, - backtracking: Boolean): Source[SerializedStateRow, NotUsed] = { - val result = r2dbcExecutor.select( - s"select stateBySlices [${settings.querySettings.bufferSize}, $entityType, $fromTimestamp, $toTimestamp, $minSlice - $maxSlice]")( - connection => { - val query = stateBySlicesRangeSql( - entityType, - maxDbTimestampParam = toTimestamp.isDefined, - behindCurrentTime, - backtracking, - minSlice, - maxSlice) - val stmt = connection - .createStatement(query) - .bind("@entityType", entityType) - .bind("@fromTimestamp", toDbTimestamp(fromTimestamp)) - - stmt.bind("@limit", settings.querySettings.bufferSize) - - if (behindCurrentTime > Duration.Zero) { - stmt.bind("@now", nowLocalDateTime()) - } - - toTimestamp.foreach(until => stmt.bind("@until", toDbTimestamp(until))) - - stmt - }, - row => - if (backtracking) { - val serId = row.get[Integer]("state_ser_id", classOf[Integer]) - // would have been better with an explicit deleted column as in the journal table, - // but not worth the schema change - val isDeleted = serId == 0 - - SerializedStateRow( - persistenceId = row.get("persistence_id", classOf[String]), - revision = row.get[java.lang.Long]("revision", classOf[java.lang.Long]), - dbTimestamp = fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime])), - readDbTimestamp = fromDbTimestamp(row.get("read_db_timestamp", classOf[LocalDateTime])), - // payload = null => lazy loaded for backtracking (ugly, but not worth changing UpdatedDurableState in Akka) - // payload = None => DeletedDurableState (no lazy loading) - payload = if (isDeleted) None else null, - serId = 0, - serManifest = "", - tags = Set.empty // tags not fetched in queries (yet) - ) - } else - SerializedStateRow( - persistenceId = row.get("persistence_id", classOf[String]), - revision = row.get[java.lang.Long]("revision", classOf[java.lang.Long]), - dbTimestamp = fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime])), - readDbTimestamp = fromDbTimestamp(row.get("read_db_timestamp", classOf[LocalDateTime])), - payload = getPayload(row), - serId = row.get[Integer]("state_ser_id", classOf[Integer]), - serManifest = row.get("state_ser_manifest", classOf[String]), - tags = Set.empty // tags not fetched in queries (yet) - )) - - if (log.isDebugEnabled) - result.foreach(rows => - log.debugN("Read [{}] durable states from slices [{} - {}]", rows.size, minSlice, maxSlice)) - - Source.futureSource(result.map(Source(_))).mapMaterializedValue(_ => NotUsed) - } - - def persistenceIds(afterId: Option[String], limit: Long): Source[String, NotUsed] = { - if (settings.durableStateTableByEntityTypeWithSchema.isEmpty) - persistenceIds(afterId, limit, settings.durableStateTableWithSchema) - else { - def readFromCustomTables( - acc: immutable.IndexedSeq[String], - remainingTables: Vector[String]): Future[immutable.IndexedSeq[String]] = { - if (acc.size >= limit) { - Future.successful(acc) - } else if (remainingTables.isEmpty) { - Future.successful(acc) - } else { - readPersistenceIds(afterId, limit, remainingTables.head).flatMap { ids => - readFromCustomTables(acc ++ ids, remainingTables.tail) - } - } - } - - val customTables = settings.durableStateTableByEntityTypeWithSchema.toVector.sortBy(_._1).map(_._2) - val ids = for { - fromDefaultTable <- readPersistenceIds(afterId, limit, settings.durableStateTableWithSchema) - fromCustomTables <- readFromCustomTables(Vector.empty, customTables) - } yield { - (fromDefaultTable ++ fromCustomTables).sorted - } - - Source.futureSource(ids.map(Source(_))).take(limit).mapMaterializedValue(_ => NotUsed) - } - } - - def persistenceIds(afterId: Option[String], limit: Long, table: String): Source[String, NotUsed] = { - val result = readPersistenceIds(afterId, limit, table) - - Source.futureSource(result.map(Source(_))).mapMaterializedValue(_ => NotUsed) - } - - private def readPersistenceIds( - afterId: Option[String], - limit: Long, - table: String): Future[immutable.IndexedSeq[String]] = { - val result = r2dbcExecutor.select(s"select persistenceIds")( - connection => - afterId match { - case Some(after) => - connection - .createStatement(allPersistenceIdsAfterSql(table)) - .bind("@persistenceId", after) - .bind("@limit", limit) - case None => - connection - .createStatement(allPersistenceIdsSql(table)) - .bind("@limit", limit) - }, - row => row.get("persistence_id", classOf[String])) - - if (log.isDebugEnabled) - result.foreach(rows => log.debug("Read [{}] persistence ids", rows.size)) - result - } - - def persistenceIds(entityType: String, afterId: Option[String], limit: Long): Source[String, NotUsed] = { - val table = settings.getDurableStateTableWithSchema(entityType) - val likeStmtPostfix = PersistenceId.DefaultSeparator + "%" - val result = r2dbcExecutor.select(s"select persistenceIds by entity type")( - connection => - afterId match { - case Some(afterPersistenceId) => - connection - .createStatement(persistenceIdsForEntityTypeAfterSql(table)) - .bind("@persistenceIdLike", entityType + likeStmtPostfix) - .bind("@persistenceId", afterPersistenceId) - .bind("@limit", limit) - case None => - connection - .createStatement(persistenceIdsForEntityTypeSql(table)) - .bind("@persistenceIdLike", entityType + likeStmtPostfix) - .bind("@limit", limit) - }, - row => row.get("persistence_id", classOf[String])) - - if (log.isDebugEnabled) - result.foreach(rows => log.debug("Read [{}] persistence ids by entity type [{}]", rows.size, entityType)) - - Source.futureSource(result.map(Source(_))).mapMaterializedValue(_ => NotUsed) - } - - /** - * Is this correct? - */ - override def countBucketsMayChange: Boolean = true - - override def countBuckets( - entityType: String, - minSlice: Int, - maxSlice: Int, - fromTimestamp: Instant, - limit: Int): Future[Seq[Bucket]] = { +private[r2dbc] class SqlServerDurableStateDao( + settings: R2dbcSettings, + connectionFactory: ConnectionFactory, + dialect: Dialect)(implicit ec: ExecutionContext, system: ActorSystem[_]) + extends PostgresDurableStateDao(settings, connectionFactory, dialect) { - val toTimestamp = { - val nowTimestamp = nowInstant() - if (fromTimestamp == Instant.EPOCH) - nowTimestamp - else { - // max buckets, just to have some upper bound - val t = fromTimestamp.plusSeconds(Buckets.BucketDurationSeconds * limit + Buckets.BucketDurationSeconds) - if (t.isAfter(nowTimestamp)) nowTimestamp else t - } - } + override def log: Logger = SqlServerDurableStateDao.log - val result = r2dbcExecutor.select(s"select bucket counts [$entityType $minSlice - $maxSlice]")( - connection => - connection - .createStatement(selectBucketsSql(entityType, minSlice, maxSlice)) - .bind("@entityType", entityType) - .bind("@fromTimestamp", toDbTimestamp(fromTimestamp)) - .bind("@toTimestamp", toDbTimestamp(toTimestamp)) - .bind("@limit", limit), - row => { - val bucketStartEpochSeconds = row.get("bucket", classOf[java.lang.Long]).toLong * 10 - val count = row.get[java.lang.Long]("count", classOf[java.lang.Long]).toLong - Bucket(bucketStartEpochSeconds, count) - }) + protected override implicit val statePayloadCodec: PayloadCodec = settings.durableStatePayloadCodec + protected override implicit val timestampCodec: TimestampCodec = settings.timestampCodec - if (log.isDebugEnabled) - result.foreach(rows => log.debugN("Read [{}] bucket counts from slices [{} - {}]", rows.size, minSlice, maxSlice)) + override val durableStateSql: BaseDurableStateSql = new SqlServerDurableStateSql(settings) - result + override def currentDbTimestamp(): Future[Instant] = Future.successful(timestampCodec.now()) - } } diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerJournalDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerJournalDao.scala index 7b345cee..6f23f231 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerJournalDao.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerJournalDao.scala @@ -11,12 +11,17 @@ import akka.dispatch.ExecutionContexts import akka.persistence.Persistence import akka.persistence.r2dbc.R2dbcSettings import akka.persistence.r2dbc.internal.PayloadCodec.RichStatement -import akka.persistence.r2dbc.internal.JournalDao -import akka.persistence.r2dbc.internal.PayloadCodec -import akka.persistence.r2dbc.internal.R2dbcExecutor -import akka.persistence.r2dbc.internal.SerializedEventMetadata +import akka.persistence.r2dbc.internal.{ + InstantFactory, + JournalDao, + PayloadCodec, + R2dbcExecutor, + SerializedEventMetadata +} import akka.persistence.r2dbc.internal.Sql.Interpolation import akka.persistence.r2dbc.internal.postgres.PostgresJournalDao +import akka.persistence.r2dbc.internal.postgres.sql.BaseJournalSql +import akka.persistence.r2dbc.internal.sqlserver.sql.SqlServerJournalSql import akka.persistence.typed.PersistenceId import io.r2dbc.spi.Connection import io.r2dbc.spi.ConnectionFactory @@ -39,18 +44,6 @@ private[r2dbc] object SqlServerJournalDao { val TRUE = 1 - def readMetadata(row: Row): Option[SerializedEventMetadata] = { - row.get("meta_payload", classOf[Array[Byte]]) match { - case null => None - case metaPayload => - Some( - SerializedEventMetadata( - serId = row.get[Integer]("meta_ser_id", classOf[Integer]), - serManifest = row.get("meta_ser_manifest", classOf[String]), - metaPayload)) - } - } - } /** @@ -64,259 +57,11 @@ private[r2dbc] class SqlServerJournalDao(settings: R2dbcSettings, connectionFact system: ActorSystem[_]) extends PostgresJournalDao(settings, connectionFactory) { - private val helper = SqlServerDialectHelper(settings.connectionFactorySettings.config) - import helper._ - require(settings.useAppTimestamp, "SqlServer requires akka.persistence.r2dbc.use-app-timestamp=on") require(settings.useAppTimestamp, "SqlServer requires akka.persistence.r2dbc.db-timestamp-monotonic-increasing = off") - import JournalDao.SerializedJournalRow override def log: Logger = SqlServerJournalDao.log - private val persistenceExt = Persistence(system) - -// protected val r2dbcExecutor = -// new R2dbcExecutor( -// connectionFactory, -// log, -// settings.logDbCallsExceeding, -// settings.connectionFactorySettings.poolSettings.closeCallsExceeding)(ec, system) - - //protected val journalTable = settings.journalTableWithSchema - //protected implicit val journalPayloadCodec: PayloadCodec = settings.journalPayloadCodec - - /** - * VALUES (@slice, @entityType, @persistenceId, @seqNr, @writer, @adapterManifest, @eventSerId, @eventSerManifest, - * @eventPayload, - * @tags, - * @metaSerId, - * @metaSerManifest, - * @metaSerPayload, - * @dbTimestamp) - */ - private val insertEventWithParameterTimestampSql = sql""" - INSERT INTO $journalTable - (slice, entity_type, persistence_id, seq_nr, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, tags, meta_ser_id, meta_ser_manifest, meta_payload, db_timestamp) - OUTPUT inserted.db_timestamp - VALUES (@slice, @entityType, @persistenceId, @seqNr, @writer, @adapterManifest, @eventSerId, @eventSerManifest, @eventPayload, @tags, @metaSerId, @metaSerManifest, @metaSerPayload, @dbTimestamp)""" - - private val selectHighestSequenceNrSql = sql""" - SELECT MAX(seq_nr) as max_seq_nr from $journalTable - WHERE persistence_id = @persistenceId AND seq_nr >= @seqNr""" - - private val selectLowestSequenceNrSql = - sql""" - SELECT MIN(seq_nr) as min_seq_nr from $journalTable - WHERE persistence_id = @persistenceId""" - - private val deleteEventsSql = sql""" - DELETE FROM $journalTable - WHERE persistence_id = @persistenceId AND seq_nr >= @from AND seq_nr <= @to""" - - private val insertDeleteMarkerSql = sql""" - INSERT INTO $journalTable(slice, entity_type, persistence_id, seq_nr, db_timestamp, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, deleted) - VALUES(@slice, @entityType, @persistenceId, @deleteMarkerSeqNr, @now, @writer, @adapterManifest, @eventSerId, @eventSerManifest, @eventPayload, @deleted )""" - - /** - * All events must be for the same persistenceId. - * - * The returned timestamp should be the `db_timestamp` column and it is used in published events when that feature is - * enabled. - * - * Note for implementing future database dialects: If a database dialect can't efficiently return the timestamp column - * it can return `JournalDao.EmptyDbTimestamp` when the pub-sub feature is disabled. When enabled it would have to use - * a select (in same transaction). - */ - override def writeEvents(events: Seq[SerializedJournalRow]): Future[Instant] = { - require(events.nonEmpty) - - // it's always the same persistenceId for all events - val persistenceId = events.head.persistenceId - - def bind(stmt: Statement, write: SerializedJournalRow): Statement = { - stmt - .bind(0, write.slice) - .bind(1, write.entityType) - .bind(2, write.persistenceId) - .bind(3, write.seqNr) - .bind(4, write.writerUuid) - .bind(5, "") // FIXME event adapter - .bind(6, write.serId) - .bind(7, write.serManifest) - .bindPayload(8, write.payload.get) - - if (write.tags.isEmpty) - stmt.bindNull(9, classOf[String]) - else - stmt.bind(9, tagsToDb(write.tags)) - - // optional metadata - write.metadata match { - case Some(m) => - stmt - .bind(10, m.serId) - .bind(11, m.serManifest) - .bind(12, m.payload) - case None => - stmt - .bindNull(10, classOf[Integer]) - .bindNull(11, classOf[String]) - .bindNull(12, classOf[Array[Byte]]) - } - stmt.bind(13, toDbTimestamp(write.dbTimestamp)) - } - - val insertSql = insertEventWithParameterTimestampSql - - val totalEvents = events.size - if (totalEvents == 1) { - val result = r2dbcExecutor.updateOneReturning(s"insert [$persistenceId]")( - connection => bind(connection.createStatement(insertSql), events.head), - row => fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime]))) - if (log.isDebugEnabled()) - result.foreach { _ => - log.debug("Wrote [{}] events for persistenceId [{}]", 1, events.head.persistenceId) - } - result - } else { - val result = r2dbcExecutor.updateInBatchReturning(s"batch insert [$persistenceId], [$totalEvents] events")( - connection => - events.foldLeft(connection.createStatement(insertSql)) { (stmt, write) => - stmt.add() - bind(stmt, write) - }, - row => fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime]))) - if (log.isDebugEnabled()) - result.foreach { _ => - log.debug("Wrote [{}] events for persistenceId [{}]", totalEvents, events.head.persistenceId) - } - result.map(_.head)(ExecutionContexts.parasitic) - } - } - - override def readHighestSequenceNr(persistenceId: String, fromSequenceNr: Long): Future[Long] = { - val result = r2dbcExecutor - .select(s"select highest seqNr [$persistenceId]")( - connection => - connection - .createStatement(selectHighestSequenceNrSql) - .bind("@persistenceId", persistenceId) - .bind("@seqNr", fromSequenceNr), - row => { - val seqNr = row.get("max_seq_nr", classOf[java.lang.Long]) - if (seqNr eq null) 0L else seqNr.longValue - }) - .map(r => if (r.isEmpty) 0L else r.head)(ExecutionContexts.parasitic) - - if (log.isDebugEnabled) - result.foreach(seqNr => log.debug("Highest sequence nr for persistenceId [{}]: [{}]", persistenceId, seqNr)) - - result - } - - override def readLowestSequenceNr(persistenceId: String): Future[Long] = { - val result = r2dbcExecutor - .select(s"select lowest seqNr [$persistenceId]")( - connection => - connection - .createStatement(selectLowestSequenceNrSql) - .bind("@persistenceId", persistenceId), - row => { - val seqNr = row.get("min_seq_nr", classOf[java.lang.Long]) - if (seqNr eq null) 0L else seqNr.longValue - }) - .map(r => if (r.isEmpty) 0L else r.head)(ExecutionContexts.parasitic) - - if (log.isDebugEnabled) - result.foreach(seqNr => log.debug("Lowest sequence nr for persistenceId [{}]: [{}]", persistenceId, seqNr)) - - result - } - -// private def highestSeqNrForDelete(persistenceId: String, toSequenceNr: Long): Future[Long] = { -// if (toSequenceNr == Long.MaxValue) -// readHighestSequenceNr(persistenceId, 0L) -// else -// Future.successful(toSequenceNr) -// } - -// private def lowestSequenceNrForDelete(persistenceId: String, toSeqNr: Long, batchSize: Int): Future[Long] = { -// if (toSeqNr <= batchSize) { -// Future.successful(1L) -// } else { -// readLowestSequenceNr(persistenceId) -// } -// } - - override def deleteEventsTo(persistenceId: String, toSequenceNr: Long, resetSequenceNumber: Boolean): Future[Unit] = { - - def insertDeleteMarkerStmt(deleteMarkerSeqNr: Long, connection: Connection): Statement = { - val entityType = PersistenceId.extractEntityType(persistenceId) - val slice = persistenceExt.sliceForPersistenceId(persistenceId) - connection - .createStatement(insertDeleteMarkerSql) - .bind("@slice", slice) - .bind("@entityType", entityType) - .bind("@persistenceId", persistenceId) - .bind("@deleteMarkerSeqNr", deleteMarkerSeqNr) - .bind("@writer", "") - .bind("@adapterManifest", "") - .bind("@eventSerId", 0) - .bind("@eventSerManifest", "") - .bindPayloadOption("@eventPayload", None) - .bind("@deleted", SqlServerJournalDao.TRUE) - .bind("@now", nowLocalDateTime()) - } - - def deleteBatch(from: Long, to: Long, lastBatch: Boolean): Future[Unit] = { - (if (lastBatch && !resetSequenceNumber) { - r2dbcExecutor - .update(s"delete [$persistenceId] and insert marker") { connection => - Vector( - connection - .createStatement(deleteEventsSql) - .bind("@persistenceId", persistenceId) - .bind("@from", from) - .bind("@to", to), - insertDeleteMarkerStmt(to, connection)) - } - .map(_.head) - } else { - r2dbcExecutor - .updateOne(s"delete [$persistenceId]") { connection => - connection - .createStatement(deleteEventsSql) - .bind(0, persistenceId) - .bind(1, from) - .bind(2, to) - } - }).map(deletedRows => - if (log.isDebugEnabled) { - log.debugN( - "Deleted [{}] events for persistenceId [{}], from seq num [{}] to [{}]", - deletedRows, - persistenceId, - from, - to) - })(ExecutionContexts.parasitic) - } - - val batchSize = settings.cleanupSettings.eventsJournalDeleteBatchSize - - def deleteInBatches(from: Long, maxTo: Long): Future[Unit] = { - if (from + batchSize > maxTo) { - deleteBatch(from, maxTo, true) - } else { - val to = from + batchSize - 1 - deleteBatch(from, to, false).flatMap(_ => deleteInBatches(to + 1, maxTo)) - } - } - - for { - toSeqNr <- highestSeqNrForDelete(persistenceId, toSequenceNr) - fromSeqNr <- lowestSequenceNrForDelete(persistenceId, toSeqNr, batchSize) - _ <- deleteInBatches(fromSeqNr, toSeqNr) - } yield () - } + override val journalSql: BaseJournalSql = new SqlServerJournalSql(settings) } diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerQueryDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerQueryDao.scala index c6059ffc..7ba659b0 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerQueryDao.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerQueryDao.scala @@ -13,11 +13,15 @@ import akka.persistence.r2dbc.internal.BySliceQuery.Buckets import akka.persistence.r2dbc.internal.BySliceQuery.Buckets.Bucket import akka.persistence.r2dbc.internal.JournalDao.SerializedJournalRow import akka.persistence.r2dbc.internal.PayloadCodec.RichRow -import akka.persistence.r2dbc.internal.InstantFactory -import akka.persistence.r2dbc.internal.PayloadCodec -import akka.persistence.r2dbc.internal.QueryDao -import akka.persistence.r2dbc.internal.R2dbcExecutor -import akka.persistence.r2dbc.internal.SerializedEventMetadata +import akka.persistence.r2dbc.internal.TimestampCodec.{ RichRow => TimstampRichRow, RichStatement } +import akka.persistence.r2dbc.internal.{ + InstantFactory, + PayloadCodec, + QueryDao, + R2dbcExecutor, + SerializedEventMetadata, + TimestampCodec +} import akka.persistence.r2dbc.internal.Sql.Interpolation import akka.persistence.typed.PersistenceId import akka.stream.scaladsl.Source @@ -74,7 +78,9 @@ private[r2dbc] class SqlServerQueryDao(settings: R2dbcSettings, connectionFactor protected def log: Logger = SqlServerQueryDao.log protected val journalTable = settings.journalTableWithSchema + protected implicit val journalPayloadCodec: PayloadCodec = settings.journalPayloadCodec + protected implicit val timestampCodec: TimestampCodec = settings.timestampCodec protected def eventsBySlicesRangeSql( toDbTimestampParam: Boolean, @@ -86,7 +92,7 @@ private[r2dbc] class SqlServerQueryDao(settings: R2dbcSettings, connectionFactor def toDbTimestampParamCondition = if (toDbTimestampParam) "AND db_timestamp <= @until" else "" - def localNow = toDbTimestamp(nowInstant()) + def localNow = timestampCodec.encode(InstantFactory.now()).asInstanceOf[LocalDateTime] def behindCurrentTimeIntervalCondition = if (behindCurrentTime > Duration.Zero) @@ -165,7 +171,7 @@ private[r2dbc] class SqlServerQueryDao(settings: R2dbcSettings, connectionFactor settings.logDbCallsExceeding, settings.connectionFactorySettings.poolSettings.closeCallsExceeding)(ec, system) - override def currentDbTimestamp(): Future[Instant] = Future.successful(nowInstant()) + override def currentDbTimestamp(): Future[Instant] = Future.successful(timestampCodec.now()) override def rowsBySlices( entityType: String, @@ -186,8 +192,8 @@ private[r2dbc] class SqlServerQueryDao(settings: R2dbcSettings, connectionFactor minSlice, maxSlice)) .bind("@entityType", entityType) - .bind("@from", toDbTimestamp(fromTimestamp)) - toTimestamp.foreach(t => stmt.bind("@until", toDbTimestamp(t))) + .bindTimestamp("@from", fromTimestamp) + toTimestamp.foreach(t => stmt.bindTimestamp("@until", t)) stmt.bind("@limit", settings.querySettings.bufferSize) }, row => @@ -197,8 +203,8 @@ private[r2dbc] class SqlServerQueryDao(settings: R2dbcSettings, connectionFactor entityType, persistenceId = row.get("persistence_id", classOf[String]), seqNr = row.get[java.lang.Long]("seq_nr", classOf[java.lang.Long]), - dbTimestamp = fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime])), - readDbTimestamp = fromDbTimestamp(row.get("read_db_timestamp", classOf[LocalDateTime])), + dbTimestamp = row.getTimestamp(), + readDbTimestamp = row.getTimestamp("read_db_timestamp"), payload = None, // lazy loaded for backtracking serId = row.get[Integer]("event_ser_id", classOf[Integer]), serManifest = "", @@ -211,8 +217,8 @@ private[r2dbc] class SqlServerQueryDao(settings: R2dbcSettings, connectionFactor entityType, persistenceId = row.get("persistence_id", classOf[String]), seqNr = row.get[java.lang.Long]("seq_nr", classOf[java.lang.Long]), - dbTimestamp = fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime])), - readDbTimestamp = fromDbTimestamp(row.get("read_db_timestamp", classOf[LocalDateTime])), + dbTimestamp = row.getTimestamp(), + readDbTimestamp = row.getTimestamp("read_db_timestamp"), payload = Some(row.getPayload("event_payload")), serId = row.get[Integer]("event_ser_id", classOf[Integer]), serManifest = row.get("event_ser_manifest", classOf[String]), @@ -249,8 +255,8 @@ private[r2dbc] class SqlServerQueryDao(settings: R2dbcSettings, connectionFactor connection .createStatement(selectBucketsSql(minSlice, maxSlice)) .bind("@entityType", entityType) - .bind("@fromTimestamp", toDbTimestamp(fromTimestamp)) - .bind("@toTimestamp", toDbTimestamp(toTimestamp)) + .bindTimestamp("@fromTimestamp", fromTimestamp) + .bindTimestamp("@toTimestamp", toTimestamp) .bind("@limit", limit), row => { val bucketStartEpochSeconds = row.get[java.lang.Long]("bucket", classOf[java.lang.Long]).toLong * 10 @@ -276,7 +282,7 @@ private[r2dbc] class SqlServerQueryDao(settings: R2dbcSettings, connectionFactor .createStatement(selectTimestampOfEventSql) .bind("@persistenceId", persistenceId) .bind("@seqNr", seqNr), - row => fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime]))) + row => row.getTimestamp()) } override def loadEvent( @@ -301,8 +307,8 @@ private[r2dbc] class SqlServerQueryDao(settings: R2dbcSettings, connectionFactor entityType = row.get("entity_type", classOf[String]), persistenceId, seqNr, - dbTimestamp = fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime])), - readDbTimestamp = fromDbTimestamp(row.get("read_db_timestamp", classOf[LocalDateTime])), + dbTimestamp = row.getTimestamp(), + readDbTimestamp = row.getTimestamp("read_db_timestamp"), payload, serId = row.get[Integer]("event_ser_id", classOf[Integer]), serManifest = row.get("event_ser_manifest", classOf[String]), @@ -330,8 +336,8 @@ private[r2dbc] class SqlServerQueryDao(settings: R2dbcSettings, connectionFactor entityType = row.get("entity_type", classOf[String]), persistenceId = row.get("persistence_id", classOf[String]), seqNr = row.get[java.lang.Long]("seq_nr", classOf[java.lang.Long]), - dbTimestamp = fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime])), - readDbTimestamp = fromDbTimestamp(row.get("read_db_timestamp", classOf[LocalDateTime])), + dbTimestamp = row.getTimestamp(), + readDbTimestamp = row.getTimestamp("read_db_timestamp"), payload = Some(row.getPayload("event_payload")), serId = row.get[Integer]("event_ser_id", classOf[Integer]), serManifest = row.get("event_ser_manifest", classOf[String]), diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerSnapshotDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerSnapshotDao.scala index 14f2c3ad..6a9e5225 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerSnapshotDao.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerSnapshotDao.scala @@ -14,11 +14,12 @@ import akka.persistence.r2dbc.R2dbcSettings import akka.persistence.r2dbc.internal.BySliceQuery.Buckets import akka.persistence.r2dbc.internal.BySliceQuery.Buckets.Bucket import akka.persistence.r2dbc.internal.PayloadCodec.RichRow +import akka.persistence.r2dbc.internal.TimestampCodec.{ + RichRow => TimstampRichRow, + RichStatement => TimestampRichStatement +} import akka.persistence.r2dbc.internal.PayloadCodec.RichStatement -import akka.persistence.r2dbc.internal.InstantFactory -import akka.persistence.r2dbc.internal.PayloadCodec -import akka.persistence.r2dbc.internal.R2dbcExecutor -import akka.persistence.r2dbc.internal.SnapshotDao +import akka.persistence.r2dbc.internal.{ InstantFactory, PayloadCodec, R2dbcExecutor, SnapshotDao, TimestampCodec } import akka.persistence.r2dbc.internal.Sql.Interpolation import akka.persistence.typed.PersistenceId import akka.stream.scaladsl.Source @@ -57,6 +58,7 @@ private[r2dbc] class SqlServerSnapshotDao(settings: R2dbcSettings, connectionFac protected val snapshotTable = settings.snapshotsTableWithSchema private implicit val snapshotPayloadCodec: PayloadCodec = settings.snapshotPayloadCodec + private implicit val timestampCodec: TimestampCodec = settings.timestampCodec protected val r2dbcExecutor = new R2dbcExecutor( connectionFactory, log, @@ -192,12 +194,12 @@ private[r2dbc] class SqlServerSnapshotDao(settings: R2dbcSettings, connectionFac private def collectSerializedSnapshot(entityType: String, row: Row): SerializedSnapshotRow = { val writeTimestamp = row.get[java.lang.Long]("write_timestamp", classOf[java.lang.Long]) val dbTimestamp = - if (settings.querySettings.startFromSnapshotEnabled) - fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime])) match { + if (settings.querySettings.startFromSnapshotEnabled) { + row.getTimestamp() match { case null => Instant.ofEpochMilli(writeTimestamp) case t => t } - else + } else Instant.ofEpochMilli(writeTimestamp) val tags = if (settings.querySettings.startFromSnapshotEnabled) @@ -284,7 +286,7 @@ private[r2dbc] class SqlServerSnapshotDao(settings: R2dbcSettings, connectionFac if (settings.querySettings.startFromSnapshotEnabled) { statement - .bind("@dbTimestamp", toDbTimestamp(serializedRow.dbTimestamp)) + .bindTimestamp("@dbTimestamp", serializedRow.dbTimestamp) .bind("@tags", tagsToDb(serializedRow.tags)) } @@ -318,7 +320,7 @@ private[r2dbc] class SqlServerSnapshotDao(settings: R2dbcSettings, connectionFac /** * This is used from `BySliceQuery`, i.e. only if settings.querySettings.startFromSnapshotEnabled */ - override def currentDbTimestamp(): Future[Instant] = Future.successful(nowInstant()) + override def currentDbTimestamp(): Future[Instant] = Future.successful(timestampCodec.now()) /** * This is used from `BySliceQuery`, i.e. only if settings.querySettings.startFromSnapshotEnabled @@ -336,7 +338,7 @@ private[r2dbc] class SqlServerSnapshotDao(settings: R2dbcSettings, connectionFac val stmt = connection .createStatement(snapshotsBySlicesRangeSql(minSlice, maxSlice)) .bind("@entityType", entityType) - .bind("@fromTimestamp", toDbTimestamp(fromTimestamp)) + .bindTimestamp("@fromTimestamp", fromTimestamp) .bind("@bufferSize", settings.querySettings.bufferSize) stmt }, @@ -380,8 +382,8 @@ private[r2dbc] class SqlServerSnapshotDao(settings: R2dbcSettings, connectionFac connection .createStatement(selectBucketsSql(entityType, minSlice, maxSlice)) .bind("@entityType", entityType) - .bind("@fromTimestamp", toDbTimestamp(fromTimestamp)) - .bind("@toTimestamp", toDbTimestamp(toTimestamp)) + .bindTimestamp("@fromTimestamp", fromTimestamp) + .bindTimestamp("@toTimestamp", toTimestamp) .bind("@limit", limit), row => { val bucketStartEpochSeconds = row.get("bucket", classOf[java.lang.Long]).toLong * 10 diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/sql/SqlServerDurableStateSql.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/sql/SqlServerDurableStateSql.scala new file mode 100644 index 00000000..fcad4713 --- /dev/null +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/sql/SqlServerDurableStateSql.scala @@ -0,0 +1,350 @@ +/* + * Copyright (C) 2022 - 2023 Lightbend Inc. + */ + +package akka.persistence.r2dbc.internal.sqlserver.sql + +import akka.persistence.r2dbc.R2dbcSettings +import akka.persistence.r2dbc.internal.DurableStateDao.SerializedStateRow +import akka.persistence.r2dbc.internal.{ PayloadCodec, TimestampCodec } +import akka.persistence.r2dbc.internal.PayloadCodec.RichStatement +import akka.persistence.r2dbc.internal.TimestampCodec.{ RichStatement => TimestampRichStatement } +import akka.persistence.r2dbc.internal.Sql.Interpolation +import akka.persistence.r2dbc.internal.postgres.PostgresDurableStateDao.EvaluatedAdditionalColumnBindings +import akka.persistence.r2dbc.internal.postgres.sql.BaseDurableStateSql +import akka.persistence.r2dbc.internal.sqlserver.SqlServerDialectHelper +import akka.persistence.r2dbc.state.scaladsl.AdditionalColumn +import io.r2dbc.spi.Statement + +import java.lang +import java.time.{ Instant, LocalDateTime } +import scala.collection.immutable +import scala.concurrent.duration.{ Duration, FiniteDuration } + +class SqlServerDurableStateSql(settings: R2dbcSettings)(implicit + statePayloadCodec: PayloadCodec, + timestampCodec: TimestampCodec) + extends BaseDurableStateSql { + + private val helper = SqlServerDialectHelper(settings.connectionFactorySettings.config) + + import helper._ + + def bindUpdateStateSqlForDeleteState( + stmt: Statement, + revision: Long, + persistenceId: String, + previousRevision: Long): _root_.io.r2dbc.spi.Statement = { + stmt + .bind("@revision", revision) + .bind("@stateSerId", 0) + .bind("@stateSerManifest", "") + .bindPayloadOption("@statePayload", None) + .bind("@now", timestampCodec.now()) + .bind("@persistenceId", persistenceId) + .bind("@previousRevision", previousRevision) + } + + def bindDeleteStateForInsertState( + stmt: Statement, + slice: Int, + entityType: String, + persistenceId: String, + revision: Long): Statement = { + + stmt + .bind("@slice", slice) + .bind("@entityType", entityType) + .bind("@persistenceId", persistenceId) + .bind("@revision", revision) + .bind("@stateSerId", 0) + .bind("@stateSerManifest", "") + .bindPayloadOption("@statePayload", None) + .bindNull("@tags", classOf[String]) + .bind("@now", timestampCodec.now()) + + } + + def binUpdateStateSqlForUpsertState( + stmt: Statement, + getAndIncIndex: () => Int, + state: SerializedStateRow, + additionalBindings: IndexedSeq[EvaluatedAdditionalColumnBindings], + previousRevision: Long): Statement = { + stmt + .bind("@revision", state.revision) + .bind("@stateSerId", state.serId) + .bind("@stateSerManifest", state.serManifest) + .bindPayloadOption("@statePayload", state.payload) + .bind("@now", timestampCodec.now()) + .bind("@persistenceId", state.persistenceId) + bindTags(stmt, "@tags", state) + bindAdditionalColumns(stmt, additionalBindings) + + if (settings.durableStateAssertSingleWriter) { + stmt.bind("@previousRevision", previousRevision) + } + + stmt + + } + + def bindTags(stmt: Statement, name: String, state: SerializedStateRow): Statement = { + if (state.tags.isEmpty) + stmt.bindNull(name, classOf[String]) + else + stmt.bind(name, tagsToDb(state.tags)) + } + + def bindInsertStateForUpsertState( + stmt: Statement, + getAndIncIndex: () => Int, + slice: Int, + entityType: String, + state: SerializedStateRow, + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): Statement = { + + stmt + .bind("@slice", slice) + .bind("@entityType", entityType) + .bind("@persistenceId", state.persistenceId) + .bind("@revision", state.revision) + .bind("@stateSerId", state.serId) + .bind("@stateSerManifest", state.serManifest) + .bindPayloadOption("@statePayload", state.payload) + .bind("@now", timestampCodec.now()) + bindTags(stmt, "@tags", state) + bindAdditionalColumns(stmt, additionalBindings) + } + + def updateStateSql( + entityType: String, + updateTags: Boolean, + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { + + val stateTable = settings.getDurableStateTableWithSchema(entityType) + + val revisionCondition = + if (settings.durableStateAssertSingleWriter) " AND revision = @previousRevision" + else "" + + val tags = if (updateTags) ", tags = @tags" else "" + + val additionalParams = additionalUpdateParameters(additionalBindings) + sql""" + UPDATE $stateTable + SET revision = @revision, state_ser_id = @stateSerId, state_ser_manifest = @stateSerManifest, state_payload = @statePayload $tags $additionalParams, db_timestamp = @now + WHERE persistence_id = @persistenceId + $revisionCondition""" + + } + + override def insertStateSql( + entityType: String, + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { + val stateTable = settings.getDurableStateTableWithSchema(entityType) + val additionalCols = additionalInsertColumns(additionalBindings) + val additionalParams = additionalInsertParameters(additionalBindings) + sql""" + INSERT INTO $stateTable + (slice, entity_type, persistence_id, revision, state_ser_id, state_ser_manifest, state_payload, tags$additionalCols, db_timestamp) + VALUES (@slice, @entityType, @persistenceId, @revision, @stateSerId, @stateSerManifest, @statePayload, @tags$additionalParams, @now)""" + } + + def bindForHardDeleteState(stmt: Statement, persistenceId: String): Statement = { + stmt.bind("@persistenceId", persistenceId) + } + + def hardDeleteStateSql(entityType: String): String = { + val stateTable = settings.getDurableStateTableWithSchema(entityType) + sql"DELETE from $stateTable WHERE persistence_id = @persistenceId" + } + + private def additionalInsertColumns( + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { + if (additionalBindings.isEmpty) "" + else { + val strB = new lang.StringBuilder() + additionalBindings.foreach { + case EvaluatedAdditionalColumnBindings(c, bindBalue) if bindBalue != AdditionalColumn.Skip => + strB.append(", ").append(c.columnName) + case EvaluatedAdditionalColumnBindings(_, _) => + } + strB.toString + } + } + + private def additionalInsertParameters( + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { + if (additionalBindings.isEmpty) "" + else { + val strB = new lang.StringBuilder() + additionalBindings.foreach { + case EvaluatedAdditionalColumnBindings(col, bindValue) if bindValue != AdditionalColumn.Skip => + strB.append(s", @${col.columnName}") + case EvaluatedAdditionalColumnBindings(_, _) => + } + strB.toString + } + } + + private def additionalUpdateParameters( + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { + if (additionalBindings.isEmpty) "" + else { + val strB = new lang.StringBuilder() + additionalBindings.foreach { + case EvaluatedAdditionalColumnBindings(col, binValue) if binValue != AdditionalColumn.Skip => + strB.append(", ").append(col.columnName).append(s" = @${col.columnName}") + case EvaluatedAdditionalColumnBindings(_, _) => + } + strB.toString + } + } + + def bindAdditionalColumns( + stmt: Statement, + additionalBindings: IndexedSeq[EvaluatedAdditionalColumnBindings]): Statement = { + additionalBindings.foreach { + case EvaluatedAdditionalColumnBindings(col, AdditionalColumn.BindValue(v)) => + stmt.bind(s"@${col.columnName}", v) + case EvaluatedAdditionalColumnBindings(col, AdditionalColumn.BindNull) => + stmt.bindNull(s"@${col.columnName}", col.fieldClass) + case EvaluatedAdditionalColumnBindings(_, AdditionalColumn.Skip) => + } + stmt + } + + override def bindForSelectStateSql(stmt: Statement, persistenceId: String): Statement = + stmt.bind("@persistenceId", persistenceId) + + override def selectStateSql(entityType: String): String = { + val stateTable = settings.getDurableStateTableWithSchema(entityType) + sql""" + SELECT revision, state_ser_id, state_ser_manifest, state_payload, db_timestamp + FROM $stateTable WHERE persistence_id = @persistenceId""" + } + + override def selectBucketsSql(entityType: String, minSlice: Int, maxSlice: Int): String = { + val stateTable = settings.getDurableStateTableWithSchema(entityType) + + // group by column alias (bucket) needs a sub query + val subQuery = + s""" + | select TOP(@limit) CAST(DATEDIFF(s,'1970-01-01 00:00:00',db_timestamp) AS BIGINT) / 10 AS bucket + | FROM $stateTable + | WHERE entity_type = @entityType + | AND ${sliceCondition(minSlice, maxSlice)} + | AND db_timestamp >= @fromTimestamp AND db_timestamp <= @toTimestamp + |""".stripMargin + sql""" + SELECT bucket, count(*) as count from ($subQuery) as sub + GROUP BY bucket ORDER BY bucket + """ + } + + override def bindSelectBucketsForCoundBuckets( + stmt: Statement, + entityType: String, + fromTimestamp: Instant, + toTimestamp: Instant, + limit: Int): Statement = stmt + .bind("@entityType", entityType) + .bindTimestamp("@fromTimestamp", fromTimestamp) + .bindTimestamp("@toTimestamp", toTimestamp) + .bind("@limit", limit) + + override def persistenceIdsForEntityTypeAfterSql(table: String): String = + sql"SELECT TOP(@limit) persistence_id from $table WHERE persistence_id LIKE @persistenceIdLike AND persistence_id > @persistenceId ORDER BY persistence_id" + + override def bindPersistenceIdsForEntityTypeAfter( + stmt: Statement, + entityTypePlusLikeStmtPostfix: String, + afterPersistenceId: String, + limit: Long): Statement = { + stmt + .bind("@persistenceIdLike", entityTypePlusLikeStmtPostfix) + .bind("@persistenceId", afterPersistenceId) + .bind("@limit", limit) + } + + override def persistenceIdsForEntityTypeSql(table: String): String = + sql"SELECT TOP(@limit) persistence_id from $table WHERE persistence_id LIKE @persistenceIdLike ORDER BY persistence_id" + + override def bindPersistenceIdsForEntityType( + stmt: Statement, + entityTypePlusLikeStmtPostfix: String, + limit: Long): Statement = + stmt + .bind("@persistenceIdLike", entityTypePlusLikeStmtPostfix) + .bind("@limit", limit) + + override def allPersistenceIdsAfterSql(table: String): String = + sql"SELECT TOP(@limit) persistence_id from $table WHERE persistence_id > @persistenceId ORDER BY persistence_id" + + override def bindForAllPersistenceIdsAfter(stmt: Statement, after: String, limit: Long): Statement = stmt + .bind("@persistenceId", after) + .bind("@limit", limit) + + override def bindForAllPersistenceIdsSql(stmt: Statement, limit: Long): Statement = + stmt.bind("@limit", limit) + + override def allPersistenceIdsSql(table: String): String = + sql"SELECT TOP(@limit) persistence_id from $table ORDER BY persistence_id" + + override def bindForStateBySlicesRangeSql( + stmt: Statement, + entityType: String, + fromTimestamp: Instant, + toTimestamp: Option[Instant], + behindCurrentTime: FiniteDuration): Statement = { + stmt + .bind("@entityType", entityType) + .bindTimestamp("@fromTimestamp", fromTimestamp) + + stmt.bind("@limit", settings.querySettings.bufferSize) + + if (behindCurrentTime > Duration.Zero) { + stmt.bind("@now", timestampCodec.now()) + } + + toTimestamp.foreach(until => stmt.bindTimestamp("@until", until)) + + stmt + } + + private def behindCurrentTimeIntervalConditionFor(behindCurrentTime: FiniteDuration): String = + if (behindCurrentTime > Duration.Zero) + s"AND db_timestamp < DATEADD(ms, -${behindCurrentTime.toMillis}, @now)" + else "" + + override def stateBySlicesRangeSql( + entityType: String, + maxDbTimestampParam: Boolean, + behindCurrentTime: FiniteDuration, + backtracking: Boolean, + minSlice: Int, + maxSlice: Int): String = { + val stateTable = settings.getDurableStateTableWithSchema(entityType) + + def maxDbTimestampParamCondition = + if (maxDbTimestampParam) s"AND db_timestamp < @until" else "" + + val behindCurrentTimeIntervalCondition = behindCurrentTimeIntervalConditionFor(behindCurrentTime) + + val selectColumns = + if (backtracking) + "SELECT TOP(@limit) persistence_id, revision, db_timestamp, SYSUTCDATETIME() AS read_db_timestamp, state_ser_id " + else + "SELECT TOP(@limit) persistence_id, revision, db_timestamp, SYSUTCDATETIME() AS read_db_timestamp, state_ser_id, state_ser_manifest, state_payload " + + sql""" + $selectColumns + FROM $stateTable + WHERE entity_type = @entityType + AND ${sliceCondition(minSlice, maxSlice)} + AND db_timestamp >= @fromTimestamp $maxDbTimestampParamCondition $behindCurrentTimeIntervalCondition + ORDER BY db_timestamp, revision""" + } + +} diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/sql/SqlServerJournalSql.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/sql/SqlServerJournalSql.scala new file mode 100644 index 00000000..f27886e7 --- /dev/null +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/sql/SqlServerJournalSql.scala @@ -0,0 +1,161 @@ +/* + * Copyright (C) 2022 - 2023 Lightbend Inc. + */ + +package akka.persistence.r2dbc.internal.sqlserver.sql + +import akka.persistence.r2dbc.R2dbcSettings +import akka.persistence.r2dbc.internal.{ JournalDao, PayloadCodec, TimestampCodec } +import akka.persistence.r2dbc.internal.Sql.Interpolation +import akka.persistence.r2dbc.internal.PayloadCodec.RichStatement +import akka.persistence.r2dbc.internal.TimestampCodec.{ RichRow, RichStatement => TimestampRichStatement } +import akka.persistence.r2dbc.internal.postgres.sql.BaseJournalSql +import akka.persistence.r2dbc.internal.sqlserver.{ SqlServerDialectHelper, SqlServerJournalDao } +import io.r2dbc.spi.{ Row, Statement } + +import java.time.{ Instant, LocalDateTime } + +class SqlServerJournalSql(journalSettings: R2dbcSettings)(implicit + statePayloadCodec: PayloadCodec, + timestampCodec: TimestampCodec) + extends BaseJournalSql { + + private val helper = new SqlServerDialectHelper(journalSettings.connectionFactorySettings.config) + + private val journalTable = journalSettings.journalTableWithSchema + + private val deleteEventsBySliceBeforeTimestampSql = + sql""" + DELETE FROM $journalTable + WHERE slice = @slice AND entity_type = @entityType AND db_timestamp < @dbTimestamp""" + + override def deleteEventsBySliceBeforeTimestamp( + createStatement: String => Statement, + slice: Int, + entityType: String, + timestamp: Instant): Statement = { + createStatement(deleteEventsBySliceBeforeTimestampSql) + .bind("@slice", slice) + .bind("@entityType", entityType) + .bindTimestamp("@dbTimestamp", timestamp) + } + + private val deleteEventsByPersistenceIdBeforeTimestampSql = + sql""" + DELETE FROM $journalTable + WHERE persistence_id = @persistenceId AND db_timestamp < @timestamp""" + + // can this be inherited by PostgresJournalSql? + def deleteEventsByPersistenceIdBeforeTimestamp( + createStatement: String => Statement, + persistenceId: String, + timestamp: Instant): Statement = { + createStatement(deleteEventsByPersistenceIdBeforeTimestampSql) + .bind("@persistenceId", persistenceId) + .bindTimestamp("@timestamp", timestamp) + } + + override def bindInsertForWriteEvent( + stmt: Statement, + write: JournalDao.SerializedJournalRow, + useTimestampFromDb: Boolean, + previousSeqNr: Long): Statement = { + + stmt + .bind("@slice", write.slice) + .bind("@entityType", write.entityType) + .bind("@persistenceId", write.persistenceId) + .bind("@seqNr", write.seqNr) + .bind("@writer", write.writerUuid) + .bind("@adapterManifest", "") // FIXME event adapter + .bind("@eventSerId", write.serId) + .bind("@eventSerManifest", write.serManifest) + .bindPayload("@eventPayload", write.payload.get) + + if (write.tags.isEmpty) + stmt.bindNull("@tags", classOf[String]) + else + stmt.bind("@tags", helper.tagsToDb(write.tags)) + + // optional metadata + write.metadata match { + case Some(m) => + stmt + .bind("@metaSerId", m.serId) + .bind("@metaSerManifest", m.serManifest) + .bind("@metaSerPayload", m.payload) + case None => + stmt + .bindNull("@metaSerId", classOf[Integer]) + .bindNull("@metaSerManifest", classOf[String]) + .bindNull("@metaSerPayload", classOf[Array[Byte]]) + } + stmt.bindTimestamp("@dbTimestamp", write.dbTimestamp) + } + + /** + * Param `useTimestampFromDb` is ignored in sqlserver + */ + override def insertSql(useTimestampFromDb: Boolean): String = + sql""" + INSERT INTO $journalTable + (slice, entity_type, persistence_id, seq_nr, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, tags, meta_ser_id, meta_ser_manifest, meta_payload, db_timestamp) + OUTPUT inserted.db_timestamp + VALUES (@slice, @entityType, @persistenceId, @seqNr, @writer, @adapterManifest, @eventSerId, @eventSerManifest, @eventPayload, @tags, @metaSerId, @metaSerManifest, @metaSerPayload, @dbTimestamp)""" + + override def parseInsertForWriteEvent(row: Row): Instant = row.getTimestamp() + + override val selectHighestSequenceNrSql = + sql""" + SELECT MAX(seq_nr) as max_seq_nr from $journalTable + WHERE persistence_id = @persistenceId AND seq_nr >= @seqNr""" + + override def bindSelectHighestSequenceNrSql(stmt: Statement, persistenceId: String, fromSequenceNr: Long): Statement = + stmt + .bind("@persistenceId", persistenceId) + .bind("@seqNr", fromSequenceNr) + + override val selectLowestSequenceNrSql = + sql""" + SELECT MIN(seq_nr) as min_seq_nr from $journalTable + WHERE persistence_id = @persistenceId""" + + override def bindSelectLowestSequenceNrSql(stmt: Statement, persistenceId: String): Statement = + stmt + .bind("@persistenceId", persistenceId) + + override val insertDeleteMarkerSql = + sql""" + INSERT INTO $journalTable(slice, entity_type, persistence_id, seq_nr, db_timestamp, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, deleted) + VALUES(@slice, @entityType, @persistenceId, @deleteMarkerSeqNr, @now, @writer, @adapterManifest, @eventSerId, @eventSerManifest, @eventPayload, @deleted )""" + + override def bindForInsertDeleteMarkerSql( + stmt: Statement, + slice: Int, + entityType: String, + persistenceId: String, + deleteMarkerSeqNr: Long): Statement = { + stmt + .bind("@slice", slice) + .bind("@entityType", entityType) + .bind("@persistenceId", persistenceId) + .bind("@deleteMarkerSeqNr", deleteMarkerSeqNr) + .bind("@writer", "") + .bind("@adapterManifest", "") + .bind("@eventSerId", 0) + .bind("@eventSerManifest", "") + .bindPayloadOption("@eventPayload", None) + .bind("@deleted", SqlServerJournalDao.TRUE) + .bind("@now", timestampCodec.now()) + } + + override val deleteEventsSql = + sql""" + DELETE FROM $journalTable + WHERE persistence_id = @persistenceId AND seq_nr >= @from AND seq_nr <= @to""" + + override def bindForDeleteEventsSql(stmt: Statement, persistenceId: String, from: Long, to: Long): Statement = stmt + .bind("@persistenceId", persistenceId) + .bind("@from", from) + .bind("@to", to) +}