Skip to content

Commit

Permalink
feat: SQL Server / mssql dialect support
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastian-alfers committed Dec 18, 2023
1 parent 8035859 commit 7f7454e
Show file tree
Hide file tree
Showing 29 changed files with 2,379 additions and 55 deletions.
34 changes: 34 additions & 0 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,40 @@ jobs:
cp .jvmopts-ci .jvmopts
sbt -Dconfig.resource=application-h2.conf test
test-sqlserver:
name: Run test with SQL Server
runs-on: ubuntu-22.04
if: github.repository == 'akka/akka-persistence-r2dbc'
steps:
- name: Checkout
uses: actions/[email protected]
with:
fetch-depth: 0

- name: Checkout GitHub merge
if: github.event.pull_request
run: |-
git fetch origin pull/${{ github.event.pull_request.number }}/merge:scratch
git checkout scratch
- name: Cache Coursier cache
uses: coursier/[email protected]

- name: Set up JDK 11
uses: coursier/[email protected]
with:
jvm: temurin:1.11.0

- name: Start DB
run: |-
docker compose -f docker/docker-compose-sqlserver.yml up --wait
docker exec -i sqlserver-db /opt/mssql-tools/bin/sqlcmd -S localhost -U SA -P '<YourStrong@Passw0rd>' -d master < ddl-scripts/create_tables_sqlserver.sql
- name: sbt test
run: |-
cp .jvmopts-ci .jvmopts
sbt -Dconfig.resource=application-sqlserver.conf test
test-docs:
name: Docs
runs-on: ubuntu-22.04
Expand Down
4 changes: 3 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ lazy val docs = project
Preprocess / siteSubdirName := s"api/akka-persistence-r2dbc/${projectInfoVersion.value}",
Preprocess / sourceDirectory := (LocalRootProject / ScalaUnidoc / unidoc / target).value,
Paradox / siteSubdirName := s"docs/akka-persistence-r2dbc/${projectInfoVersion.value}",
paradoxGroups := Map("Language" -> Seq("Java", "Scala"), "Dialect" -> Seq("Postgres", "Yugabyte", "H2")),
paradoxGroups := Map(
"Language" -> Seq("Java", "Scala"),
"Dialect" -> Seq("SQL Server", "Postgres", "Yugabyte", "H2")),
Compile / paradoxProperties ++= Map(
"project.url" -> "https://doc.akka.io/docs/akka-persistence-r2dbc/current/",
"canonical.base_url" -> "https://doc.akka.io/docs/akka-persistence-r2dbc/current",
Expand Down
29 changes: 28 additions & 1 deletion core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,33 @@ akka.persistence.r2dbc {
// #connection-settings-h2
}

# Defaults for SQL Server
sqlserver = ${akka.persistence.r2dbc.default-connection-pool}
sqlserver {
dialect = "sqlserver"
driver = "mssql"

// #connection-settings-sqlserver
# the connection can be configured with a url, eg: "r2dbc:sqlserver://<host>:1433/<database>"
url = ""

# The connection options to be used. Ignored if 'url' is non-empty
host = "localhost"

port = 1433
database = "master"
user = "SA"
password = "<YourStrong@Passw0rd>"

# Maximum time to create a new connection.
connect-timeout = 3 seconds

# Used to encode tags to and from db. Tags must not contain this separator.
tag-separator = ","

// #connection-settings-sqlserver
}

# Assign the connection factory for the dialect you want to use, then override specific fields
# connection-factory = ${akka.persistence.r2dbc.postgres}
# connection-factory {
Expand All @@ -368,7 +395,7 @@ akka.persistence.r2dbc {
# updates of the same persistenceId there might be a performance gain to
# set this to `on`. Note that many databases use the system clock and that can
# move backwards when the system clock is adjusted.
# Ignored for H2
# Ignored for H2 and sqlserver
db-timestamp-monotonic-increasing = off

# Enable this to generate timestamps from the Akka client side instead of using database timestamps.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import akka.actor.typed.ActorSystem
import akka.annotation.InternalApi
import akka.persistence.r2dbc.ConnectionPoolSettings
import akka.persistence.r2dbc.internal.h2.H2Dialect
import akka.persistence.r2dbc.internal.sqlserver.SqlServerDialect
import akka.persistence.r2dbc.internal.postgres.PostgresDialect
import akka.persistence.r2dbc.internal.postgres.YugabyteDialect
import akka.util.Helpers.toRootLowerCase
Expand All @@ -24,12 +25,13 @@ private[r2dbc] object ConnectionFactorySettings {

def apply(config: Config): ConnectionFactorySettings = {
val dialect: Dialect = toRootLowerCase(config.getString("dialect")) match {
case "yugabyte" => YugabyteDialect: Dialect
case "postgres" => PostgresDialect: Dialect
case "h2" => H2Dialect: Dialect
case "yugabyte" => YugabyteDialect: Dialect
case "postgres" => PostgresDialect: Dialect
case "h2" => H2Dialect: Dialect
case "sqlserver" => SqlServerDialect: Dialect
case other =>
throw new IllegalArgumentException(
s"Unknown dialect [$other]. Supported dialects are [postgres, yugabyte, h2].")
s"Unknown dialect [$other]. Supported dialects are [postgres, yugabyte, h2, sqlserver].")
}

// pool settings are common to all dialects but defined inline in the connection factory block
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,20 @@ import io.r2dbc.spi.Statement
def bindPayload(index: Int, payload: Array[Byte]): Statement =
statement.bind(index, codec.encode(payload))

def bindPayload(name: String, payload: Array[Byte]): Statement =
statement.bind(name, codec.encode(payload))

def bindPayloadOption(index: Int, payloadOption: Option[Array[Byte]]): Statement =
payloadOption match {
case Some(payload) => bindPayload(index, payload)
case None => bindPayload(index, codec.nonePayload)
}

def bindPayloadOption(name: String, payloadOption: Option[Array[Byte]]): Statement =
payloadOption match {
case Some(payload) => bindPayload(name, payload)
case None => bindPayload(name, codec.nonePayload)
}
}
implicit class RichRow(val row: Row)(implicit codec: PayloadCodec) extends AnyRef {
def getPayload(name: String): Array[Byte] = codec.decode(row.get(name, codec.payloadClass))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright (C) 2022 - 2023 Lightbend Inc. <https://www.lightbend.com>
*/

package akka.persistence.r2dbc.internal.sqlserver

import akka.actor.typed.ActorSystem
import akka.annotation.InternalApi
import akka.persistence.r2dbc.R2dbcSettings
import akka.persistence.r2dbc.internal._
import akka.util.JavaDurationConverters.JavaDurationOps
import com.typesafe.config.Config
import io.r2dbc.spi.ConnectionFactories
import io.r2dbc.spi.ConnectionFactory
import io.r2dbc.spi.ConnectionFactoryOptions

import java.time.{ Duration => JDuration }
import scala.concurrent.duration.FiniteDuration

/**
* INTERNAL API
*/
@InternalApi
private[r2dbc] object SqlServerDialect extends Dialect {

private[r2dbc] final class SqlServerConnectionFactorySettings(config: Config) {
val urlOption: Option[String] =
Option(config.getString("url"))
.filter(_.trim.nonEmpty)

val driver: String = config.getString("driver")
val host: String = config.getString("host")
val port: Int = config.getInt("port")
val user: String = config.getString("user")
val password: String = config.getString("password")
val database: String = config.getString("database")
val connectTimeout: FiniteDuration = config.getDuration("connect-timeout").asScala

}

override def name: String = "sqlserver"

override def adaptSettings(settings: R2dbcSettings): R2dbcSettings = {
val res = settings
// app timestamp is db timestamp because sqlserver does not provide a transaction timestamp
.withUseAppTimestamp(true)
// saw flaky tests where the Instant.now was smaller then the db timestamp AFTER the insert
.withDbTimestampMonotonicIncreasing(false)
res
}

override def createConnectionFactory(config: Config): ConnectionFactory = {

val settings = new SqlServerConnectionFactorySettings(config)
val builder =
settings.urlOption match {
case Some(url) =>
ConnectionFactoryOptions
.builder()
.from(ConnectionFactoryOptions.parse(url))
case _ =>
ConnectionFactoryOptions
.builder()
.option(ConnectionFactoryOptions.DRIVER, settings.driver)
.option(ConnectionFactoryOptions.HOST, settings.host)
.option(ConnectionFactoryOptions.PORT, Integer.valueOf(settings.port))
.option(ConnectionFactoryOptions.USER, settings.user)
.option(ConnectionFactoryOptions.PASSWORD, settings.password)
.option(ConnectionFactoryOptions.DATABASE, settings.database)
.option(ConnectionFactoryOptions.CONNECT_TIMEOUT, JDuration.ofMillis(settings.connectTimeout.toMillis))
}
ConnectionFactories.get(builder.build())
}

override def createJournalDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit
system: ActorSystem[_]): JournalDao =
new SqlServerJournalDao(settings, connectionFactory)(system.executionContext, system)

override def createQueryDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit
system: ActorSystem[_]): QueryDao =
new SqlServerQueryDao(settings, connectionFactory)(system.executionContext, system)

override def createSnapshotDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit
system: ActorSystem[_]): SnapshotDao =
new SqlServerSnapshotDao(settings, connectionFactory)(system.executionContext, system)

override def createDurableStateDao(settings: R2dbcSettings, connectionFactory: ConnectionFactory)(implicit
system: ActorSystem[_]): DurableStateDao =
new SqlServerDurableStateDao(settings, connectionFactory)(system.executionContext, system)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (C) 2022 - 2023 Lightbend Inc. <https://www.lightbend.com>
*/

package akka.persistence.r2dbc.internal.sqlserver

import akka.annotation.InternalApi
import akka.persistence.r2dbc.internal.InstantFactory
import com.typesafe.config.Config
import io.r2dbc.spi.Row

import java.time.Instant
import java.time.LocalDateTime
import java.util.TimeZone

/**
* INTERNAL API
*/
@InternalApi
private[r2dbc] object SqlServerDialectHelper {
def apply(config: Config) = new SqlServerDialectHelper(config)
}

/**
* INTERNAL API
*/
@InternalApi
private[r2dbc] class SqlServerDialectHelper(config: Config) {

private val tagSeparator = config.getString("tag-separator")

require(tagSeparator.length == 1, s"Tag separator '$tagSeparator' must be a single character.")

def tagsToDb(tags: Set[String]): String = {
if (tags.exists(_.contains(tagSeparator))) {
throw new IllegalArgumentException(
s"A tag in [$tags] contains the character '$tagSeparator' which is reserved. Please change `akka.persistence.r2dbc.sqlserver.tag-separator` to a character that is not contained by any of your tags.")
}
tags.mkString(tagSeparator)
}

def tagsFromDb(row: Row): Set[String] = row.get("tags", classOf[String]) match {
case null => Set.empty[String]
case entries => entries.split(tagSeparator).toSet
}

private val zone = TimeZone.getTimeZone("UTC").toZoneId

def nowInstant(): Instant = InstantFactory.now()

def nowLocalDateTime(): LocalDateTime = LocalDateTime.ofInstant(nowInstant(), zone)

def toDbTimestamp(timestamp: Instant): LocalDateTime =
LocalDateTime.ofInstant(timestamp, zone)

def fromDbTimestamp(time: LocalDateTime): Instant = time
.atZone(zone)
.toInstant

}
Loading

0 comments on commit 7f7454e

Please sign in to comment.