Skip to content

Commit

Permalink
Derive SqlRead and SqlWrite for simple enums
Browse files Browse the repository at this point in the history
  • Loading branch information
sake92 committed Jan 3, 2024
1 parent ba31b73 commit 54f12a0
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 30 deletions.
101 changes: 79 additions & 22 deletions squery/src/ba/sake/squery/read/SqlRead.scala
Original file line number Diff line number Diff line change
@@ -1,35 +1,20 @@
package ba.sake.squery.read
package ba.sake.squery
package read

import java.{sql => jsql}
import java.time.*
import java.util.UUID
import scala.deriving.*
import scala.quoted.*

// reads a value from a column
trait SqlRead[T]:
def readByName(jRes: jsql.ResultSet, colName: String): Option[T]
def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[T]

object SqlRead:
object SqlRead {
def apply[T](using sqlRead: SqlRead[T]): SqlRead[T] = sqlRead

// TODO derived for simple enums
/*
import java.sql.ResultSet
import ba.sake.squery.read.SqlRead
import ba.sake.sharaf.petclinic.common.PetType
given SqlRead[PetType] = new {
private val stringRead = SqlRead[String]
override def readByName(jRes: ResultSet, colName: String): Option[PetType] =
stringRead.readByName(jRes, colName).map(PetType.valueOf)
override def readByIdx(jRes: ResultSet, colIdx: Int): Option[PetType] =
stringRead.readByIdx(jRes, colIdx).map(PetType.valueOf)
}
*/

given SqlRead[String] = new {
def readByName(jRes: jsql.ResultSet, colName: String): Option[String] =
Option(jRes.getString(colName))
Expand Down Expand Up @@ -79,7 +64,6 @@ given SqlRead[PetType] = new {

def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[OffsetDateTime] =
Option(jRes.getObject(colIdx, classOf[OffsetDateTime]))

}

given SqlRead[LocalDate] = new {
Expand All @@ -104,7 +88,6 @@ given SqlRead[PetType] = new {

def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[UUID] =
Option(jRes.getObject(colIdx, classOf[UUID]))

}

// this "cannot fail"
Expand All @@ -114,5 +97,79 @@ given SqlRead[PetType] = new {

def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Option[T]] =
Some(sr.readByIdx(jRes, colIdx))
}

/* macro derived instances */
inline def derived[T]: SqlRead[T] = ${ derivedMacro[T] }

private def derivedMacro[T: Type](using Quotes): Expr[SqlRead[T]] = {
import quotes.reflect.*

val mirror: Expr[Mirror.Of[T]] = Expr.summon[Mirror.Of[T]].getOrElse {
report.errorAndAbort(
s"Cannot derive SqlRead[${Type.show[T]}] automatically because ${Type.show[T]} is not a singleton enum"
)
}

mirror match
case '{ $m: Mirror.ProductOf[T] } =>
report.errorAndAbort("Product types are not supported")

case '{
type label <: Tuple;
$m: Mirror.SumOf[T] { type MirroredElemLabels = `label` }
} =>
val labels = Expr(Type.valueOfTuple[label].map(_.toList.map(_.toString)).getOrElse(List.empty))

val isSingleCasesEnum = isSingletonCasesEnum[T]
if !isSingleCasesEnum then
report.errorAndAbort(
s"Cannot derive SqlRead[${Type.show[T]}] automatically because ${Type.show[T]} is not a singleton-cases enum"
)

val companion = TypeRepr.of[T].typeSymbol.companionModule.termRef
val valueOfSelect = Select.unique(Ident(companion), "valueOf").symbol

'{
new SqlRead[T] {
override def readByName(jRes: jsql.ResultSet, colName: String): Option[T] =
SqlRead[String].readByName(jRes, colName).map { enumString =>
try {
${
val bla = '{ enumString }
Block(Nil, Apply(Select(Ident(companion), valueOfSelect), List(bla.asTerm))).asExprOf[T]
}
} catch {
case e: IllegalArgumentException =>
throw SqueryException(
s"Enum value not found: '${enumString}'. Possible values: ${$labels.map(l => s"'$l'").mkString(", ")}"
)
}
}

override def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[T] =
SqlRead[String].readByIdx(jRes, colIdx).map { enumString =>
try {
${
val bla = '{ enumString }
Block(Nil, Apply(Select(Ident(companion), valueOfSelect), List(bla.asTerm))).asExprOf[T]
}
} catch {
case e: IllegalArgumentException =>
throw SqueryException(
s"Enum value not found: '${enumString}'. Possible values: ${$labels.map(l => s"'$l'").mkString(", ")}"
)
}
}
}
}

case hmm => report.errorAndAbort("Sum types are not supported")
}

private def isSingletonCasesEnum[T: Type](using Quotes): Boolean =
import quotes.reflect.*
val ts = TypeRepr.of[T].typeSymbol
ts.flags.is(Flags.Enum) && ts.companionClass.methodMember("values").nonEmpty

}
57 changes: 56 additions & 1 deletion squery/src/ba/sake/squery/write/SqlWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ import java.time.LocalDate
import java.time.LocalDateTime
import java.time.ZoneId
import java.time.OffsetDateTime
import scala.deriving.*
import scala.quoted.*

trait SqlWrite[T]:
def write(ps: jsql.PreparedStatement, idx: Int, valueOpt: Option[T]): Unit

object SqlWrite:
object SqlWrite {

def apply[T](using sqlWrite: SqlWrite[T]): SqlWrite[T] = sqlWrite

Expand Down Expand Up @@ -134,3 +136,56 @@ object SqlWrite:
): Unit =
sw.write(ps, idx, value.flatten)
}

/* macro derived instances */
inline def derived[T]: SqlWrite[T] = ${ derivedMacro[T] }

private def derivedMacro[T: Type](using Quotes): Expr[SqlWrite[T]] = {
import quotes.reflect.*

val mirror: Expr[Mirror.Of[T]] = Expr.summon[Mirror.Of[T]].getOrElse {
report.errorAndAbort(
s"Cannot derive SqlWrite[${Type.show[T]}] automatically because ${Type.show[T]} is not a singleton enum"
)
}

mirror match
case '{ $m: Mirror.ProductOf[T] } =>
report.errorAndAbort("Product types are not supported")

case '{
type label <: Tuple;
$m: Mirror.SumOf[T] { type MirroredElemLabels = `label` }
} =>
val labels = Expr(Type.valueOfTuple[label].map(_.toList.map(_.toString)).getOrElse(List.empty))

val isSingleCasesEnum = isSingletonCasesEnum[T]
if !isSingleCasesEnum then
report.errorAndAbort(
s"Cannot derive SqlWrite[${Type.show[T]}] automatically because ${Type.show[T]} is not a singleton-cases enum"
)

val companion = TypeRepr.of[T].typeSymbol.companionModule.termRef
val valueOfSelect = Select.unique(Ident(companion), "valueOf").symbol

'{
new SqlWrite[T] {
def write(ps: jsql.PreparedStatement, idx: Int, valueOpt: Option[T]): Unit =
valueOpt match
case Some(value) =>
val index = $m.ordinal(value)
val label = $labels(index)
ps.setObject(idx, label, jsql.Types.OTHER)
case None => ps.setNull(idx, jsql.Types.OTHER)
}
}

case hmm => report.errorAndAbort("Sum types are not supported")
}

private def isSingletonCasesEnum[T: Type](using Quotes): Boolean =
import quotes.reflect.*
val ts = TypeRepr.of[T].typeSymbol
ts.flags.is(Flags.Enum) && ts.companionClass.methodMember("values").nonEmpty

}
18 changes: 12 additions & 6 deletions squery/test/src/ba/sake/squery/PostgresSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ class PostgresSuite extends munit.FunSuite {
test("Data types") {
val ctx = initDb()
ctx.run {
// enum
sql"""
CREATE TYPE color AS ENUM ('red', 'green', 'blue')
""".update()
// note that Instant has NANOseconds precision!
// postgres has MICROseconds precision
sql"""
Expand All @@ -249,7 +253,8 @@ class PostgresSuite extends munit.FunSuite {
boolean BOOLEAN,
string VARCHAR(255),
uuid UUID,
tstz TIMESTAMPTZ
tstz TIMESTAMPTZ,
clr color
)
""".update()
val dt1 = Datatypes(
Expand All @@ -259,21 +264,22 @@ class PostgresSuite extends munit.FunSuite {
Some(true),
Some("abc"),
Some(UUID.randomUUID),
Some(Instant.now.truncatedTo(ChronoUnit.MICROS))
Some(Instant.now.truncatedTo(ChronoUnit.MICROS)),
Some(Color.red)
)
val dt2 = Datatypes(None, None, None, None, None, None, None)
val dt2 = Datatypes(None, None, None, None, None, None, None, None)

val values = Seq(dt1, dt2)
.map(dt => sql"(${dt.int}, ${dt.long}, ${dt.double}, ${dt.boolean}, ${dt.string}, ${dt.uuid}, ${dt.tstz})")
.map(dt => sql"(${dt.int}, ${dt.long}, ${dt.double}, ${dt.boolean}, ${dt.string}, ${dt.uuid}, ${dt.tstz}, ${dt.clr})")
.intersperse(sql",")
.reduce(_ ++ _)
sql"""
INSERT INTO datatypes(int, long, double, boolean, string, uuid, tstz)
INSERT INTO datatypes(int, long, double, boolean, string, uuid, tstz, clr)
VALUES ${values}
""".insert()

val storedRows = sql"""
SELECT int, long, double, boolean, string, uuid, tstz
SELECT int, long, double, boolean, string, uuid, tstz, clr
FROM datatypes
""".readRows[Datatypes]()
assertEquals(
Expand Down
6 changes: 5 additions & 1 deletion squery/test/src/ba/sake/squery/dataTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,9 @@ case class Datatypes(
boolean: Option[Boolean],
string: Option[String],
uuid: Option[UUID],
tstz: Option[Instant]
tstz: Option[Instant],
clr: Option[Color]
) derives SqlReadRow

enum Color derives SqlRead, SqlWrite:
case red, green, blue

0 comments on commit 54f12a0

Please sign in to comment.