diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 6c5a55e7..b422de6e 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -177,6 +177,40 @@ jobs: cp .jvmopts-ci .jvmopts sbt -Dconfig.resource=application-h2.conf test + test-sqlserver: + name: Run test with SQL Server + runs-on: ubuntu-22.04 + if: github.repository == 'akka/akka-persistence-r2dbc' + steps: + - name: Checkout + uses: actions/checkout@v3.1.0 + with: + fetch-depth: 0 + + - name: Checkout GitHub merge + if: github.event.pull_request + run: |- + git fetch origin pull/${{ github.event.pull_request.number }}/merge:scratch + git checkout scratch + + - name: Cache Coursier cache + uses: coursier/cache-action@v6.4.0 + + - name: Set up JDK 11 + uses: coursier/setup-action@v1.3.0 + with: + jvm: temurin:1.11.0 + + - name: Start DB + run: |- + docker compose -f docker/docker-compose-sqlserver.yml up --wait + docker exec -i sqlserver-db /opt/mssql-tools/bin/sqlcmd -S localhost -U SA -P '' -d master < ddl-scripts/create_tables_sqlserver.sql + + - name: sbt test + run: |- + cp .jvmopts-ci .jvmopts + sbt -Dconfig.resource=application-sqlserver.conf test + test-docs: name: Docs runs-on: ubuntu-22.04 diff --git a/build.sbt b/build.sbt index 35747a71..a8a2c80b 100644 --- a/build.sbt +++ b/build.sbt @@ -139,7 +139,9 @@ lazy val docs = project Preprocess / siteSubdirName := s"api/akka-persistence-r2dbc/${projectInfoVersion.value}", Preprocess / sourceDirectory := (LocalRootProject / ScalaUnidoc / unidoc / target).value, Paradox / siteSubdirName := s"docs/akka-persistence-r2dbc/${projectInfoVersion.value}", - paradoxGroups := Map("Language" -> Seq("Java", "Scala"), "Dialect" -> Seq("Postgres", "Yugabyte", "H2")), + paradoxGroups := Map( + "Language" -> Seq("Java", "Scala"), + "Dialect" -> Seq("SQL Server", "Postgres", "Yugabyte", "H2")), Compile / paradoxProperties ++= Map( "project.url" -> "https://doc.akka.io/docs/akka-persistence-r2dbc/current/", "canonical.base_url" -> "https://doc.akka.io/docs/akka-persistence-r2dbc/current", diff --git a/core/src/main/resources/reference.conf b/core/src/main/resources/reference.conf index fb6ec574..d8129390 100644 --- a/core/src/main/resources/reference.conf +++ b/core/src/main/resources/reference.conf @@ -357,6 +357,33 @@ akka.persistence.r2dbc { // #connection-settings-h2 } + # Defaults for SQL Server + sqlserver = ${akka.persistence.r2dbc.default-connection-pool} + sqlserver { + dialect = "sqlserver" + driver = "mssql" + + // #connection-settings-sqlserver + # the connection can be configured with a url, eg: "r2dbc:sqlserver://:1433/" + url = "" + + # The connection options to be used. Ignored if 'url' is non-empty + host = "localhost" + + port = 1433 + database = "master" + user = "SA" + password = "" + + # Maximum time to create a new connection. + connect-timeout = 3 seconds + + # Used to encode tags to and from db. Tags must not contain this separator. + tag-separator = "," + + // #connection-settings-sqlserver + } + # Assign the connection factory for the dialect you want to use, then override specific fields # connection-factory = ${akka.persistence.r2dbc.postgres} # connection-factory { @@ -368,7 +395,7 @@ akka.persistence.r2dbc { # updates of the same persistenceId there might be a performance gain to # set this to `on`. Note that many databases use the system clock and that can # move backwards when the system clock is adjusted. - # Ignored for H2 + # Ignored for H2 and sqlserver db-timestamp-monotonic-increasing = off # Enable this to generate timestamps from the Akka client side instead of using database timestamps. diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/ConnectionFactorySettings.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/ConnectionFactorySettings.scala index f8efa52d..49e8bfc1 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/ConnectionFactorySettings.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/ConnectionFactorySettings.scala @@ -8,6 +8,7 @@ import akka.actor.typed.ActorSystem import akka.annotation.InternalApi import akka.persistence.r2dbc.ConnectionPoolSettings import akka.persistence.r2dbc.internal.h2.H2Dialect +import akka.persistence.r2dbc.internal.sqlserver.SqlServerDialect import akka.persistence.r2dbc.internal.postgres.PostgresDialect import akka.persistence.r2dbc.internal.postgres.YugabyteDialect import akka.util.Helpers.toRootLowerCase @@ -24,12 +25,13 @@ private[r2dbc] object ConnectionFactorySettings { def apply(config: Config): ConnectionFactorySettings = { val dialect: Dialect = toRootLowerCase(config.getString("dialect")) match { - case "yugabyte" => YugabyteDialect: Dialect - case "postgres" => PostgresDialect: Dialect - case "h2" => H2Dialect: Dialect + case "yugabyte" => YugabyteDialect: Dialect + case "postgres" => PostgresDialect: Dialect + case "h2" => H2Dialect: Dialect + case "sqlserver" => SqlServerDialect: Dialect case other => throw new IllegalArgumentException( - s"Unknown dialect [$other]. Supported dialects are [postgres, yugabyte, h2].") + s"Unknown dialect [$other]. Supported dialects are [postgres, yugabyte, h2, sqlserver].") } // pool settings are common to all dialects but defined inline in the connection factory block diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/PayloadCodec.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/PayloadCodec.scala index e4770594..0edd385b 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/PayloadCodec.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/PayloadCodec.scala @@ -42,11 +42,20 @@ import io.r2dbc.spi.Statement def bindPayload(index: Int, payload: Array[Byte]): Statement = statement.bind(index, codec.encode(payload)) + def bindPayload(name: String, payload: Array[Byte]): Statement = + statement.bind(name, codec.encode(payload)) + def bindPayloadOption(index: Int, payloadOption: Option[Array[Byte]]): Statement = payloadOption match { case Some(payload) => bindPayload(index, payload) case None => bindPayload(index, codec.nonePayload) } + + def bindPayloadOption(name: String, payloadOption: Option[Array[Byte]]): Statement = + payloadOption match { + case Some(payload) => bindPayload(name, payload) + case None => bindPayload(name, codec.nonePayload) + } } implicit class RichRow(val row: Row)(implicit codec: PayloadCodec) extends AnyRef { def getPayload(name: String): Array[Byte] = codec.decode(row.get(name, codec.payloadClass)) diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialect.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialect.scala new file mode 100644 index 00000000..3b6ab6c9 --- /dev/null +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialect.scala @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2022 - 2023 Lightbend Inc. + */ + +package akka.persistence.r2dbc.internal.sqlserver + +import akka.actor.typed.ActorSystem +import akka.annotation.InternalApi +import akka.persistence.r2dbc.R2dbcSettings +import akka.persistence.r2dbc.internal._ +import akka.util.JavaDurationConverters.JavaDurationOps +import com.typesafe.config.Config +import io.r2dbc.spi.ConnectionFactories +import io.r2dbc.spi.ConnectionFactory +import io.r2dbc.spi.ConnectionFactoryOptions + +import java.time.{ Duration => JDuration } +import scala.concurrent.duration.FiniteDuration + +/** + * INTERNAL API + */ +@InternalApi +private[r2dbc] object SqlServerDialect extends Dialect { + + private[r2dbc] final class SqlServerConnectionFactorySettings(config: Config) { + val urlOption: Option[String] = + Option(config.getString("url")) + .filter(_.trim.nonEmpty) + + val driver: String = config.getString("driver") + val host: String = config.getString("host") + val port: Int = config.getInt("port") + val user: String = config.getString("user") + val password: String = config.getString("password") + val database: String = config.getString("database") + val connectTimeout: FiniteDuration = config.getDuration("connect-timeout").asScala + + } + + override def name: String = "sqlserver" + + override def adaptSettings(settings: R2dbcSettings): R2dbcSettings = { + val res = settings + // app timestamp is db timestamp because sqlserver does not provide a transaction timestamp + .withUseAppTimestamp(true) + // saw flaky tests where the Instant.now was smaller then the db timestamp AFTER the insert + .withDbTimestampMonotonicIncreasing(false) + res + } + + override def createConnectionFactory(config: Config): ConnectionFactory = { + + val settings = new SqlServerConnectionFactorySettings(config) + val builder = + settings.urlOption match { + case Some(url) => + ConnectionFactoryOptions + .builder() + .from(ConnectionFactoryOptions.parse(url)) + case _ => + ConnectionFactoryOptions + .builder() + .option(ConnectionFactoryOptions.DRIVER, settings.driver) + .option(ConnectionFactoryOptions.HOST, settings.host) + .option(ConnectionFactoryOptions.PORT, Integer.valueOf(settings.port)) + .option(ConnectionFactoryOptions.USER, settings.user) + .option(ConnectionFactoryOptions.PASSWORD, settings.password) + .option(ConnectionFactoryOptions.DATABASE, settings.database) + .option(ConnectionFactoryOptions.CONNECT_TIMEOUT, JDuration.ofMillis(settings.connectTimeout.toMillis)) + } + ConnectionFactories.get(builder.build()) + } + + override def createJournalDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit + system: ActorSystem[_]): JournalDao = + new SqlServerJournalDao(settings, connectionFactory)(system.executionContext, system) + + override def createQueryDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit + system: ActorSystem[_]): QueryDao = + new SqlServerQueryDao(settings, connectionFactory)(system.executionContext, system) + + override def createSnapshotDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit + system: ActorSystem[_]): SnapshotDao = + new SqlServerSnapshotDao(settings, connectionFactory)(system.executionContext, system) + + override def createDurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit + system: ActorSystem[_]): DurableStateDao = + new SqlServerDurableStateDao(settings, connectionFactory)(system.executionContext, system) +} diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialectHelper.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialectHelper.scala new file mode 100644 index 00000000..9e974bcd --- /dev/null +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialectHelper.scala @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2022 - 2023 Lightbend Inc. + */ + +package akka.persistence.r2dbc.internal.sqlserver + +import akka.annotation.InternalApi +import akka.persistence.r2dbc.internal.InstantFactory +import com.typesafe.config.Config +import io.r2dbc.spi.Row + +import java.time.Instant +import java.time.LocalDateTime +import java.util.TimeZone + +/** + * INTERNAL API + */ +@InternalApi +private[r2dbc] object SqlServerDialectHelper { + def apply(config: Config) = new SqlServerDialectHelper(config) +} + +/** + * INTERNAL API + */ +@InternalApi +private[r2dbc] class SqlServerDialectHelper(config: Config) { + + private val tagSeparator = config.getString("tag-separator") + + require(tagSeparator.length == 1, s"Tag separator '$tagSeparator' must be a single character.") + + def tagsToDb(tags: Set[String]): String = { + if (tags.exists(_.contains(tagSeparator))) { + throw new IllegalArgumentException( + s"A tag in [$tags] contains the character '$tagSeparator' which is reserved. Please change `akka.persistence.r2dbc.sqlserver.tag-separator` to a character that is not contained by any of your tags.") + } + tags.mkString(tagSeparator) + } + + def tagsFromDb(row: Row): Set[String] = row.get("tags", classOf[String]) match { + case null => Set.empty[String] + case entries => entries.split(tagSeparator).toSet + } + + private val zone = TimeZone.getTimeZone("UTC").toZoneId + + def nowInstant(): Instant = InstantFactory.now() + + def nowLocalDateTime(): LocalDateTime = LocalDateTime.ofInstant(nowInstant(), zone) + + def toDbTimestamp(timestamp: Instant): LocalDateTime = + LocalDateTime.ofInstant(timestamp, zone) + + def fromDbTimestamp(time: LocalDateTime): Instant = time + .atZone(zone) + .toInstant + +} diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDurableStateDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDurableStateDao.scala new file mode 100644 index 00000000..ef1c91a4 --- /dev/null +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDurableStateDao.scala @@ -0,0 +1,758 @@ +/* + * Copyright (C) 2022 - 2023 Lightbend Inc. + */ + +package akka.persistence.r2dbc.internal.sqlserver + +import akka.Done +import akka.NotUsed +import akka.actor.typed.ActorSystem +import akka.actor.typed.scaladsl.LoggerOps +import akka.annotation.InternalApi +import akka.dispatch.ExecutionContexts +import akka.persistence.Persistence +import akka.persistence.query.DeletedDurableState +import akka.persistence.query.DurableStateChange +import akka.persistence.query.NoOffset +import akka.persistence.query.UpdatedDurableState +import akka.persistence.r2dbc.R2dbcSettings +import akka.persistence.r2dbc.internal.AdditionalColumnFactory +import akka.persistence.r2dbc.internal.BySliceQuery.Buckets +import akka.persistence.r2dbc.internal.BySliceQuery.Buckets.Bucket +import akka.persistence.r2dbc.internal.ChangeHandlerFactory +import akka.persistence.r2dbc.internal.DurableStateDao +import akka.persistence.r2dbc.internal.PayloadCodec +import akka.persistence.r2dbc.internal.PayloadCodec.RichRow +import akka.persistence.r2dbc.internal.PayloadCodec.RichStatement +import akka.persistence.r2dbc.internal.R2dbcExecutor +import akka.persistence.r2dbc.internal.Sql.Interpolation +import akka.persistence.r2dbc.session.scaladsl.R2dbcSession +import akka.persistence.r2dbc.state.ChangeHandlerException +import akka.persistence.r2dbc.state.scaladsl.AdditionalColumn +import akka.persistence.r2dbc.state.scaladsl.ChangeHandler +import akka.persistence.typed.PersistenceId +import akka.stream.scaladsl.Source +import io.r2dbc.spi.Connection +import io.r2dbc.spi.ConnectionFactory +import io.r2dbc.spi.R2dbcDataIntegrityViolationException +import io.r2dbc.spi.Row +import io.r2dbc.spi.Statement +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +import java.lang +import java.time.Instant +import java.time.LocalDateTime +import java.util +import java.util.TimeZone +import scala.collection.immutable +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.concurrent.duration.Duration +import scala.concurrent.duration.FiniteDuration +import scala.util.control.NonFatal + +/** + * INTERNAL API + */ +@InternalApi +private[r2dbc] object SqlServerDurableStateDao { + + private val log: Logger = LoggerFactory.getLogger(classOf[SqlServerDurableStateDao]) + + private final case class EvaluatedAdditionalColumnBindings( + additionalColumn: AdditionalColumn[_, _], + binding: AdditionalColumn.Binding[_]) + + val FutureDone: Future[Done] = Future.successful(Done) +} + +/** + * INTERNAL API + */ +@InternalApi +private[r2dbc] class SqlServerDurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit + ec: ExecutionContext, + system: ActorSystem[_]) + extends DurableStateDao { + import DurableStateDao._ + import SqlServerDurableStateDao._ + protected def log: Logger = SqlServerDurableStateDao.log + + private val helper = SqlServerDialectHelper(settings.connectionFactorySettings.config) + import helper._ + + private val persistenceExt = Persistence(system) + protected val r2dbcExecutor = new R2dbcExecutor( + connectionFactory, + log, + settings.logDbCallsExceeding, + settings.connectionFactorySettings.poolSettings.closeCallsExceeding)(ec, system) + + private implicit val statePayloadCodec: PayloadCodec = settings.durableStatePayloadCodec + + private lazy val additionalColumns: Map[String, immutable.IndexedSeq[AdditionalColumn[Any, Any]]] = { + settings.durableStateAdditionalColumnClasses.map { case (entityType, columnClasses) => + val instances = columnClasses.map(fqcn => AdditionalColumnFactory.create(system, fqcn)) + entityType -> instances + } + } + + private lazy val changeHandlers: Map[String, ChangeHandler[Any]] = { + settings.durableStateChangeHandlerClasses.map { case (entityType, fqcn) => + val handler = ChangeHandlerFactory.create(system, fqcn) + entityType -> handler + } + } + + private def selectStateSql(entityType: String): String = { + val stateTable = settings.getDurableStateTableWithSchema(entityType) + sql""" + SELECT revision, state_ser_id, state_ser_manifest, state_payload, db_timestamp + FROM $stateTable WHERE persistence_id = @persistenceId""" + } + + private def selectBucketsSql(entityType: String, minSlice: Int, maxSlice: Int): String = { + val stateTable = settings.getDurableStateTableWithSchema(entityType) + + // group by column alias (bucket) needs a sub query + val subQuery = + s""" + | select TOP(@limit) CAST(DATEDIFF(s,'1970-01-01 00:00:00',db_timestamp) AS BIGINT) / 10 AS bucket + | FROM $stateTable + | WHERE entity_type = @entityType + | AND ${sliceCondition(minSlice, maxSlice)} + | AND db_timestamp >= @fromTimestamp AND db_timestamp <= @toTimestamp + |""".stripMargin + sql""" + SELECT bucket, count(*) as count from ($subQuery) as sub + GROUP BY bucket ORDER BY bucket + """ + } + + protected def sliceCondition(minSlice: Int, maxSlice: Int): String = + s"slice in (${(minSlice to maxSlice).mkString(",")})" + + private def insertStateSql( + entityType: String, + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { + val stateTable = settings.getDurableStateTableWithSchema(entityType) + val additionalCols = additionalInsertColumns(additionalBindings) + val additionalParams = additionalInsertParameters(additionalBindings) + sql""" + INSERT INTO $stateTable + (slice, entity_type, persistence_id, revision, state_ser_id, state_ser_manifest, state_payload, tags$additionalCols, db_timestamp) + VALUES (@slice, @entityType, @persistenceId, @revision, @stateSerId, @stateSerManifest, @statePayload, @tags$additionalParams, @now)""" + } + + private def additionalInsertColumns( + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { + if (additionalBindings.isEmpty) "" + else { + val strB = new lang.StringBuilder() + additionalBindings.foreach { + case EvaluatedAdditionalColumnBindings(c, bindBalue) if bindBalue != AdditionalColumn.Skip => + strB.append(", ").append(c.columnName) + case EvaluatedAdditionalColumnBindings(_, _) => + } + strB.toString + } + } + + private def additionalInsertParameters( + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { + if (additionalBindings.isEmpty) "" + else { + val strB = new lang.StringBuilder() + additionalBindings.foreach { + case EvaluatedAdditionalColumnBindings(col, bindValue) if bindValue != AdditionalColumn.Skip => + strB.append(s", @${col.columnName}") + case EvaluatedAdditionalColumnBindings(_, _) => + } + strB.toString + } + } + + private def updateStateSql( + entityType: String, + updateTags: Boolean, + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { + val stateTable = settings.getDurableStateTableWithSchema(entityType) + + val revisionCondition = + if (settings.durableStateAssertSingleWriter) " AND revision = @previousRevision" + else "" + + val tags = if (updateTags) ", tags = @tags" else "" + + val additionalParams = additionalUpdateParameters(additionalBindings) + sql""" + UPDATE $stateTable + SET revision = @revision, state_ser_id = @stateSerId, state_ser_manifest = @stateSerManifest, state_payload = @statePayload $tags $additionalParams, db_timestamp = @now + WHERE persistence_id = @persistenceId + $revisionCondition""" + } + + private def additionalUpdateParameters( + additionalBindings: immutable.IndexedSeq[EvaluatedAdditionalColumnBindings]): String = { + if (additionalBindings.isEmpty) "" + else { + val strB = new lang.StringBuilder() + additionalBindings.foreach { + case EvaluatedAdditionalColumnBindings(col, binValue) if binValue != AdditionalColumn.Skip => + strB.append(", ").append(col.columnName).append(s" = @${col.columnName}") + case EvaluatedAdditionalColumnBindings(_, _) => + } + strB.toString + } + } + + private def hardDeleteStateSql(entityType: String): String = { + val stateTable = settings.getDurableStateTableWithSchema(entityType) + sql"DELETE from $stateTable WHERE persistence_id = @persistenceId" + } + + private def allPersistenceIdsSql(table: String): String = + sql"SELECT TOP(@limit) persistence_id from $table ORDER BY persistence_id" + + private def persistenceIdsForEntityTypeSql(table: String): String = + sql"SELECT TOP(@limit) persistence_id from $table WHERE persistence_id LIKE @persistenceIdLike ORDER BY persistence_id" + + private def allPersistenceIdsAfterSql(table: String): String = + sql"SELECT TOP(@limit) persistence_id from $table WHERE persistence_id > @persistenceId ORDER BY persistence_id" + + private def persistenceIdsForEntityTypeAfterSql(table: String): String = + sql"SELECT TOP(@limit) persistence_id from $table WHERE persistence_id LIKE @persistenceIdLike AND persistence_id > @persistenceId ORDER BY persistence_id" + + protected def behindCurrentTimeIntervalConditionFor(behindCurrentTime: FiniteDuration): String = + if (behindCurrentTime > Duration.Zero) + s"AND db_timestamp < DATEADD(ms, -${behindCurrentTime.toMillis}, @now)" + else "" + + protected def stateBySlicesRangeSql( + entityType: String, + maxDbTimestampParam: Boolean, + behindCurrentTime: FiniteDuration, + backtracking: Boolean, + minSlice: Int, + maxSlice: Int): String = { + val stateTable = settings.getDurableStateTableWithSchema(entityType) + + def maxDbTimestampParamCondition = + if (maxDbTimestampParam) s"AND db_timestamp < @until" else "" + + val behindCurrentTimeIntervalCondition = behindCurrentTimeIntervalConditionFor(behindCurrentTime) + + val selectColumns = + if (backtracking) + "SELECT TOP(@limit) persistence_id, revision, db_timestamp, SYSUTCDATETIME() AS read_db_timestamp, state_ser_id " + else + "SELECT TOP(@limit) persistence_id, revision, db_timestamp, SYSUTCDATETIME() AS read_db_timestamp, state_ser_id, state_ser_manifest, state_payload " + + sql""" + $selectColumns + FROM $stateTable + WHERE entity_type = @entityType + AND ${sliceCondition(minSlice, maxSlice)} + AND db_timestamp >= @fromTimestamp $maxDbTimestampParamCondition $behindCurrentTimeIntervalCondition + ORDER BY db_timestamp, revision""" + } + + def readState(persistenceId: String): Future[Option[SerializedStateRow]] = { + val entityType = PersistenceId.extractEntityType(persistenceId) + r2dbcExecutor.selectOne(s"select [$persistenceId]")( + connection => + connection + .createStatement(selectStateSql(entityType)) + .bind("@persistenceId", persistenceId), + row => + SerializedStateRow( + persistenceId = persistenceId, + revision = row.get[java.lang.Long]("revision", classOf[java.lang.Long]), + dbTimestamp = fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime])), + readDbTimestamp = Instant.EPOCH, // not needed here + payload = getPayload(row), + serId = row.get[Integer]("state_ser_id", classOf[Integer]), + serManifest = row.get("state_ser_manifest", classOf[String]), + tags = Set.empty // tags not fetched in queries (yet) + )) + } + + private def getPayload(row: Row): Option[Array[Byte]] = { + val serId = row.get("state_ser_id", classOf[Integer]) + val rowPayload = row.getPayload("state_payload") + if (serId == 0 && (rowPayload == null || util.Arrays.equals(statePayloadCodec.nonePayload, rowPayload))) + None // delete marker + else + Option(rowPayload) + } + + def upsertState(state: SerializedStateRow, value: Any): Future[Done] = { + require(state.revision > 0) + + def bindTags(stmt: Statement, name: String): Statement = { + if (state.tags.isEmpty) + stmt.bindNull(name, classOf[String]) + else + stmt.bind(name, tagsToDb(state.tags)) + } + + def bindAdditionalColumns( + stmt: Statement, + additionalBindings: IndexedSeq[EvaluatedAdditionalColumnBindings]): Statement = { + additionalBindings.foreach { + case EvaluatedAdditionalColumnBindings(col, AdditionalColumn.BindValue(v)) => + stmt.bind(s"@${col.columnName}", v) + case EvaluatedAdditionalColumnBindings(col, AdditionalColumn.BindNull) => + stmt.bindNull(s"@${col.columnName}", col.fieldClass) + case EvaluatedAdditionalColumnBindings(_, AdditionalColumn.Skip) => + } + stmt + } + + def change = + new UpdatedDurableState[Any](state.persistenceId, state.revision, value, NoOffset, EmptyDbTimestamp.toEpochMilli) + + val entityType = PersistenceId.extractEntityType(state.persistenceId) + + val result = { + val additionalBindings = additionalColumns.get(entityType) match { + case None => Vector.empty[EvaluatedAdditionalColumnBindings] + case Some(columns) => + val slice = persistenceExt.sliceForPersistenceId(state.persistenceId) + val upsert = AdditionalColumn.Upsert(state.persistenceId, entityType, slice, state.revision, value) + columns.map(c => EvaluatedAdditionalColumnBindings(c, c.bind(upsert))) + } + + if (state.revision == 1) { + val slice = persistenceExt.sliceForPersistenceId(state.persistenceId) + + def insertStatement(connection: Connection): Statement = { + val stmt = connection + .createStatement(insertStateSql(entityType, additionalBindings)) + .bind("@slice", slice) + .bind("@entityType", entityType) + .bind("@persistenceId", state.persistenceId) + .bind("@revision", state.revision) + .bind("@stateSerId", state.serId) + .bind("@stateSerManifest", state.serManifest) + .bindPayloadOption("@statePayload", state.payload) + .bind("@now", nowLocalDateTime()) + bindTags(stmt, "@tags") + bindAdditionalColumns(stmt, additionalBindings) + } + + def recoverDataIntegrityViolation[A](f: Future[A]): Future[A] = + f.recoverWith { case _: R2dbcDataIntegrityViolationException => + Future.failed( + new IllegalStateException( + s"Insert failed: durable state for persistence id [${state.persistenceId}] already exists")) + } + + changeHandlers.get(entityType) match { + case None => + recoverDataIntegrityViolation(r2dbcExecutor.updateOne(s"insert [${state.persistenceId}]")(insertStatement)) + case Some(handler) => + r2dbcExecutor.withConnection(s"insert [${state.persistenceId}] with change handler") { connection => + for { + updatedRows <- recoverDataIntegrityViolation(R2dbcExecutor.updateOneInTx(insertStatement(connection))) + _ <- processChange(handler, connection, change) + } yield updatedRows + } + } + } else { + val previousRevision = state.revision - 1 + + def updateStatement(connection: Connection): Statement = { + + val query = updateStateSql(entityType, updateTags = true, additionalBindings) + val stmt = connection + .createStatement(query) + .bind("@revision", state.revision) + .bind("@stateSerId", state.serId) + .bind("@stateSerManifest", state.serManifest) + .bindPayloadOption("@statePayload", state.payload) + .bind("@now", nowLocalDateTime()) + .bind("@persistenceId", state.persistenceId) + bindTags(stmt, "@tags") + bindAdditionalColumns(stmt, additionalBindings) + + if (settings.durableStateAssertSingleWriter) { + stmt.bind("@previousRevision", previousRevision) + } + + stmt + } + + changeHandlers.get(entityType) match { + case None => + r2dbcExecutor.updateOne(s"update [${state.persistenceId}]")(updateStatement) + case Some(handler) => + r2dbcExecutor.withConnection(s"update [${state.persistenceId}] with change handler") { connection => + for { + updatedRows <- R2dbcExecutor.updateOneInTx(updateStatement(connection)) + _ <- if (updatedRows == 1) processChange(handler, connection, change) else FutureDone + } yield updatedRows + } + } + } + } + + result + .map { updatedRows => + if (updatedRows != 1) + throw new IllegalStateException( + s"Update failed: durable state for persistence id [${state.persistenceId}] could not be updated to revision [${state.revision}]") + else { + log.debug( + "Updated durable state for persistenceId [{}] to revision [{}]", + state.persistenceId, + state.revision) + Done + } + } + + } + + private def processChange( + handler: ChangeHandler[Any], + connection: Connection, + change: DurableStateChange[Any]): Future[Done] = { + val session = new R2dbcSession(connection) + + def excMessage(cause: Throwable): String = { + val (changeType, revision) = change match { + case upd: UpdatedDurableState[_] => "update" -> upd.revision + case del: DeletedDurableState[_] => "delete" -> del.revision + } + s"Change handler $changeType failed for [${change.persistenceId}] revision [$revision], due to ${cause.getMessage}" + } + + try handler.process(session, change).recoverWith { case cause => + Future.failed[Done](new ChangeHandlerException(excMessage(cause), cause)) + } catch { + case NonFatal(cause) => throw new ChangeHandlerException(excMessage(cause), cause) + } + } + + def deleteState(persistenceId: String, revision: Long): Future[Done] = { + if (revision == 0) { + hardDeleteState(persistenceId) + } else { + val result = { + val entityType = PersistenceId.extractEntityType(persistenceId) + def change = + new DeletedDurableState[Any](persistenceId, revision, NoOffset, EmptyDbTimestamp.toEpochMilli) + if (revision == 1) { + val slice = persistenceExt.sliceForPersistenceId(persistenceId) + + def insertDeleteMarkerStatement(connection: Connection): Statement = { + connection + .createStatement(insertStateSql(entityType, Vector.empty)) + .bind("@slice", slice) + .bind("@entityType", entityType) + .bind("@persistenceId", persistenceId) + .bind("@revision", revision) + .bind("@stateSerId", 0) + .bind("@stateSerManifest", "") + .bindPayloadOption("@statePayload", None) + .bindNull("@tags", classOf[String]) + .bind("@now", nowLocalDateTime()) + } + + def recoverDataIntegrityViolation[A](f: Future[A]): Future[A] = + f.recoverWith { case _: R2dbcDataIntegrityViolationException => + Future.failed(new IllegalStateException( + s"Insert delete marker with revision 1 failed: durable state for persistence id [$persistenceId] already exists")) + } + + val changeHandler = changeHandlers.get(entityType) + val changeHandlerHint = changeHandler.map(_ => " with change handler").getOrElse("") + + r2dbcExecutor.withConnection(s"insert delete marker [$persistenceId]$changeHandlerHint") { connection => + for { + updatedRows <- recoverDataIntegrityViolation( + R2dbcExecutor.updateOneInTx(insertDeleteMarkerStatement(connection))) + _ <- changeHandler match { + case None => FutureDone + case Some(handler) => processChange(handler, connection, change) + } + } yield updatedRows + } + + } else { + val previousRevision = revision - 1 + + def updateStatement(connection: Connection): Statement = { + connection + .createStatement( + updateStateSql(entityType, updateTags = false, Vector.empty) + ) // FIXME should the additional columns be cleared (null)? Then they must allow NULL + .bind("@revision", revision) + .bind("@stateSerId", 0) + .bind("@stateSerManifest", "") + .bindPayloadOption("@statePayload", None) + .bind("@now", nowLocalDateTime()) + .bind("@persistenceId", persistenceId) + .bind("@previousRevision", previousRevision) + } + + val changeHandler = changeHandlers.get(entityType) + val changeHandlerHint = changeHandler.map(_ => " with change handler").getOrElse("") + + r2dbcExecutor.withConnection(s"delete [$persistenceId]$changeHandlerHint") { connection => + for { + updatedRows <- R2dbcExecutor.updateOneInTx(updateStatement(connection)) + _ <- changeHandler match { + case None => FutureDone + case Some(handler) => processChange(handler, connection, change) + } + } yield updatedRows + } + } + } + + result.map { updatedRows => + if (updatedRows != 1) + throw new IllegalStateException( + s"Delete failed: durable state for persistence id [$persistenceId] could not be updated to revision [$revision]") + else { + log.debug("Deleted durable state for persistenceId [{}] to revision [{}]", persistenceId, revision) + Done + } + } + + } + } + + private def hardDeleteState(persistenceId: String): Future[Done] = { + val entityType = PersistenceId.extractEntityType(persistenceId) + + val changeHandler = changeHandlers.get(entityType) + val changeHandlerHint = changeHandler.map(_ => " with change handler").getOrElse("") + + val result = + r2dbcExecutor.withConnection(s"hard delete [$persistenceId]$changeHandlerHint") { connection => + for { + updatedRows <- R2dbcExecutor.updateOneInTx( + connection + .createStatement(hardDeleteStateSql(entityType)) + .bind("@persistenceId", persistenceId)) + _ <- changeHandler match { + case None => FutureDone + case Some(handler) => + val change = new DeletedDurableState[Any](persistenceId, 0L, NoOffset, EmptyDbTimestamp.toEpochMilli) + processChange(handler, connection, change) + } + } yield updatedRows + } + + if (log.isDebugEnabled()) + result.foreach(_ => log.debug("Hard deleted durable state for persistenceId [{}]", persistenceId)) + + result.map(_ => Done)(ExecutionContexts.parasitic) + } + + override def currentDbTimestamp(): Future[Instant] = Future.successful(nowInstant()) + + override def rowsBySlices( + entityType: String, + minSlice: Int, + maxSlice: Int, + fromTimestamp: Instant, + toTimestamp: Option[Instant], + behindCurrentTime: FiniteDuration, + backtracking: Boolean): Source[SerializedStateRow, NotUsed] = { + val result = r2dbcExecutor.select( + s"select stateBySlices [${settings.querySettings.bufferSize}, $entityType, $fromTimestamp, $toTimestamp, $minSlice - $maxSlice]")( + connection => { + val query = stateBySlicesRangeSql( + entityType, + maxDbTimestampParam = toTimestamp.isDefined, + behindCurrentTime, + backtracking, + minSlice, + maxSlice) + val stmt = connection + .createStatement(query) + .bind("@entityType", entityType) + .bind("@fromTimestamp", toDbTimestamp(fromTimestamp)) + + stmt.bind("@limit", settings.querySettings.bufferSize) + + if (behindCurrentTime > Duration.Zero) { + stmt.bind("@now", nowLocalDateTime()) + } + + toTimestamp.foreach(until => stmt.bind("@until", toDbTimestamp(until))) + + stmt + }, + row => + if (backtracking) { + val serId = row.get[Integer]("state_ser_id", classOf[Integer]) + // would have been better with an explicit deleted column as in the journal table, + // but not worth the schema change + val isDeleted = serId == 0 + + SerializedStateRow( + persistenceId = row.get("persistence_id", classOf[String]), + revision = row.get[java.lang.Long]("revision", classOf[java.lang.Long]), + dbTimestamp = fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime])), + readDbTimestamp = fromDbTimestamp(row.get("read_db_timestamp", classOf[LocalDateTime])), + // payload = null => lazy loaded for backtracking (ugly, but not worth changing UpdatedDurableState in Akka) + // payload = None => DeletedDurableState (no lazy loading) + payload = if (isDeleted) None else null, + serId = 0, + serManifest = "", + tags = Set.empty // tags not fetched in queries (yet) + ) + } else + SerializedStateRow( + persistenceId = row.get("persistence_id", classOf[String]), + revision = row.get[java.lang.Long]("revision", classOf[java.lang.Long]), + dbTimestamp = fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime])), + readDbTimestamp = fromDbTimestamp(row.get("read_db_timestamp", classOf[LocalDateTime])), + payload = getPayload(row), + serId = row.get[Integer]("state_ser_id", classOf[Integer]), + serManifest = row.get("state_ser_manifest", classOf[String]), + tags = Set.empty // tags not fetched in queries (yet) + )) + + if (log.isDebugEnabled) + result.foreach(rows => + log.debugN("Read [{}] durable states from slices [{} - {}]", rows.size, minSlice, maxSlice)) + + Source.futureSource(result.map(Source(_))).mapMaterializedValue(_ => NotUsed) + } + + def persistenceIds(afterId: Option[String], limit: Long): Source[String, NotUsed] = { + if (settings.durableStateTableByEntityTypeWithSchema.isEmpty) + persistenceIds(afterId, limit, settings.durableStateTableWithSchema) + else { + def readFromCustomTables( + acc: immutable.IndexedSeq[String], + remainingTables: Vector[String]): Future[immutable.IndexedSeq[String]] = { + if (acc.size >= limit) { + Future.successful(acc) + } else if (remainingTables.isEmpty) { + Future.successful(acc) + } else { + readPersistenceIds(afterId, limit, remainingTables.head).flatMap { ids => + readFromCustomTables(acc ++ ids, remainingTables.tail) + } + } + } + + val customTables = settings.durableStateTableByEntityTypeWithSchema.toVector.sortBy(_._1).map(_._2) + val ids = for { + fromDefaultTable <- readPersistenceIds(afterId, limit, settings.durableStateTableWithSchema) + fromCustomTables <- readFromCustomTables(Vector.empty, customTables) + } yield { + (fromDefaultTable ++ fromCustomTables).sorted + } + + Source.futureSource(ids.map(Source(_))).take(limit).mapMaterializedValue(_ => NotUsed) + } + } + + def persistenceIds(afterId: Option[String], limit: Long, table: String): Source[String, NotUsed] = { + val result = readPersistenceIds(afterId, limit, table) + + Source.futureSource(result.map(Source(_))).mapMaterializedValue(_ => NotUsed) + } + + private def readPersistenceIds( + afterId: Option[String], + limit: Long, + table: String): Future[immutable.IndexedSeq[String]] = { + val result = r2dbcExecutor.select(s"select persistenceIds")( + connection => + afterId match { + case Some(after) => + connection + .createStatement(allPersistenceIdsAfterSql(table)) + .bind("@persistenceId", after) + .bind("@limit", limit) + case None => + connection + .createStatement(allPersistenceIdsSql(table)) + .bind("@limit", limit) + }, + row => row.get("persistence_id", classOf[String])) + + if (log.isDebugEnabled) + result.foreach(rows => log.debug("Read [{}] persistence ids", rows.size)) + result + } + + def persistenceIds(entityType: String, afterId: Option[String], limit: Long): Source[String, NotUsed] = { + val table = settings.getDurableStateTableWithSchema(entityType) + val likeStmtPostfix = PersistenceId.DefaultSeparator + "%" + val result = r2dbcExecutor.select(s"select persistenceIds by entity type")( + connection => + afterId match { + case Some(afterPersistenceId) => + connection + .createStatement(persistenceIdsForEntityTypeAfterSql(table)) + .bind("@persistenceIdLike", entityType + likeStmtPostfix) + .bind("@persistenceId", afterPersistenceId) + .bind("@limit", limit) + case None => + connection + .createStatement(persistenceIdsForEntityTypeSql(table)) + .bind("@persistenceIdLike", entityType + likeStmtPostfix) + .bind("@limit", limit) + }, + row => row.get("persistence_id", classOf[String])) + + if (log.isDebugEnabled) + result.foreach(rows => log.debug("Read [{}] persistence ids by entity type [{}]", rows.size, entityType)) + + Source.futureSource(result.map(Source(_))).mapMaterializedValue(_ => NotUsed) + } + + /** + * Is this correct? + */ + override def countBucketsMayChange: Boolean = true + + override def countBuckets( + entityType: String, + minSlice: Int, + maxSlice: Int, + fromTimestamp: Instant, + limit: Int): Future[Seq[Bucket]] = { + + val toTimestamp = { + val nowTimestamp = nowInstant() + if (fromTimestamp == Instant.EPOCH) + nowTimestamp + else { + // max buckets, just to have some upper bound + val t = fromTimestamp.plusSeconds(Buckets.BucketDurationSeconds * limit + Buckets.BucketDurationSeconds) + if (t.isAfter(nowTimestamp)) nowTimestamp else t + } + } + + val result = r2dbcExecutor.select(s"select bucket counts [$entityType $minSlice - $maxSlice]")( + connection => + connection + .createStatement(selectBucketsSql(entityType, minSlice, maxSlice)) + .bind("@entityType", entityType) + .bind("@fromTimestamp", toDbTimestamp(fromTimestamp)) + .bind("@toTimestamp", toDbTimestamp(toTimestamp)) + .bind("@limit", limit), + row => { + val bucketStartEpochSeconds = row.get("bucket", classOf[java.lang.Long]).toLong * 10 + val count = row.get[java.lang.Long]("count", classOf[java.lang.Long]).toLong + Bucket(bucketStartEpochSeconds, count) + }) + + if (log.isDebugEnabled) + result.foreach(rows => log.debugN("Read [{}] bucket counts from slices [{} - {}]", rows.size, minSlice, maxSlice)) + + result + + } +} diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerJournalDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerJournalDao.scala new file mode 100644 index 00000000..807cc025 --- /dev/null +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerJournalDao.scala @@ -0,0 +1,312 @@ +/* + * Copyright (C) 2022 - 2023 Lightbend Inc. + */ + +package akka.persistence.r2dbc.internal.sqlserver + +import akka.actor.typed.ActorSystem +import akka.actor.typed.scaladsl.LoggerOps +import akka.annotation.InternalApi +import akka.dispatch.ExecutionContexts +import akka.persistence.Persistence +import akka.persistence.r2dbc.R2dbcSettings +import akka.persistence.r2dbc.internal.PayloadCodec.RichStatement +import akka.persistence.r2dbc.internal.JournalDao +import akka.persistence.r2dbc.internal.PayloadCodec +import akka.persistence.r2dbc.internal.R2dbcExecutor +import akka.persistence.r2dbc.internal.SerializedEventMetadata +import akka.persistence.r2dbc.internal.Sql.Interpolation +import akka.persistence.typed.PersistenceId +import io.r2dbc.spi.Connection +import io.r2dbc.spi.ConnectionFactory +import io.r2dbc.spi.Row +import io.r2dbc.spi.Statement +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +import java.time.Instant +import java.time.LocalDateTime +import scala.concurrent.ExecutionContext +import scala.concurrent.Future + +/** + * INTERNAL API + */ +@InternalApi +private[r2dbc] object SqlServerJournalDao { + private val log: Logger = LoggerFactory.getLogger(classOf[SqlServerJournalDao]) + + val TRUE = 1 + + def readMetadata(row: Row): Option[SerializedEventMetadata] = { + row.get("meta_payload", classOf[Array[Byte]]) match { + case null => None + case metaPayload => + Some( + SerializedEventMetadata( + serId = row.get[Integer]("meta_ser_id", classOf[Integer]), + serManifest = row.get("meta_ser_manifest", classOf[String]), + metaPayload)) + } + } + +} + +/** + * INTERNAL API + * + * Class for doing db interaction outside of an actor to avoid mistakes in future callbacks + */ +@InternalApi +private[r2dbc] class SqlServerJournalDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit + ec: ExecutionContext, + system: ActorSystem[_]) + extends JournalDao { + + private val helper = SqlServerDialectHelper(settings.connectionFactorySettings.config) + import helper._ + + require(settings.useAppTimestamp, "SqlServer requires akka.persistence.r2dbc.use-app-timestamp=on") + require(settings.useAppTimestamp, "SqlServer requires akka.persistence.r2dbc.db-timestamp-monotonic-increasing = off") + + import JournalDao.SerializedJournalRow + protected def log: Logger = SqlServerJournalDao.log + + private val persistenceExt = Persistence(system) + + protected val r2dbcExecutor = + new R2dbcExecutor( + connectionFactory, + log, + settings.logDbCallsExceeding, + settings.connectionFactorySettings.poolSettings.closeCallsExceeding)(ec, system) + + protected val journalTable = settings.journalTableWithSchema + protected implicit val journalPayloadCodec: PayloadCodec = settings.journalPayloadCodec + + private val insertEventWithParameterTimestampSql = sql""" + INSERT INTO $journalTable + (slice, entity_type, persistence_id, seq_nr, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, tags, meta_ser_id, meta_ser_manifest, meta_payload, db_timestamp) + OUTPUT inserted.db_timestamp + VALUES (@slice, @entityType, @persistenceId, @seqNr, @writer, @adapterManifest, @eventSerId, @eventSerManifest, @eventPayload, @tags, @metaSerId, @metaSerManifest, @metaSerPayload, @dbTimestamp)""" + + private val selectHighestSequenceNrSql = sql""" + SELECT MAX(seq_nr) as max_seq_nr from $journalTable + WHERE persistence_id = @persistenceId AND seq_nr >= @seqNr""" + + private val selectLowestSequenceNrSql = + sql""" + SELECT MIN(seq_nr) as min_seq_nr from $journalTable + WHERE persistence_id = @persistenceId""" + + private val deleteEventsSql = sql""" + DELETE FROM $journalTable + WHERE persistence_id = @persistenceId AND seq_nr >= @from AND seq_nr <= @to""" + + private val insertDeleteMarkerSql = sql""" + INSERT INTO $journalTable(slice, entity_type, persistence_id, seq_nr, db_timestamp, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, deleted) + VALUES(@slice, @entityType, @persistenceId, @deleteMarkerSeqNr, @now, @writer, @adapterManifest, @eventSerId, @eventSerManifest, @eventPayload, @deleted )""" + + /** + * All events must be for the same persistenceId. + * + * The returned timestamp should be the `db_timestamp` column and it is used in published events when that feature is + * enabled. + * + * Note for implementing future database dialects: If a database dialect can't efficiently return the timestamp column + * it can return `JournalDao.EmptyDbTimestamp` when the pub-sub feature is disabled. When enabled it would have to use + * a select (in same transaction). + */ + def writeEvents(events: Seq[SerializedJournalRow]): Future[Instant] = { + require(events.nonEmpty) + + // it's always the same persistenceId for all events + val persistenceId = events.head.persistenceId + + def bind(stmt: Statement, write: SerializedJournalRow): Statement = { + stmt + .bind("@slice", write.slice) + .bind("@entityType", write.entityType) + .bind("@persistenceId", write.persistenceId) + .bind("@seqNr", write.seqNr) + .bind("@writer", write.writerUuid) + .bind("@adapterManifest", "") // FIXME event adapter + .bind("@eventSerId", write.serId) + .bind("@eventSerManifest", write.serManifest) + .bindPayload("@eventPayload", write.payload.get) + + if (write.tags.isEmpty) + stmt.bindNull("@tags", classOf[String]) + else + stmt.bind("@tags", tagsToDb(write.tags)) + + // optional metadata + write.metadata match { + case Some(m) => + stmt + .bind("@metaSerId", m.serId) + .bind("@metaSerManifest", m.serManifest) + .bind("@metaSerPayload", m.payload) + case None => + stmt + .bindNull("@metaSerId", classOf[Integer]) + .bindNull("@metaSerManifest", classOf[String]) + .bindNull("@metaSerPayload", classOf[Array[Byte]]) + } + stmt.bind("@dbTimestamp", toDbTimestamp(write.dbTimestamp)) + } + + val insertSql = insertEventWithParameterTimestampSql + + val totalEvents = events.size + if (totalEvents == 1) { + val result = r2dbcExecutor.updateOneReturning(s"insert [$persistenceId]")( + connection => bind(connection.createStatement(insertSql), events.head), + row => fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime]))) + if (log.isDebugEnabled()) + result.foreach { _ => + log.debug("Wrote [{}] events for persistenceId [{}]", 1, events.head.persistenceId) + } + result + } else { + val result = r2dbcExecutor.updateInBatchReturning(s"batch insert [$persistenceId], [$totalEvents] events")( + connection => + events.foldLeft(connection.createStatement(insertSql)) { (stmt, write) => + stmt.add() + bind(stmt, write) + }, + row => fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime]))) + if (log.isDebugEnabled()) + result.foreach { _ => + log.debug("Wrote [{}] events for persistenceId [{}]", totalEvents, events.head.persistenceId) + } + result.map(_.head)(ExecutionContexts.parasitic) + } + } + + def readHighestSequenceNr(persistenceId: String, fromSequenceNr: Long): Future[Long] = { + val result = r2dbcExecutor + .select(s"select highest seqNr [$persistenceId]")( + connection => + connection + .createStatement(selectHighestSequenceNrSql) + .bind("@persistenceId", persistenceId) + .bind("@seqNr", fromSequenceNr), + row => { + val seqNr = row.get("max_seq_nr", classOf[java.lang.Long]) + if (seqNr eq null) 0L else seqNr.longValue + }) + .map(r => if (r.isEmpty) 0L else r.head)(ExecutionContexts.parasitic) + + if (log.isDebugEnabled) + result.foreach(seqNr => log.debug("Highest sequence nr for persistenceId [{}]: [{}]", persistenceId, seqNr)) + + result + } + + def readLowestSequenceNr(persistenceId: String): Future[Long] = { + val result = r2dbcExecutor + .select(s"select lowest seqNr [$persistenceId]")( + connection => + connection + .createStatement(selectLowestSequenceNrSql) + .bind("@persistenceId", persistenceId), + row => { + val seqNr = row.get("min_seq_nr", classOf[java.lang.Long]) + if (seqNr eq null) 0L else seqNr.longValue + }) + .map(r => if (r.isEmpty) 0L else r.head)(ExecutionContexts.parasitic) + + if (log.isDebugEnabled) + result.foreach(seqNr => log.debug("Lowest sequence nr for persistenceId [{}]: [{}]", persistenceId, seqNr)) + + result + } + + private def highestSeqNrForDelete(persistenceId: String, toSequenceNr: Long): Future[Long] = { + if (toSequenceNr == Long.MaxValue) + readHighestSequenceNr(persistenceId, 0L) + else + Future.successful(toSequenceNr) + } + + private def lowestSequenceNrForDelete(persistenceId: String, toSeqNr: Long, batchSize: Int): Future[Long] = { + if (toSeqNr <= batchSize) { + Future.successful(1L) + } else { + readLowestSequenceNr(persistenceId) + } + } + + def deleteEventsTo(persistenceId: String, toSequenceNr: Long, resetSequenceNumber: Boolean): Future[Unit] = { + + def insertDeleteMarkerStmt(deleteMarkerSeqNr: Long, connection: Connection): Statement = { + val entityType = PersistenceId.extractEntityType(persistenceId) + val slice = persistenceExt.sliceForPersistenceId(persistenceId) + connection + .createStatement(insertDeleteMarkerSql) + .bind("@slice", slice) + .bind("@entityType", entityType) + .bind("@persistenceId", persistenceId) + .bind("@deleteMarkerSeqNr", deleteMarkerSeqNr) + .bind("@writer", "") + .bind("@adapterManifest", "") + .bind("@eventSerId", 0) + .bind("@eventSerManifest", "") + .bindPayloadOption("@eventPayload", None) + .bind("@deleted", SqlServerJournalDao.TRUE) + .bind("@now", nowLocalDateTime()) + } + + def deleteBatch(from: Long, to: Long, lastBatch: Boolean): Future[Unit] = { + (if (lastBatch && !resetSequenceNumber) { + r2dbcExecutor + .update(s"delete [$persistenceId] and insert marker") { connection => + Vector( + connection + .createStatement(deleteEventsSql) + .bind("@persistenceId", persistenceId) + .bind("@from", from) + .bind("@to", to), + insertDeleteMarkerStmt(to, connection)) + } + .map(_.head) + } else { + r2dbcExecutor + .updateOne(s"delete [$persistenceId]") { connection => + connection + .createStatement(deleteEventsSql) + .bind(0, persistenceId) + .bind(1, from) + .bind(2, to) + } + }).map(deletedRows => + if (log.isDebugEnabled) { + log.debugN( + "Deleted [{}] events for persistenceId [{}], from seq num [{}] to [{}]", + deletedRows, + persistenceId, + from, + to) + })(ExecutionContexts.parasitic) + } + + val batchSize = settings.cleanupSettings.eventsJournalDeleteBatchSize + + def deleteInBatches(from: Long, maxTo: Long): Future[Unit] = { + if (from + batchSize > maxTo) { + deleteBatch(from, maxTo, true) + } else { + val to = from + batchSize - 1 + deleteBatch(from, to, false).flatMap(_ => deleteInBatches(to + 1, maxTo)) + } + } + + for { + toSeqNr <- highestSeqNrForDelete(persistenceId, toSequenceNr) + fromSeqNr <- lowestSequenceNrForDelete(persistenceId, toSeqNr, batchSize) + _ <- deleteInBatches(fromSeqNr, toSeqNr) + } yield () + } + +} diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerQueryDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerQueryDao.scala new file mode 100644 index 00000000..c6059ffc --- /dev/null +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerQueryDao.scala @@ -0,0 +1,396 @@ +/* + * Copyright (C) 2022 - 2023 Lightbend Inc. + */ + +package akka.persistence.r2dbc.internal.sqlserver + +import akka.NotUsed +import akka.actor.typed.ActorSystem +import akka.actor.typed.scaladsl.LoggerOps +import akka.annotation.InternalApi +import akka.persistence.r2dbc.R2dbcSettings +import akka.persistence.r2dbc.internal.BySliceQuery.Buckets +import akka.persistence.r2dbc.internal.BySliceQuery.Buckets.Bucket +import akka.persistence.r2dbc.internal.JournalDao.SerializedJournalRow +import akka.persistence.r2dbc.internal.PayloadCodec.RichRow +import akka.persistence.r2dbc.internal.InstantFactory +import akka.persistence.r2dbc.internal.PayloadCodec +import akka.persistence.r2dbc.internal.QueryDao +import akka.persistence.r2dbc.internal.R2dbcExecutor +import akka.persistence.r2dbc.internal.SerializedEventMetadata +import akka.persistence.r2dbc.internal.Sql.Interpolation +import akka.persistence.typed.PersistenceId +import akka.stream.scaladsl.Source +import io.r2dbc.spi.ConnectionFactory +import io.r2dbc.spi.Row +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +import java.time.Instant +import java.time.LocalDateTime +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.concurrent.duration.Duration +import scala.concurrent.duration.FiniteDuration + +/** + * INTERNAL API + */ +@InternalApi +private[r2dbc] object SqlServerQueryDao { + private val log: Logger = LoggerFactory.getLogger(classOf[SqlServerQueryDao]) + def setFromDb[T](array: Array[T]): Set[T] = array match { + case null => Set.empty[T] + case entries => entries.toSet + } + + def readMetadata(row: Row): Option[SerializedEventMetadata] = { + row.get("meta_payload", classOf[Array[Byte]]) match { + case null => None + case metaPayload => + Some( + SerializedEventMetadata( + serId = row.get[Integer]("meta_ser_id", classOf[Integer]), + serManifest = row.get("meta_ser_manifest", classOf[String]), + metaPayload)) + } + } +} + +/** + * INTERNAL API + */ +@InternalApi +private[r2dbc] class SqlServerQueryDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit + ec: ExecutionContext, + system: ActorSystem[_]) + extends QueryDao { + import SqlServerQueryDao.readMetadata + + private val helper = SqlServerDialectHelper(settings.connectionFactorySettings.config) + import helper._ + + private val FALSE = "0" + + protected def log: Logger = SqlServerQueryDao.log + protected val journalTable = settings.journalTableWithSchema + protected implicit val journalPayloadCodec: PayloadCodec = settings.journalPayloadCodec + + protected def eventsBySlicesRangeSql( + toDbTimestampParam: Boolean, + behindCurrentTime: FiniteDuration, + backtracking: Boolean, + minSlice: Int, + maxSlice: Int): String = { + + def toDbTimestampParamCondition = + if (toDbTimestampParam) "AND db_timestamp <= @until" else "" + + def localNow = toDbTimestamp(nowInstant()) + + def behindCurrentTimeIntervalCondition = + if (behindCurrentTime > Duration.Zero) + s"AND db_timestamp < DATEADD(ms, -${behindCurrentTime.toMillis}, CAST('$localNow' as datetime2(6)))" + else "" + + val selectColumns = { + if (backtracking) + "SELECT TOP(@limit) slice, persistence_id, seq_nr, db_timestamp, SYSUTCDATETIME() AS read_db_timestamp, tags, event_ser_id " + else + "SELECT TOP(@limit) slice, persistence_id, seq_nr, db_timestamp, SYSUTCDATETIME() AS read_db_timestamp, tags, event_ser_id, event_ser_manifest, event_payload, meta_ser_id, meta_ser_manifest, meta_payload " + } + + sql""" + $selectColumns + FROM $journalTable + WHERE entity_type = @entityType + AND ${sliceCondition(minSlice, maxSlice)} + AND db_timestamp >= @from $toDbTimestampParamCondition $behindCurrentTimeIntervalCondition + AND deleted = $FALSE + ORDER BY db_timestamp, seq_nr""" + } + + protected def sliceCondition(minSlice: Int, maxSlice: Int): String = + s"slice in (${(minSlice to maxSlice).mkString(",")})" + + private def selectBucketsSql(minSlice: Int, maxSlice: Int): String = { + sql""" + SELECT TOP(@limit) bucket, count(*) as count from + (select DATEDIFF(s,'1970-01-01 00:00:00', db_timestamp)/10 as bucket + FROM $journalTable + WHERE entity_type = @entityType + AND ${sliceCondition(minSlice, maxSlice)} + AND db_timestamp >= @fromTimestamp AND db_timestamp <= @toTimestamp + AND deleted = $FALSE) as sub + GROUP BY bucket ORDER BY bucket + """ + } + + private val selectTimestampOfEventSql = sql""" + SELECT db_timestamp FROM $journalTable + WHERE persistence_id = @persistenceId AND seq_nr = @seqNr AND deleted = $FALSE""" + + protected val selectOneEventSql = sql""" + SELECT slice, entity_type, db_timestamp, SYSUTCDATETIME() AS read_db_timestamp, event_ser_id, event_ser_manifest, event_payload, meta_ser_id, meta_ser_manifest, meta_payload, tags + FROM $journalTable + WHERE persistence_id = @persistenceId AND seq_nr = @seqNr AND deleted = $FALSE""" + + protected val selectOneEventWithoutPayloadSql = sql""" + SELECT slice, entity_type, db_timestamp, SYSUTCDATETIME() AS read_db_timestamp, event_ser_id, event_ser_manifest, meta_ser_id, meta_ser_manifest, meta_payload, tags + FROM $journalTable + WHERE persistence_id = @persistenceId AND seq_nr = @seqNr AND deleted = $FALSE""" + + protected val selectEventsSql = sql""" + SELECT TOP(@limit) slice, entity_type, persistence_id, seq_nr, db_timestamp, SYSUTCDATETIME() AS read_db_timestamp, event_ser_id, event_ser_manifest, event_payload, writer, adapter_manifest, meta_ser_id, meta_ser_manifest, meta_payload, tags + from $journalTable + WHERE persistence_id = @persistenceId AND seq_nr >= @from AND seq_nr <= @to + AND deleted = $FALSE + ORDER BY seq_nr""" + + private val allPersistenceIdsSql = + sql"SELECT DISTINCT TOP(@limit) persistence_id from $journalTable ORDER BY persistence_id" + + private val persistenceIdsForEntityTypeSql = + sql"SELECT DISTINCT TOP(@limit) persistence_id from $journalTable WHERE persistence_id LIKE @entityTypeLike ORDER BY persistence_id" + + private val allPersistenceIdsAfterSql = + sql"SELECT DISTINCT TOP(@limit) persistence_id from $journalTable WHERE persistence_id > @after ORDER BY persistence_id" + + private val persistenceIdsForEntityTypeAfterSql = + sql"SELECT DISTINCT TOP(@limit) persistence_id from $journalTable WHERE persistence_id LIKE @entityTypeLike AND persistence_id > @after ORDER BY persistence_id" + + protected val r2dbcExecutor = new R2dbcExecutor( + connectionFactory, + log, + settings.logDbCallsExceeding, + settings.connectionFactorySettings.poolSettings.closeCallsExceeding)(ec, system) + + override def currentDbTimestamp(): Future[Instant] = Future.successful(nowInstant()) + + override def rowsBySlices( + entityType: String, + minSlice: Int, + maxSlice: Int, + fromTimestamp: Instant, + toTimestamp: Option[Instant], + behindCurrentTime: FiniteDuration, + backtracking: Boolean): Source[SerializedJournalRow, NotUsed] = { + val result = r2dbcExecutor.select(s"select eventsBySlices [$minSlice - $maxSlice]")( + connection => { + val stmt = connection + .createStatement( + eventsBySlicesRangeSql( + toDbTimestampParam = toTimestamp.isDefined, + behindCurrentTime, + backtracking, + minSlice, + maxSlice)) + .bind("@entityType", entityType) + .bind("@from", toDbTimestamp(fromTimestamp)) + toTimestamp.foreach(t => stmt.bind("@until", toDbTimestamp(t))) + stmt.bind("@limit", settings.querySettings.bufferSize) + }, + row => + if (backtracking) + SerializedJournalRow( + slice = row.get[Integer]("slice", classOf[Integer]), + entityType, + persistenceId = row.get("persistence_id", classOf[String]), + seqNr = row.get[java.lang.Long]("seq_nr", classOf[java.lang.Long]), + dbTimestamp = fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime])), + readDbTimestamp = fromDbTimestamp(row.get("read_db_timestamp", classOf[LocalDateTime])), + payload = None, // lazy loaded for backtracking + serId = row.get[Integer]("event_ser_id", classOf[Integer]), + serManifest = "", + writerUuid = "", // not need in this query + tags = tagsFromDb(row), + metadata = None) + else + SerializedJournalRow( + slice = row.get[Integer]("slice", classOf[Integer]), + entityType, + persistenceId = row.get("persistence_id", classOf[String]), + seqNr = row.get[java.lang.Long]("seq_nr", classOf[java.lang.Long]), + dbTimestamp = fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime])), + readDbTimestamp = fromDbTimestamp(row.get("read_db_timestamp", classOf[LocalDateTime])), + payload = Some(row.getPayload("event_payload")), + serId = row.get[Integer]("event_ser_id", classOf[Integer]), + serManifest = row.get("event_ser_manifest", classOf[String]), + writerUuid = "", // not need in this query + tags = tagsFromDb(row), + metadata = readMetadata(row))) + + if (log.isDebugEnabled) + result.foreach(rows => log.debugN("Read [{}] events from slices [{} - {}]", rows.size, minSlice, maxSlice)) + + Source.futureSource(result.map(Source(_))).mapMaterializedValue(_ => NotUsed) + } + + override def countBuckets( + entityType: String, + minSlice: Int, + maxSlice: Int, + fromTimestamp: Instant, + limit: Int): Future[Seq[Bucket]] = { + + val toTimestamp = { + val now = InstantFactory.now() // not important to use database time + if (fromTimestamp == Instant.EPOCH) + now + else { + // max buckets, just to have some upper bound + val t = fromTimestamp.plusSeconds(Buckets.BucketDurationSeconds * limit + Buckets.BucketDurationSeconds) + if (t.isAfter(now)) now else t + } + } + + val result = r2dbcExecutor.select(s"select bucket counts [$minSlice - $maxSlice]")( + connection => + connection + .createStatement(selectBucketsSql(minSlice, maxSlice)) + .bind("@entityType", entityType) + .bind("@fromTimestamp", toDbTimestamp(fromTimestamp)) + .bind("@toTimestamp", toDbTimestamp(toTimestamp)) + .bind("@limit", limit), + row => { + val bucketStartEpochSeconds = row.get[java.lang.Long]("bucket", classOf[java.lang.Long]).toLong * 10 + val count = row.get[java.lang.Long]("count", classOf[java.lang.Long]).toLong + Bucket(bucketStartEpochSeconds, count) + }) + + if (log.isDebugEnabled) + result.foreach(rows => log.debugN("Read [{}] bucket counts from slices [{} - {}]", rows.size, minSlice, maxSlice)) + + result + } + + /** + * Events are append only + */ + override def countBucketsMayChange: Boolean = false + + override def timestampOfEvent(persistenceId: String, seqNr: Long): Future[Option[Instant]] = { + r2dbcExecutor.selectOne("select timestampOfEvent")( + connection => + connection + .createStatement(selectTimestampOfEventSql) + .bind("@persistenceId", persistenceId) + .bind("@seqNr", seqNr), + row => fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime]))) + } + + override def loadEvent( + persistenceId: String, + seqNr: Long, + includePayload: Boolean): Future[Option[SerializedJournalRow]] = + r2dbcExecutor.selectOne(s"select one event ($persistenceId, $seqNr, $includePayload)")( + connection => { + val selectSql = if (includePayload) selectOneEventSql else selectOneEventWithoutPayloadSql + connection + .createStatement(selectSql) + .bind("@persistenceId", persistenceId) + .bind("@seqNr", seqNr) + }, + row => { + val payload = + if (includePayload) + Some(row.getPayload("event_payload")) + else None + SerializedJournalRow( + slice = row.get[Integer]("slice", classOf[Integer]), + entityType = row.get("entity_type", classOf[String]), + persistenceId, + seqNr, + dbTimestamp = fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime])), + readDbTimestamp = fromDbTimestamp(row.get("read_db_timestamp", classOf[LocalDateTime])), + payload, + serId = row.get[Integer]("event_ser_id", classOf[Integer]), + serManifest = row.get("event_ser_manifest", classOf[String]), + writerUuid = "", // not need in this query + tags = tagsFromDb(row), + metadata = readMetadata(row)) + }) + + override def eventsByPersistenceId( + persistenceId: String, + fromSequenceNr: Long, + toSequenceNr: Long): Source[SerializedJournalRow, NotUsed] = { + + val result = r2dbcExecutor.select(s"select eventsByPersistenceId [$persistenceId]")( + connection => + connection + .createStatement(selectEventsSql) + .bind("@persistenceId", persistenceId) + .bind("@from", fromSequenceNr) + .bind("@to", toSequenceNr) + .bind("@limit", settings.querySettings.bufferSize), + row => + SerializedJournalRow( + slice = row.get[Integer]("slice", classOf[Integer]), + entityType = row.get("entity_type", classOf[String]), + persistenceId = row.get("persistence_id", classOf[String]), + seqNr = row.get[java.lang.Long]("seq_nr", classOf[java.lang.Long]), + dbTimestamp = fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime])), + readDbTimestamp = fromDbTimestamp(row.get("read_db_timestamp", classOf[LocalDateTime])), + payload = Some(row.getPayload("event_payload")), + serId = row.get[Integer]("event_ser_id", classOf[Integer]), + serManifest = row.get("event_ser_manifest", classOf[String]), + writerUuid = row.get("writer", classOf[String]), + tags = tagsFromDb(row), + metadata = readMetadata(row))) + + if (log.isDebugEnabled) + result.foreach(rows => log.debug("Read [{}] events for persistenceId [{}]", rows.size, persistenceId)) + + Source.futureSource(result.map(Source(_))).mapMaterializedValue(_ => NotUsed) + } + + override def persistenceIds(entityType: String, afterId: Option[String], limit: Long): Source[String, NotUsed] = { + val likeStmtPostfix = PersistenceId.DefaultSeparator + "%" + + val result = r2dbcExecutor.select(s"select persistenceIds by entity type")( + connection => + afterId match { + case Some(after) => + connection + .createStatement(persistenceIdsForEntityTypeAfterSql) + .bind("@entityTypeLike", entityType + likeStmtPostfix) + .bind("@after", after) + .bind("@limit", limit) + case None => + connection + .createStatement(persistenceIdsForEntityTypeSql) + .bind("@entityTypeLike", entityType + likeStmtPostfix) + .bind("@limit", limit) + }, + row => row.get("persistence_id", classOf[String])) + + if (log.isDebugEnabled) + result.foreach(rows => log.debug("Read [{}] persistence ids by entity type [{}]", rows.size, entityType)) + + Source.futureSource(result.map(Source(_))).mapMaterializedValue(_ => NotUsed) + } + + override def persistenceIds(afterId: Option[String], limit: Long): Source[String, NotUsed] = { + val result = r2dbcExecutor.select(s"select persistenceIds")( + connection => + afterId match { + case Some(after) => + connection + .createStatement(allPersistenceIdsAfterSql) + .bind("@after", after) + .bind("@limit", limit) + case None => + connection + .createStatement(allPersistenceIdsSql) + .bind("@limit", limit) + }, + row => row.get("persistence_id", classOf[String])) + + if (log.isDebugEnabled) + result.foreach(rows => log.debug("Read [{}] persistence ids", rows.size)) + + Source.futureSource(result.map(Source(_))).mapMaterializedValue(_ => NotUsed) + } + +} diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerSnapshotDao.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerSnapshotDao.scala new file mode 100644 index 00000000..14f2c3ad --- /dev/null +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerSnapshotDao.scala @@ -0,0 +1,397 @@ +/* + * Copyright (C) 2022 - 2023 Lightbend Inc. + */ + +package akka.persistence.r2dbc.internal.sqlserver + +import akka.NotUsed +import akka.actor.typed.ActorSystem +import akka.actor.typed.scaladsl.LoggerOps +import akka.annotation.InternalApi +import akka.dispatch.ExecutionContexts +import akka.persistence.SnapshotSelectionCriteria +import akka.persistence.r2dbc.R2dbcSettings +import akka.persistence.r2dbc.internal.BySliceQuery.Buckets +import akka.persistence.r2dbc.internal.BySliceQuery.Buckets.Bucket +import akka.persistence.r2dbc.internal.PayloadCodec.RichRow +import akka.persistence.r2dbc.internal.PayloadCodec.RichStatement +import akka.persistence.r2dbc.internal.InstantFactory +import akka.persistence.r2dbc.internal.PayloadCodec +import akka.persistence.r2dbc.internal.R2dbcExecutor +import akka.persistence.r2dbc.internal.SnapshotDao +import akka.persistence.r2dbc.internal.Sql.Interpolation +import akka.persistence.typed.PersistenceId +import akka.stream.scaladsl.Source +import io.r2dbc.spi.ConnectionFactory +import io.r2dbc.spi.Row +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +import java.time.Instant +import java.time.LocalDateTime +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.concurrent.duration.FiniteDuration + +/** + * INTERNAL API + */ +private[r2dbc] object SqlServerSnapshotDao { + private val log: Logger = LoggerFactory.getLogger(classOf[SqlServerSnapshotDao]) +} + +/** + * INTERNAL API + */ +@InternalApi +private[r2dbc] class SqlServerSnapshotDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit + ec: ExecutionContext, + system: ActorSystem[_]) + extends SnapshotDao { + import SnapshotDao._ + + private val helper = SqlServerDialectHelper(settings.connectionFactorySettings.config) + import helper._ + + protected def log: Logger = SqlServerSnapshotDao.log + + protected val snapshotTable = settings.snapshotsTableWithSchema + private implicit val snapshotPayloadCodec: PayloadCodec = settings.snapshotPayloadCodec + protected val r2dbcExecutor = new R2dbcExecutor( + connectionFactory, + log, + settings.logDbCallsExceeding, + settings.connectionFactorySettings.poolSettings.closeCallsExceeding)(ec, system) + + protected def createUpsertSql: String = { + if (settings.querySettings.startFromSnapshotEnabled) + sql""" + UPDATE $snapshotTable SET + seq_nr = @seqNr, + db_timestamp = @dbTimestamp, + write_timestamp = @writeTimestamp, + snapshot = @snapshot, + ser_id = @serId, + tags = @tags, + ser_manifest = @serManifest, + meta_payload = @metaPayload, + meta_ser_id = @metaSerId, + meta_ser_manifest = @metaSerManifest + where persistence_id = @persistenceId + if @@ROWCOUNT = 0 + INSERT INTO $snapshotTable + (slice, entity_type, persistence_id, seq_nr, write_timestamp, snapshot, ser_id, ser_manifest, meta_payload, meta_ser_id, meta_ser_manifest, db_timestamp, tags) + VALUES (@slice, @entityType, @persistenceId, @seqNr, @writeTimestamp, @snapshot, @serId, @serManifest, @metaPayload, @metaSerId, @metaSerManifest, @dbTimestamp, @tags) + """ + else + sql""" + UPDATE $snapshotTable SET + seq_nr = @seqNr, + write_timestamp = @writeTimestamp, + snapshot = @snapshot, + ser_id = @serId, + tags = @tags, + ser_manifest = @serManifest, + meta_payload = @metaPayload, + meta_ser_id = @metaSerId, + meta_ser_manifest = @metaSerManifest + where persistence_id = @persistenceId + if @@ROWCOUNT = 0 + INSERT INTO $snapshotTable + (slice, entity_type, persistence_id, seq_nr, write_timestamp, snapshot, ser_id, ser_manifest, meta_payload, meta_ser_id, meta_ser_manifest, tags) + VALUES (@slice, @entityType, @persistenceId, @seqNr, @writeTimestamp, @snapshot, @serId, @serManifest, @metaPayload, @metaSerId, @metaSerManifest, @tags) + """ + } + + private val upsertSql = createUpsertSql + + private def selectSql(criteria: SnapshotSelectionCriteria, pid: String): String = { + val maxSeqNrCondition = + if (criteria.maxSequenceNr != Long.MaxValue) " AND seq_nr <= @maxSeqNr" + else "" + + val minSeqNrCondition = + if (criteria.minSequenceNr > 0L) " AND seq_nr >= @minSeqNr" + else "" + + val maxTimestampCondition = + if (criteria.maxTimestamp != Long.MaxValue) " AND write_timestamp <= @maxTimestamp" + else "" + + val minTimestampCondition = + if (criteria.minTimestamp != 0L) " AND write_timestamp >= @minTimestamp" + else "" + + if (settings.querySettings.startFromSnapshotEnabled) + sql""" + SELECT TOP(1) slice, persistence_id, seq_nr, db_timestamp, write_timestamp, snapshot, ser_id, ser_manifest, tags, meta_payload, meta_ser_id, meta_ser_manifest + FROM $snapshotTable + WHERE persistence_id = @persistenceId + $maxSeqNrCondition $minSeqNrCondition $maxTimestampCondition $minTimestampCondition + """ + else + sql""" + SELECT TOP (1) slice, persistence_id, seq_nr, write_timestamp, snapshot, ser_id, ser_manifest, meta_payload, meta_ser_id, meta_ser_manifest + FROM $snapshotTable + WHERE persistence_id = @persistenceId + $maxSeqNrCondition $minSeqNrCondition $maxTimestampCondition $minTimestampCondition + """ + } + + private def deleteSql(criteria: SnapshotSelectionCriteria): String = { + val maxSeqNrCondition = + if (criteria.maxSequenceNr != Long.MaxValue) " AND seq_nr <= @maxSeqNr" + else "" + + val minSeqNrCondition = + if (criteria.minSequenceNr > 0L) " AND seq_nr >= @minSeqNr" + else "" + + val maxTimestampCondition = + if (criteria.maxTimestamp != Long.MaxValue) " AND write_timestamp <= @maxTimestamp" + else "" + + val minTimestampCondition = + if (criteria.minTimestamp != 0L) " AND write_timestamp >= @minTimestamp" + else "" + + sql""" + DELETE FROM $snapshotTable + WHERE persistence_id = @persistenceId + $maxSeqNrCondition $minSeqNrCondition $maxTimestampCondition $minTimestampCondition""" + } + + protected def snapshotsBySlicesRangeSql(minSlice: Int, maxSlice: Int): String = + sql""" + SELECT TOP(@bufferSize) slice, persistence_id, seq_nr, db_timestamp, write_timestamp, snapshot, ser_id, ser_manifest, tags, meta_payload, meta_ser_id, meta_ser_manifest + FROM $snapshotTable + WHERE entity_type = @entityType + AND ${sliceCondition(minSlice, maxSlice)} + AND db_timestamp >= @fromTimestamp + ORDER BY db_timestamp, seq_nr + """ + + private def selectBucketsSql(entityType: String, minSlice: Int, maxSlice: Int): String = { + val subQuery = + s""" + select TOP(@limit) CAST(DATEDIFF(s,'1970-01-01 00:00:00',db_timestamp) AS BIGINT) / 10 AS bucket + FROM $snapshotTable + WHERE entity_type = @entityType + AND ${sliceCondition(minSlice, maxSlice)} + AND db_timestamp >= @fromTimestamp AND db_timestamp <= @toTimestamp + """.stripMargin + sql""" + SELECT bucket, count(*) as count from ($subQuery) as sub + GROUP BY bucket ORDER BY bucket + """ + } + + protected def sliceCondition(minSlice: Int, maxSlice: Int): String = + s"slice in (${(minSlice to maxSlice).mkString(",")})" + + private def collectSerializedSnapshot(entityType: String, row: Row): SerializedSnapshotRow = { + val writeTimestamp = row.get[java.lang.Long]("write_timestamp", classOf[java.lang.Long]) + val dbTimestamp = + if (settings.querySettings.startFromSnapshotEnabled) + fromDbTimestamp(row.get("db_timestamp", classOf[LocalDateTime])) match { + case null => Instant.ofEpochMilli(writeTimestamp) + case t => t + } + else + Instant.ofEpochMilli(writeTimestamp) + val tags = + if (settings.querySettings.startFromSnapshotEnabled) + tagsFromDb(row) + else + Set.empty[String] + + SerializedSnapshotRow( + slice = row.get[Integer]("slice", classOf[Integer]), + entityType, + persistenceId = row.get("persistence_id", classOf[String]), + seqNr = row.get[java.lang.Long]("seq_nr", classOf[java.lang.Long]), + dbTimestamp, + writeTimestamp, + snapshot = row.getPayload("snapshot"), + serializerId = row.get[Integer]("ser_id", classOf[Integer]), + serializerManifest = row.get("ser_manifest", classOf[String]), + tags, + metadata = { + val metaSerializerId = row.get("meta_ser_id", classOf[Integer]) + if (metaSerializerId eq null) None + else + Some( + SerializedSnapshotMetadata( + row.get("meta_payload", classOf[Array[Byte]]), + metaSerializerId, + row.get("meta_ser_manifest", classOf[String]))) + }) + } + + override def load( + persistenceId: String, + criteria: SnapshotSelectionCriteria): Future[Option[SerializedSnapshotRow]] = { + val entityType = PersistenceId.extractEntityType(persistenceId) + r2dbcExecutor + .select(s"select snapshot [$persistenceId], criteria: [$criteria]")( + { connection => + val sql = selectSql(criteria, persistenceId) + val statement = connection + .createStatement(sql) + .bind("@persistenceId", persistenceId) + + if (criteria.maxSequenceNr != Long.MaxValue) statement.bind("@maxSeqNr", criteria.maxSequenceNr) + if (criteria.minSequenceNr > 0L) statement.bind("@minSeqNr", criteria.minSequenceNr) + if (criteria.maxTimestamp != Long.MaxValue) statement.bind("@maxTimestamp", criteria.maxTimestamp) + if (criteria.minTimestamp > 0L) statement.bind("@minTimestamp", criteria.minTimestamp) + + statement + }, + collectSerializedSnapshot(entityType, _)) + .map(_.headOption)(ExecutionContexts.parasitic) + + } + + def store(serializedRow: SerializedSnapshotRow): Future[Unit] = { + r2dbcExecutor + .updateOne(s"upsert snapshot [${serializedRow.persistenceId}], sequence number [${serializedRow.seqNr}]") { + connection => + val statement = + connection + .createStatement(upsertSql) + .bind("@slice", serializedRow.slice) + .bind("@entityType", serializedRow.entityType) + .bind("@persistenceId", serializedRow.persistenceId) + .bind("@seqNr", serializedRow.seqNr) + .bind("@writeTimestamp", serializedRow.writeTimestamp) + .bindPayload("@snapshot", serializedRow.snapshot) + .bind("@serId", serializedRow.serializerId) + .bind("@serManifest", serializedRow.serializerManifest) + .bind("@tags", tagsToDb(serializedRow.tags)) + + serializedRow.metadata match { + case Some(SerializedSnapshotMetadata(serializedMeta, serializerId, serializerManifest)) => + statement + .bind("@metaPayload", serializedMeta) + .bind("@metaSerId", serializerId) + .bind("@metaSerManifest", serializerManifest) + case None => + statement + .bindNull("@metaPayload", classOf[Array[Byte]]) + .bindNull("@metaSerId", classOf[Integer]) + .bindNull("@metaSerManifest", classOf[String]) + } + + if (settings.querySettings.startFromSnapshotEnabled) { + statement + .bind("@dbTimestamp", toDbTimestamp(serializedRow.dbTimestamp)) + .bind("@tags", tagsToDb(serializedRow.tags)) + } + + statement + } + .map(_ => ())(ExecutionContexts.parasitic) + } + + def delete(persistenceId: String, criteria: SnapshotSelectionCriteria): Future[Unit] = { + r2dbcExecutor.updateOne(s"delete snapshot [$persistenceId], criteria [$criteria]") { connection => + val statement = connection + .createStatement(deleteSql(criteria)) + .bind("@persistenceId", persistenceId) + + if (criteria.maxSequenceNr != Long.MaxValue) { + statement.bind("@maxSeqNr", criteria.maxSequenceNr) + } + if (criteria.minSequenceNr > 0L) { + statement.bind("@minSeqNr", criteria.minSequenceNr) + } + if (criteria.maxTimestamp != Long.MaxValue) { + statement.bind("@maxTimestamp", criteria.maxTimestamp) + } + if (criteria.minTimestamp > 0L) { + statement.bind("@minTimestamp", criteria.minTimestamp) + } + statement + } + }.map(_ => ())(ExecutionContexts.parasitic) + + /** + * This is used from `BySliceQuery`, i.e. only if settings.querySettings.startFromSnapshotEnabled + */ + override def currentDbTimestamp(): Future[Instant] = Future.successful(nowInstant()) + + /** + * This is used from `BySliceQuery`, i.e. only if settings.querySettings.startFromSnapshotEnabled + */ + override def rowsBySlices( + entityType: String, + minSlice: Int, + maxSlice: Int, + fromTimestamp: Instant, + toTimestamp: Option[Instant], + behindCurrentTime: FiniteDuration, + backtracking: Boolean): Source[SerializedSnapshotRow, NotUsed] = { + val result = r2dbcExecutor.select(s"select snapshotsBySlices [$minSlice - $maxSlice]")( + connection => { + val stmt = connection + .createStatement(snapshotsBySlicesRangeSql(minSlice, maxSlice)) + .bind("@entityType", entityType) + .bind("@fromTimestamp", toDbTimestamp(fromTimestamp)) + .bind("@bufferSize", settings.querySettings.bufferSize) + stmt + }, + collectSerializedSnapshot(entityType, _)) + + if (log.isDebugEnabled) + result.foreach(rows => log.debugN("Read [{}] snapshots from slices [{} - {}]", rows.size, minSlice, maxSlice)) + + Source.futureSource(result.map(Source(_))).mapMaterializedValue(_ => NotUsed) + } + + /** + * Counts for a bucket may become inaccurate when existing snapshots are updated since the timestamp is changed. This + * is used from `BySliceQuery`, i.e. only if settings.querySettings.startFromSnapshotEnabled + */ + override def countBucketsMayChange: Boolean = true + + /** + * This is used from `BySliceQuery`, i.e. only if settings.querySettings.startFromSnapshotEnabled + */ + override def countBuckets( + entityType: String, + minSlice: Int, + maxSlice: Int, + fromTimestamp: Instant, + limit: Int): Future[Seq[Bucket]] = { + + val toTimestamp = { + val now = InstantFactory.now() // not important to use database time + if (fromTimestamp == Instant.EPOCH) + now + else { + // max buckets, just to have some upper bound + val t = fromTimestamp.plusSeconds(Buckets.BucketDurationSeconds * limit + Buckets.BucketDurationSeconds) + if (t.isAfter(now)) now else t + } + } + + val result = r2dbcExecutor.select(s"select bucket counts [$minSlice - $maxSlice]")( + connection => + connection + .createStatement(selectBucketsSql(entityType, minSlice, maxSlice)) + .bind("@entityType", entityType) + .bind("@fromTimestamp", toDbTimestamp(fromTimestamp)) + .bind("@toTimestamp", toDbTimestamp(toTimestamp)) + .bind("@limit", limit), + row => { + val bucketStartEpochSeconds = row.get("bucket", classOf[java.lang.Long]).toLong * 10 + val count = row.get[java.lang.Long]("count", classOf[java.lang.Long]).toLong + Bucket(bucketStartEpochSeconds, count) + }) + + if (log.isDebugEnabled) + result.foreach(rows => log.debugN("Read [{}] bucket counts from slices [{} - {}]", rows.size, minSlice, maxSlice)) + + result + } +} diff --git a/core/src/test/java/akka/persistence/r2dbc/state/JavadslChangeHandlerSqlServer.java b/core/src/test/java/akka/persistence/r2dbc/state/JavadslChangeHandlerSqlServer.java new file mode 100644 index 00000000..8cc206a4 --- /dev/null +++ b/core/src/test/java/akka/persistence/r2dbc/state/JavadslChangeHandlerSqlServer.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2022 - 2023 Lightbend Inc. + */ + +package akka.persistence.r2dbc.state; + +import akka.Done; +import akka.persistence.query.DurableStateChange; +import akka.persistence.query.UpdatedDurableState; +import akka.persistence.r2dbc.session.javadsl.R2dbcSession; +import akka.persistence.r2dbc.state.javadsl.ChangeHandler; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; + +public class JavadslChangeHandlerSqlServer implements ChangeHandler { + @Override + public CompletionStage process(R2dbcSession session, DurableStateChange change) { + if (change instanceof UpdatedDurableState) { + UpdatedDurableState upd = (UpdatedDurableState) change; + return session + .updateOne( + session + .createStatement("insert into changes_test (pid, rev, the_value) values (@pid, @rev, @theValue)") + .bind("@pid", upd.persistenceId()) + .bind("@rev", upd.revision()) + .bind("@theValue", upd.value())) + .thenApply(n -> Done.getInstance()); + } else { + return CompletableFuture.completedFuture(Done.getInstance()); + } + } +} diff --git a/core/src/test/resources/application-sqlserver.conf b/core/src/test/resources/application-sqlserver.conf new file mode 100644 index 00000000..524cf3da --- /dev/null +++ b/core/src/test/resources/application-sqlserver.conf @@ -0,0 +1 @@ +akka.persistence.r2dbc.connection-factory = ${akka.persistence.r2dbc.sqlserver} \ No newline at end of file diff --git a/core/src/test/resources/logback-test.xml b/core/src/test/resources/logback-test.xml index 106ce86f..9a6b2121 100644 --- a/core/src/test/resources/logback-test.xml +++ b/core/src/test/resources/logback-test.xml @@ -14,6 +14,7 @@ + @@ -23,4 +24,4 @@ - + \ No newline at end of file diff --git a/core/src/test/scala/akka/persistence/r2dbc/TestConfig.scala b/core/src/test/scala/akka/persistence/r2dbc/TestConfig.scala index 00efdcda..4fee669a 100644 --- a/core/src/test/scala/akka/persistence/r2dbc/TestConfig.scala +++ b/core/src/test/scala/akka/persistence/r2dbc/TestConfig.scala @@ -27,6 +27,9 @@ object TestConfig { trace-logging = on } """) + case "sqlserver" => + // defaults are fine + ConfigFactory.empty() } // fallback to default here so that connection-factory can be overridden diff --git a/core/src/test/scala/akka/persistence/r2dbc/internal/R2dbcExecutorSpec.scala b/core/src/test/scala/akka/persistence/r2dbc/internal/R2dbcExecutorSpec.scala index 2dfa31a2..5bf1b96b 100644 --- a/core/src/test/scala/akka/persistence/r2dbc/internal/R2dbcExecutorSpec.scala +++ b/core/src/test/scala/akka/persistence/r2dbc/internal/R2dbcExecutorSpec.scala @@ -48,6 +48,7 @@ class R2dbcExecutorSpec case class Row(col: String) // need pg_sleep or similar + // should we add sqlserver here? private def canBeTestedWithDialect: Boolean = r2dbcSettings.connectionFactorySettings.dialect == PostgresDialect || r2dbcSettings.connectionFactorySettings.dialect == YugabyteDialect diff --git a/core/src/test/scala/akka/persistence/r2dbc/internal/SqlServerDialectSpec.scala b/core/src/test/scala/akka/persistence/r2dbc/internal/SqlServerDialectSpec.scala new file mode 100644 index 00000000..864c432b --- /dev/null +++ b/core/src/test/scala/akka/persistence/r2dbc/internal/SqlServerDialectSpec.scala @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2022 - 2023 Lightbend Inc. + */ + +package akka.persistence.r2dbc.internal + +import akka.persistence.r2dbc.internal.sqlserver.SqlServerDialectHelper +import com.typesafe.config.ConfigFactory +import org.scalatest.TestSuite +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import org.scalatest.Assertions._ + +class SqlServerDialectSpec extends AnyWordSpec with TestSuite with Matchers { + + "Helper" should { + "throw if tag contains separator character" in { + val conf = ConfigFactory.parseString("""{ + | tag-separator = "|" + |} + |""".stripMargin) + val tag = "some|tag" + assertThrows[IllegalArgumentException] { + SqlServerDialectHelper(conf).tagsToDb(Set(tag)) + } + } + + } +} diff --git a/core/src/test/scala/akka/persistence/r2dbc/journal/PersistTagsSpec.scala b/core/src/test/scala/akka/persistence/r2dbc/journal/PersistTagsSpec.scala index 06c5a8e6..55d9f09c 100644 --- a/core/src/test/scala/akka/persistence/r2dbc/journal/PersistTagsSpec.scala +++ b/core/src/test/scala/akka/persistence/r2dbc/journal/PersistTagsSpec.scala @@ -5,7 +5,6 @@ package akka.persistence.r2dbc.journal import scala.concurrent.duration._ - import akka.Done import akka.actor.testkit.typed.scaladsl.LogCapturing import akka.actor.testkit.typed.scaladsl.ScalaTestWithActorTestKit @@ -15,6 +14,7 @@ import akka.persistence.r2dbc.TestActors.Persister import akka.persistence.r2dbc.TestConfig import akka.persistence.r2dbc.TestData import akka.persistence.r2dbc.TestDbLifecycle +import akka.persistence.r2dbc.internal.sqlserver.SqlServerDialectHelper import akka.persistence.typed.PersistenceId import org.scalatest.wordspec.AnyWordSpecLike @@ -63,6 +63,8 @@ class PersistTagsSpec case null => Set.empty[String] case tags: Array[_] => tags.toSet.asInstanceOf[Set[String]] } + } else if (settings.dialectName == "sqlserver") { + SqlServerDialectHelper(settings.connectionFactorySettings.config).tagsFromDb(row) } else { row.get("tags", classOf[Array[String]]) match { case null => Set.empty[String] diff --git a/core/src/test/scala/akka/persistence/r2dbc/journal/PersistTimestampSpec.scala b/core/src/test/scala/akka/persistence/r2dbc/journal/PersistTimestampSpec.scala index e9bbc17b..10429860 100644 --- a/core/src/test/scala/akka/persistence/r2dbc/journal/PersistTimestampSpec.scala +++ b/core/src/test/scala/akka/persistence/r2dbc/journal/PersistTimestampSpec.scala @@ -5,9 +5,9 @@ package akka.persistence.r2dbc.journal import java.time.Instant - +import java.time.LocalDateTime +import java.time.ZoneId import scala.concurrent.duration._ - import akka.Done import akka.actor.testkit.typed.scaladsl.LogCapturing import akka.actor.testkit.typed.scaladsl.ScalaTestWithActorTestKit @@ -80,7 +80,7 @@ class PersistTimestampSpec Row( pid = row.get("persistence_id", classOf[String]), seqNr = row.get[java.lang.Long]("seq_nr", classOf[java.lang.Long]), - dbTimestamp = row.get("db_timestamp", classOf[Instant]), + dbTimestamp = row.get("db_timestamp", classOf[LocalDateTime]).atZone(ZoneId.systemDefault()).toInstant, event) }) .futureValue diff --git a/core/src/test/scala/akka/persistence/r2dbc/journal/R2dbcJournalSpec.scala b/core/src/test/scala/akka/persistence/r2dbc/journal/R2dbcJournalSpec.scala index e57454cc..c7b5c1db 100644 --- a/core/src/test/scala/akka/persistence/r2dbc/journal/R2dbcJournalSpec.scala +++ b/core/src/test/scala/akka/persistence/r2dbc/journal/R2dbcJournalSpec.scala @@ -24,6 +24,7 @@ object R2dbcJournalSpec { def testConfig(): Config = { ConfigFactory .parseString(s""" + akka.loglevel=DEBUG # allow java serialization when testing akka.actor.allow-java-serialization = on akka.actor.warn-about-java-serializer-usage = off diff --git a/core/src/test/scala/akka/persistence/r2dbc/query/CurrentPersistenceIdsQuerySpec.scala b/core/src/test/scala/akka/persistence/r2dbc/query/CurrentPersistenceIdsQuerySpec.scala index 1c70ee62..3f630fcf 100644 --- a/core/src/test/scala/akka/persistence/r2dbc/query/CurrentPersistenceIdsQuerySpec.scala +++ b/core/src/test/scala/akka/persistence/r2dbc/query/CurrentPersistenceIdsQuerySpec.scala @@ -50,7 +50,6 @@ class CurrentPersistenceIdsQuerySpec override protected def beforeAll(): Unit = { super.beforeAll() - val probe = createTestProbe[Done]() pids.foreach { pid => val persister = spawn(TestActors.Persister(pid)) diff --git a/core/src/test/scala/akka/persistence/r2dbc/query/EventsBySliceBacktrackingSpec.scala b/core/src/test/scala/akka/persistence/r2dbc/query/EventsBySliceBacktrackingSpec.scala index 66c9fbe9..049da62e 100644 --- a/core/src/test/scala/akka/persistence/r2dbc/query/EventsBySliceBacktrackingSpec.scala +++ b/core/src/test/scala/akka/persistence/r2dbc/query/EventsBySliceBacktrackingSpec.scala @@ -5,10 +5,9 @@ package akka.persistence.r2dbc.query import java.time.Instant +import java.time.LocalDateTime import java.time.temporal.ChronoUnit - import scala.concurrent.duration._ - import akka.actor.testkit.typed.scaladsl.LogCapturing import akka.actor.testkit.typed.scaladsl.ScalaTestWithActorTestKit import akka.actor.typed.ActorSystem @@ -27,6 +26,7 @@ import akka.persistence.r2dbc.TestData import akka.persistence.r2dbc.TestDbLifecycle import akka.persistence.r2dbc.internal.EnvelopeOrigin import akka.persistence.r2dbc.internal.InstantFactory +import akka.persistence.r2dbc.query.EventsBySliceBacktrackingSpec.config import akka.persistence.r2dbc.query.scaladsl.R2dbcReadJournal import akka.persistence.typed.PersistenceId import akka.serialization.SerializationExtension @@ -36,6 +36,8 @@ import com.typesafe.config.ConfigFactory import org.scalatest.wordspec.AnyWordSpecLike import org.slf4j.LoggerFactory +import java.util.TimeZone + object EventsBySliceBacktrackingSpec { private val BufferSize = 10 // small buffer for testing @@ -66,22 +68,45 @@ class EventsBySliceBacktrackingSpec // to be able to store events with specific timestamps private def writeEvent(slice: Int, persistenceId: String, seqNr: Long, timestamp: Instant, event: String): Unit = { log.debugN("Write test event [{}] [{}] [{}] at time [{}]", persistenceId, seqNr, event, timestamp) - val insertEventSql = sql""" - INSERT INTO ${settings.journalTableWithSchema} - (slice, entity_type, persistence_id, seq_nr, db_timestamp, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload) - VALUES (?, ?, ?, ?, ?, '', '', ?, '', ?)""" + val insertEventSql = + if (settings.dialectName == "sqlserver") { + sql""" + INSERT INTO ${settings.journalTableWithSchema} + (slice, entity_type, persistence_id, seq_nr, db_timestamp, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload) + VALUES (@slice, @entityType, @persistenceId, @seqNr, @dbTimestamp, '', '', @eventSerId, '', @eventPayload)""" + } else { + sql""" + INSERT INTO ${settings.journalTableWithSchema} + (slice, entity_type, persistence_id, seq_nr, db_timestamp, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload) + VALUES (?, ?, ?, ?, ?, '', '', ?, '', ?)""" + } + val entityType = PersistenceId.extractEntityType(persistenceId) - val result = r2dbcExecutor.updateOne("test writeEvent") { connection => - connection - .createStatement(insertEventSql) - .bind(0, slice) - .bind(1, entityType) - .bind(2, persistenceId) - .bind(3, seqNr) - .bind(4, timestamp) - .bind(5, stringSerializer.identifier) - .bindPayload(6, stringSerializer.toBinary(event)) + val result = if (settings.dialectName == "sqlserver") { + r2dbcExecutor.updateOne("test writeEvent") { connection => + connection + .createStatement(insertEventSql) + .bind("@slice", slice) + .bind("@entityType", entityType) + .bind("@persistenceId", persistenceId) + .bind("@seqNr", seqNr) + .bind("@dbTimestamp", LocalDateTime.ofInstant(timestamp, TimeZone.getTimeZone("UTC").toZoneId)) + .bind("@eventSerId", stringSerializer.identifier) + .bindPayload("@eventPayload", stringSerializer.toBinary(event)) + } + } else { + r2dbcExecutor.updateOne("test writeEvent") { connection => + connection + .createStatement(insertEventSql) + .bind(0, slice) + .bind(1, entityType) + .bind(2, persistenceId) + .bind(3, seqNr) + .bind(4, timestamp) + .bind(5, stringSerializer.identifier) + .bindPayload(6, stringSerializer.toBinary(event)) + } } result.futureValue shouldBe 1 } diff --git a/core/src/test/scala/akka/persistence/r2dbc/state/CurrentPersistenceIdsQuerySpec.scala b/core/src/test/scala/akka/persistence/r2dbc/state/CurrentPersistenceIdsQuerySpec.scala index 5a879756..0d3f91ad 100644 --- a/core/src/test/scala/akka/persistence/r2dbc/state/CurrentPersistenceIdsQuerySpec.scala +++ b/core/src/test/scala/akka/persistence/r2dbc/state/CurrentPersistenceIdsQuerySpec.scala @@ -63,13 +63,17 @@ class CurrentPersistenceIdsQuerySpec private val customPid1 = nextPid(customEntityType) private val customPid2 = nextPid(customEntityType) + val createTable = if (r2dbcSettings.dialectName == "sqlserver") { + s"IF object_id('$customTable') is null SELECT * into $customTable from durable_state where persistence_id = '';" + } else { + s"create table if not exists $customTable as select * from durable_state where persistence_id = ''" + } + override protected def beforeAll(): Unit = { super.beforeAll() Await.result( - r2dbcExecutor.executeDdl("beforeAll create durable_state_test")( - _.createStatement( - s"create table if not exists $customTable as select * from durable_state where persistence_id = ''")), + r2dbcExecutor.executeDdl("beforeAll create durable_state_test")(_.createStatement(createTable)), 20.seconds) Await.result( diff --git a/core/src/test/scala/akka/persistence/r2dbc/state/DurableStateStoreAdditionalColumnSpec.scala b/core/src/test/scala/akka/persistence/r2dbc/state/DurableStateStoreAdditionalColumnSpec.scala index f96c492f..46b2dfed 100644 --- a/core/src/test/scala/akka/persistence/r2dbc/state/DurableStateStoreAdditionalColumnSpec.scala +++ b/core/src/test/scala/akka/persistence/r2dbc/state/DurableStateStoreAdditionalColumnSpec.scala @@ -68,26 +68,37 @@ class DurableStateStoreAdditionalColumnSpec private val customTable = r2dbcSettings.getDurableStateTableWithSchema("CustomEntity") + val (createCustomTable, alterCustomTable) = if (r2dbcSettings.dialectName == "sqlserver") { + val create = + s"IF object_id('$customTable') is null SELECT * into $customTable from durable_state where persistence_id = ''" + val alter = (col: String, colType: String) => { + s"IF COL_LENGTH('$customTable', '$col') IS NULL Alter Table $customTable Add $col $colType" + } + (create, alter) + } else { + val create = s"create table if not exists $customTable as select * from durable_state where persistence_id = ''" + val alter = (col: String, colType: String) => s"alter table $customTable add if not exists $col $colType" + (create, alter) + } + override def typedSystem: ActorSystem[_] = system override def beforeAll(): Unit = { super.beforeAll() Await.result( - r2dbcExecutor.executeDdl("beforeAll create durable_state_test")( - _.createStatement( - s"create table if not exists $customTable as select * from durable_state where persistence_id = ''")), + r2dbcExecutor.executeDdl("beforeAll create durable_state_test")(_.createStatement(createCustomTable)), 20.seconds) Await.result( r2dbcExecutor.executeDdl("beforeAll alter durable_state_test")( - _.createStatement(s"alter table $customTable add if not exists col1 varchar(256)")), + _.createStatement(alterCustomTable("col1", "varchar(256)"))), 20.seconds) Await.result( r2dbcExecutor.executeDdl("beforeAll alter durable_state_test")( - _.createStatement(s"alter table $customTable add if not exists col2 int")), + _.createStatement(alterCustomTable("col2", "int"))), 20.seconds) Await.result( r2dbcExecutor.executeDdl("beforeAll alter durable_state_test")( - _.createStatement(s"alter table $customTable add if not exists col3 int")), + _.createStatement(alterCustomTable("col3", "int"))), 20.seconds) Await.result( r2dbcExecutor.updateOne("beforeAll delete")(_.createStatement(s"delete from $customTable")), diff --git a/core/src/test/scala/akka/persistence/r2dbc/state/DurableStateStoreChangeHandlerSpec.scala b/core/src/test/scala/akka/persistence/r2dbc/state/DurableStateStoreChangeHandlerSpec.scala index e70f72e5..f60ce9a2 100644 --- a/core/src/test/scala/akka/persistence/r2dbc/state/DurableStateStoreChangeHandlerSpec.scala +++ b/core/src/test/scala/akka/persistence/r2dbc/state/DurableStateStoreChangeHandlerSpec.scala @@ -8,7 +8,6 @@ import scala.concurrent.Await import scala.concurrent.ExecutionContext import scala.concurrent.Future import scala.concurrent.duration._ - import akka.Done import akka.actor.testkit.typed.scaladsl.LogCapturing import akka.actor.testkit.typed.scaladsl.ScalaTestWithActorTestKit @@ -21,25 +20,66 @@ import akka.persistence.r2dbc.TestData import akka.persistence.r2dbc.TestDbLifecycle import akka.persistence.r2dbc.internal.Sql.Interpolation import akka.persistence.r2dbc.session.scaladsl.R2dbcSession +import akka.persistence.r2dbc.state.DurableStateStoreChangeHandlerSpec.config import akka.persistence.r2dbc.state.scaladsl.ChangeHandler import akka.persistence.r2dbc.state.scaladsl.R2dbcDurableStateStore import akka.persistence.state.DurableStateStoreRegistry import akka.persistence.state.scaladsl.GetObjectResult import com.typesafe.config.Config import com.typesafe.config.ConfigFactory +import io.r2dbc.spi.Statement import org.scalatest.wordspec.AnyWordSpecLike object DurableStateStoreChangeHandlerSpec { + val fallback = TestConfig.config + + val dialect = fallback.getString("akka.persistence.r2dbc.connection-factory.dialect") + val (javaDslcustomEntity, insertStatement, deleteStatement) = if (dialect == "sqlserver") { + val changeHandler = classOf[JavadslChangeHandlerSqlServer].getName + val build = (session: R2dbcSession, upd: UpdatedDurableState[String]) => { + session + .createStatement(sql"insert into changes_test (pid, rev, the_value) values (@pid, @rev, @theValue)") + .bind("@pid", upd.persistenceId) + .bind("@rev", upd.revision) + .bind("@theValue", upd.value) + } + val delete = (session: R2dbcSession, upd: DeletedDurableState[String]) => { + session + .createStatement(sql"insert into changes_test (pid, rev, the_value) values (@pid, @rev, @theValue)") + .bind("@pid", upd.persistenceId) + .bind("@rev", upd.revision) + .bindNull("@theValue", classOf[String]) + } + (changeHandler, build, delete) + } else { + val changeHandler = classOf[JavadslChangeHandler].getName + val build = (session: R2dbcSession, upd: UpdatedDurableState[String]) => { + session + .createStatement(sql"insert into changes_test (pid, rev, the_value) values (?, ?, ?)") + .bind(0, upd.persistenceId) + .bind(1, upd.revision) + .bind(2, upd.value) + } + val delete = (session: R2dbcSession, upd: DeletedDurableState[String]) => { + session + .createStatement(sql"insert into changes_test (pid, rev, the_value) values (?, ?, ?)") + .bind(0, upd.persistenceId) + .bind(1, upd.revision) + .bindNull(2, classOf[String]) + } + (changeHandler, build, delete) + } + val config: Config = ConfigFactory .parseString(s""" akka.persistence.r2dbc.state { change-handler { "CustomEntity" = "${classOf[Handler].getName}" - "JavadslCustomEntity" = "${classOf[JavadslChangeHandler].getName}" + "JavadslCustomEntity" = "$javaDslcustomEntity" } } """) - .withFallback(TestConfig.config) + .withFallback(fallback) class Handler(system: ActorSystem[_]) extends ChangeHandler[String] { private implicit val ec: ExecutionContext = system.executionContext @@ -51,22 +91,12 @@ object DurableStateStoreChangeHandlerSpec { Future.failed(new RuntimeException("BOOM")) else session - .updateOne( - session - .createStatement(sql"insert into changes_test (pid, rev, the_value) values (?, ?, ?)") - .bind(0, upd.persistenceId) - .bind(1, upd.revision) - .bind(2, upd.value)) + .updateOne(insertStatement(session, upd)) .map(_ => Done) case del: DeletedDurableState[String] => session - .updateOne( - session - .createStatement(sql"insert into changes_test (pid, rev, the_value) values (?, ?, ?)") - .bind(0, del.persistenceId) - .bind(1, del.revision) - .bindNull(2, classOf[String])) + .updateOne(deleteStatement(session, del)) .map(_ => Done) } } @@ -81,16 +111,21 @@ class DurableStateStoreChangeHandlerSpec with TestData with LogCapturing { + val dialect = config.getString("akka.persistence.r2dbc.connection-factory.dialect") private val anotherTable = "changes_test" + val createTableSql = if (dialect == "sqlserver") { + s"IF object_id('$anotherTable') is null create table $anotherTable (pid varchar(256), rev bigint, the_value varchar(256))" + } else { + s"create table if not exists $anotherTable (pid varchar(256), rev bigint, the_value varchar(256))" + } + override def typedSystem: ActorSystem[_] = system override def beforeAll(): Unit = { super.beforeAll() Await.result( - r2dbcExecutor.executeDdl("beforeAll create durable_state_test")( - _.createStatement( - s"create table if not exists $anotherTable (pid varchar(256), rev bigint, the_value varchar(256))")), + r2dbcExecutor.executeDdl("beforeAll create durable_state_test")(_.createStatement(createTableSql)), 20.seconds) Await.result( r2dbcExecutor.updateOne("beforeAll delete")(_.createStatement(s"delete from $anotherTable")), diff --git a/ddl-scripts/create_tables_sqlserver.sql b/ddl-scripts/create_tables_sqlserver.sql new file mode 100644 index 00000000..746e6f71 --- /dev/null +++ b/ddl-scripts/create_tables_sqlserver.sql @@ -0,0 +1,73 @@ +IF object_id('event_journal') is null + CREATE TABLE event_journal( + slice INT NOT NULL, + entity_type NVARCHAR(255) NOT NULL, + persistence_id NVARCHAR(255) NOT NULL, + seq_nr NUMERIC(10,0) NOT NULL, + db_timestamp datetime2(6) NOT NULL, + event_ser_id INTEGER NOT NULL, + event_ser_manifest NVARCHAR(255) NOT NULL, + event_payload VARBINARY(MAX) NOT NULL, + deleted BIT DEFAULT 0 NOT NULL, + writer NVARCHAR(255) NOT NULL, + adapter_manifest NVARCHAR(255) NOT NULL, + tags NVARCHAR(255), + + meta_ser_id INTEGER, + meta_ser_manifest NVARCHAR(MAX), + meta_payload VARBINARY(MAX), + PRIMARY KEY(persistence_id, seq_nr) + ); + +IF NOT EXISTS(SELECT * FROM sys.indexes WHERE name = 'event_journal_slice_idx' AND object_id = OBJECT_ID('event_journal')) + BEGIN + CREATE INDEX event_journal_slice_idx ON event_journal (slice, entity_type, db_timestamp, seq_nr); + END; + +IF object_id('snapshot') is null + CREATE TABLE snapshot( + slice INT NOT NULL, + entity_type NVARCHAR(255) NOT NULL, + persistence_id NVARCHAR(255) NOT NULL, + seq_nr BIGINT NOT NULL, + db_timestamp datetime2(6), + write_timestamp BIGINT NOT NULL, + ser_id INTEGER NOT NULL, + ser_manifest NVARCHAR(255) NOT NULL, + snapshot VARBINARY(MAX) NOT NULL, + tags NVARCHAR(255), + meta_ser_id INTEGER, + meta_ser_manifest NVARCHAR(255), + meta_payload VARBINARY(MAX), + PRIMARY KEY(persistence_id) + ); + +-- `snapshot_slice_idx` is only needed if the slice based queries are used together with snapshot as starting point +IF NOT EXISTS(SELECT * FROM sys.indexes WHERE name = 'snapshot_slice_idx' AND object_id = OBJECT_ID('snapshot')) + BEGIN + CREATE INDEX snapshot_slice_idx ON snapshot(slice, entity_type, db_timestamp); + END; + +IF object_id('durable_state') is null + CREATE TABLE durable_state ( + slice INT NOT NULL, + entity_type NVARCHAR(255) NOT NULL, + persistence_id NVARCHAR(255) NOT NULL, + revision BIGINT NOT NULL, + db_timestamp datetime2(6) NOT NULL, + + state_ser_id INTEGER NOT NULL, + state_ser_manifest NVARCHAR(255), + state_payload VARBINARY(MAX) NOT NULL, + tags NVARCHAR(255), + + PRIMARY KEY(persistence_id, revision) + ); + +-- `durable_state_slice_idx` is only needed if the slice based queries are used +IF NOT EXISTS(SELECT * FROM sys.indexes WHERE name = 'durable_state_slice_idx' AND object_id = OBJECT_ID('durable_state')) + BEGIN + CREATE INDEX durable_state_slice_idx ON durable_state(slice, entity_type, db_timestamp, revision); + END; + +--DROP TABLE event_journal; diff --git a/ddl-scripts/drop_tables_sqlserver.sql b/ddl-scripts/drop_tables_sqlserver.sql new file mode 100644 index 00000000..b50ac77c --- /dev/null +++ b/ddl-scripts/drop_tables_sqlserver.sql @@ -0,0 +1,4 @@ +DROP INDEX event_journal.event_journal_slice_idx; +DROP TABLE IF EXISTS event_journal; +DROP TABLE IF EXISTS snapshot; +DROP TABLE IF EXISTS durable_state; diff --git a/docker/docker-compose-sqlserver.yml b/docker/docker-compose-sqlserver.yml new file mode 100644 index 00000000..76f95fb9 --- /dev/null +++ b/docker/docker-compose-sqlserver.yml @@ -0,0 +1,10 @@ +version: '2.2' +services: + sqlserver: + image: mcr.microsoft.com/mssql/server:2022-latest + container_name: sqlserver-db + environment: + - MSSQL_SA_PASSWORD= + - ACCEPT_EULA=Y + ports: + - 1433:1433 diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 8afeffff..e1040f43 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -26,6 +26,10 @@ object Dependencies { val r2dbcPool = "io.r2dbc" % "r2dbc-pool" % "1.0.1.RELEASE" // ApacheV2 val r2dbcPostgres = "org.postgresql" % "r2dbc-postgresql" % "1.0.3.RELEASE" // ApacheV2 + // we have to stick to this version for now: https://github.com/r2dbc/r2dbc-mssql/issues/276 + // bumping to 1.0.1.RELEASE or later currently requires pool config initial-size=1 and max-size=1 + val r2dbcSqlServer = "io.r2dbc" % "r2dbc-mssql" % "1.0.0.RELEASE" // ApacheV2 + val h2 = "com.h2database" % "h2" % H2Version % Provided // EPL 1.0 val r2dbcH2 = "io.r2dbc" % "r2dbc-h2" % R2dbcH2Version % Provided // ApacheV2 } @@ -56,6 +60,7 @@ object Dependencies { r2dbcPool, r2dbcPostgres, h2, + r2dbcSqlServer, r2dbcH2, TestDeps.akkaPersistenceTck, TestDeps.akkaStreamTestkit,