Skip to content

Commit

Permalink
Merge pull request #1 from sebastian-alfers/refactor-mssql-support-3
Browse files Browse the repository at this point in the history
Separate sql query definition+binding from the actual locic
  • Loading branch information
sebastian-alfers authored Jan 9, 2024
2 parents 7b15676 + c194e4f commit 8a76a98
Show file tree
Hide file tree
Showing 18 changed files with 1,571 additions and 1,466 deletions.
16 changes: 13 additions & 3 deletions core/src/main/scala/akka/persistence/r2dbc/R2dbcSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ package akka.persistence.r2dbc

import akka.annotation.InternalApi
import akka.annotation.InternalStableApi
import akka.persistence.r2dbc.internal.ConnectionFactorySettings
import akka.persistence.r2dbc.internal.PayloadCodec
import akka.persistence.r2dbc.internal.{ ConnectionFactorySettings, PayloadCodec, TimestampCodec }
import akka.util.JavaDurationConverters._
import com.typesafe.config.Config

Expand Down Expand Up @@ -83,6 +82,13 @@ object R2dbcSettings {

val connectionFactorySettings = ConnectionFactorySettings(config.getConfig("connection-factory"))

val timestampCodec: TimestampCodec = {
connectionFactorySettings.dialect.name match {
case "sqlserver" => TimestampCodec.SqlServerCodec
case _ => TimestampCodec.PostgresTimestampCodec
}
}

val querySettings = new QuerySettings(config.getConfig("query"))

val dbTimestampMonotonicIncreasing: Boolean = config.getBoolean("db-timestamp-monotonic-increasing")
Expand All @@ -105,6 +111,7 @@ object R2dbcSettings {
snapshotPayloadCodec,
durableStateTable,
durableStatePayloadCodec,
timestampCodec,
durableStateAssertSingleWriter,
logDbCallsExceeding,
querySettings,
Expand Down Expand Up @@ -139,6 +146,7 @@ final class R2dbcSettings private (
val snapshotPayloadCodec: PayloadCodec,
val durableStateTable: String,
val durableStatePayloadCodec: PayloadCodec,
val timestampCodec: TimestampCodec,
val durableStateAssertSingleWriter: Boolean,
val logDbCallsExceeding: FiniteDuration,
val querySettings: QuerySettings,
Expand All @@ -155,7 +163,7 @@ final class R2dbcSettings private (
val durableStateTableWithSchema: String = schema.map(_ + ".").getOrElse("") + durableStateTable

/**
* One of the supported dialects 'postgres', 'yugabyte' or 'h2'
* One of the supported dialects 'postgres', 'yugabyte', 'sqlserver' or 'h2'
*/
def dialectName: String = _connectionFactorySettings.dialect.name

Expand Down Expand Up @@ -217,6 +225,7 @@ final class R2dbcSettings private (
snapshotPayloadCodec: PayloadCodec = snapshotPayloadCodec,
durableStateTable: String = durableStateTable,
durableStatePayloadCodec: PayloadCodec = durableStatePayloadCodec,
timestampCodec: TimestampCodec = timestampCodec,
durableStateAssertSingleWriter: Boolean = durableStateAssertSingleWriter,
logDbCallsExceeding: FiniteDuration = logDbCallsExceeding,
querySettings: QuerySettings = querySettings,
Expand All @@ -237,6 +246,7 @@ final class R2dbcSettings private (
snapshotPayloadCodec,
durableStateTable,
durableStatePayloadCodec,
timestampCodec,
durableStateAssertSingleWriter,
logDbCallsExceeding,
querySettings,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (C) 2022 - 2023 Lightbend Inc. <https://www.lightbend.com>
*/

package akka.persistence.r2dbc.internal

import java.nio.charset.StandardCharsets.UTF_8
import akka.annotation.InternalApi
import io.r2dbc.postgresql.codec.Json
import io.r2dbc.spi.Row
import io.r2dbc.spi.Statement

import java.time.{ Instant, LocalDateTime }
import java.util.TimeZone

/**
* INTERNAL API
*/
@InternalApi private[akka] sealed trait TimestampCodec {

def encode(timestamp: Instant): Any
def decode(row: Row, name: String): Instant

protected def instantNow() = InstantFactory.now()

def now[T](): T
}

/**
* INTERNAL API
*/
@InternalApi private[akka] object TimestampCodec {
case object PostgresTimestampCodec extends TimestampCodec {
override def decode(row: Row, name: String): Instant = row.get(name, classOf[Instant])

override def encode(timestamp: Instant): Any = timestamp

override def now[T](): T = instantNow().asInstanceOf[T]
}

case object SqlServerCodec extends TimestampCodec {

// should this come from config?
private val zone = TimeZone.getTimeZone("UTC").toZoneId

override def decode(row: Row, name: String): Instant = {
row
.get(name, classOf[LocalDateTime])
.atZone(zone)
.toInstant
}

override def encode(timestamp: Instant): Any = LocalDateTime.ofInstant(timestamp, zone)

override def now[T](): T = LocalDateTime.ofInstant(instantNow(), zone).asInstanceOf[T]
}


implicit class RichStatement[T](val statement: Statement)(implicit codec: TimestampCodec) extends AnyRef {
def bindTimestamp(name: String, timestamp: Instant): Statement = statement.bind(name, codec.encode(timestamp))
def bindTimestamp(index: Int, timestamp: Instant): Statement = statement.bind(index, codec.encode(timestamp))
}
implicit class RichRow[T](val row: Row)(implicit codec: TimestampCodec) extends AnyRef {
def getTimestamp(rowName: String = "db_timestamp"): Instant = codec.decode(row, rowName)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ package akka.persistence.r2dbc.internal.h2
import scala.concurrent.ExecutionContext
import scala.concurrent.duration.Duration
import scala.concurrent.duration.FiniteDuration

import io.r2dbc.spi.ConnectionFactory
import org.slf4j.Logger
import org.slf4j.LoggerFactory

import akka.actor.typed.ActorSystem
import akka.annotation.InternalApi
import akka.persistence.r2dbc.R2dbcSettings
import akka.persistence.r2dbc.internal.Dialect
import akka.persistence.r2dbc.internal.h2.sql.H2DurableStateSql
import akka.persistence.r2dbc.internal.postgres.PostgresDurableStateDao
import akka.persistence.r2dbc.internal.postgres.sql.BaseDurableStateSql

/**
* INTERNAL API
Expand All @@ -28,11 +28,8 @@ private[r2dbc] final class H2DurableStateDao(
dialect: Dialect)(implicit ec: ExecutionContext, system: ActorSystem[_])
extends PostgresDurableStateDao(settings, connectionFactory, dialect) {

override protected lazy val log: Logger = LoggerFactory.getLogger(classOf[H2DurableStateDao])
override val durableStateSql: BaseDurableStateSql = new H2DurableStateSql(settings)

protected override def behindCurrentTimeIntervalConditionFor(behindCurrentTime: FiniteDuration): String =
if (behindCurrentTime > Duration.Zero)
s"AND db_timestamp < CURRENT_TIMESTAMP - interval '${behindCurrentTime.toMillis.toDouble / 1000}' second"
else ""
override protected lazy val log: Logger = LoggerFactory.getLogger(classOf[H2DurableStateDao])

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (C) 2022 - 2023 Lightbend Inc. <https://www.lightbend.com>
*/

package akka.persistence.r2dbc.internal.h2.sql

import akka.persistence.r2dbc.R2dbcSettings
import akka.persistence.r2dbc.internal.DurableStateDao.SerializedStateRow
import akka.persistence.r2dbc.internal.PayloadCodec.RichStatement
import akka.persistence.r2dbc.internal.Sql.Interpolation
import akka.persistence.r2dbc.internal.TimestampCodec.{ RichStatement => TimestampRichStatement }
import akka.persistence.r2dbc.internal.postgres.PostgresDurableStateDao.EvaluatedAdditionalColumnBindings
import akka.persistence.r2dbc.internal.postgres.sql.PostgresDurableStateSql
import akka.persistence.r2dbc.internal.{ PayloadCodec, TimestampCodec }
import akka.persistence.r2dbc.state.scaladsl.AdditionalColumn
import io.r2dbc.spi.Statement

import java.lang
import java.time.Instant
import scala.collection.immutable
import scala.concurrent.duration.{ Duration, FiniteDuration }

class H2DurableStateSql(settings: R2dbcSettings)(implicit
statePayloadCodec: PayloadCodec,
timestampCodec: TimestampCodec)
extends PostgresDurableStateSql(settings) {
protected override def behindCurrentTimeIntervalConditionFor(behindCurrentTime: FiniteDuration): String =
if (behindCurrentTime > Duration.Zero)
s"AND db_timestamp < CURRENT_TIMESTAMP - interval '${behindCurrentTime.toMillis.toDouble / 1000}' second"
else ""

}
Loading

0 comments on commit 8a76a98

Please sign in to comment.