From 54f12a0c9fa43a8f40d54c81fecd399666a7e4b1 Mon Sep 17 00:00:00 2001 From: Sakib Hadziavdic Date: Wed, 3 Jan 2024 14:11:36 +0100 Subject: [PATCH] Derive SqlRead and SqlWrite for simple enums --- squery/src/ba/sake/squery/read/SqlRead.scala | 101 ++++++++++++++---- .../src/ba/sake/squery/write/SqlWrite.scala | 57 +++++++++- .../src/ba/sake/squery/PostgresSuite.scala | 18 ++-- .../test/src/ba/sake/squery/dataTypes.scala | 6 +- 4 files changed, 152 insertions(+), 30 deletions(-) diff --git a/squery/src/ba/sake/squery/read/SqlRead.scala b/squery/src/ba/sake/squery/read/SqlRead.scala index c0eab35..735ac8c 100644 --- a/squery/src/ba/sake/squery/read/SqlRead.scala +++ b/squery/src/ba/sake/squery/read/SqlRead.scala @@ -1,35 +1,20 @@ -package ba.sake.squery.read +package ba.sake.squery +package read import java.{sql => jsql} import java.time.* import java.util.UUID +import scala.deriving.* +import scala.quoted.* // reads a value from a column trait SqlRead[T]: def readByName(jRes: jsql.ResultSet, colName: String): Option[T] def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[T] -object SqlRead: +object SqlRead { def apply[T](using sqlRead: SqlRead[T]): SqlRead[T] = sqlRead - // TODO derived for simple enums - /* - -import java.sql.ResultSet -import ba.sake.squery.read.SqlRead -import ba.sake.sharaf.petclinic.common.PetType - -given SqlRead[PetType] = new { - private val stringRead = SqlRead[String] - - override def readByName(jRes: ResultSet, colName: String): Option[PetType] = - stringRead.readByName(jRes, colName).map(PetType.valueOf) - - override def readByIdx(jRes: ResultSet, colIdx: Int): Option[PetType] = - stringRead.readByIdx(jRes, colIdx).map(PetType.valueOf) -} - */ - given SqlRead[String] = new { def readByName(jRes: jsql.ResultSet, colName: String): Option[String] = Option(jRes.getString(colName)) @@ -79,7 +64,6 @@ given SqlRead[PetType] = new { def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[OffsetDateTime] = Option(jRes.getObject(colIdx, classOf[OffsetDateTime])) - } given SqlRead[LocalDate] = new { @@ -104,7 +88,6 @@ given SqlRead[PetType] = new { def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[UUID] = Option(jRes.getObject(colIdx, classOf[UUID])) - } // this "cannot fail" @@ -114,5 +97,79 @@ given SqlRead[PetType] = new { def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Option[T]] = Some(sr.readByIdx(jRes, colIdx)) + } + /* macro derived instances */ + inline def derived[T]: SqlRead[T] = ${ derivedMacro[T] } + + private def derivedMacro[T: Type](using Quotes): Expr[SqlRead[T]] = { + import quotes.reflect.* + + val mirror: Expr[Mirror.Of[T]] = Expr.summon[Mirror.Of[T]].getOrElse { + report.errorAndAbort( + s"Cannot derive SqlRead[${Type.show[T]}] automatically because ${Type.show[T]} is not a singleton enum" + ) + } + + mirror match + case '{ $m: Mirror.ProductOf[T] } => + report.errorAndAbort("Product types are not supported") + + case '{ + type label <: Tuple; + $m: Mirror.SumOf[T] { type MirroredElemLabels = `label` } + } => + val labels = Expr(Type.valueOfTuple[label].map(_.toList.map(_.toString)).getOrElse(List.empty)) + + val isSingleCasesEnum = isSingletonCasesEnum[T] + if !isSingleCasesEnum then + report.errorAndAbort( + s"Cannot derive SqlRead[${Type.show[T]}] automatically because ${Type.show[T]} is not a singleton-cases enum" + ) + + val companion = TypeRepr.of[T].typeSymbol.companionModule.termRef + val valueOfSelect = Select.unique(Ident(companion), "valueOf").symbol + + '{ + new SqlRead[T] { + override def readByName(jRes: jsql.ResultSet, colName: String): Option[T] = + SqlRead[String].readByName(jRes, colName).map { enumString => + try { + ${ + val bla = '{ enumString } + Block(Nil, Apply(Select(Ident(companion), valueOfSelect), List(bla.asTerm))).asExprOf[T] + } + } catch { + case e: IllegalArgumentException => + throw SqueryException( + s"Enum value not found: '${enumString}'. Possible values: ${$labels.map(l => s"'$l'").mkString(", ")}" + ) + } + } + + override def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[T] = + SqlRead[String].readByIdx(jRes, colIdx).map { enumString => + try { + ${ + val bla = '{ enumString } + Block(Nil, Apply(Select(Ident(companion), valueOfSelect), List(bla.asTerm))).asExprOf[T] + } + } catch { + case e: IllegalArgumentException => + throw SqueryException( + s"Enum value not found: '${enumString}'. Possible values: ${$labels.map(l => s"'$l'").mkString(", ")}" + ) + } + } + } + } + + case hmm => report.errorAndAbort("Sum types are not supported") } + + private def isSingletonCasesEnum[T: Type](using Quotes): Boolean = + import quotes.reflect.* + val ts = TypeRepr.of[T].typeSymbol + ts.flags.is(Flags.Enum) && ts.companionClass.methodMember("values").nonEmpty + +} diff --git a/squery/src/ba/sake/squery/write/SqlWrite.scala b/squery/src/ba/sake/squery/write/SqlWrite.scala index ca089fc..aea8682 100644 --- a/squery/src/ba/sake/squery/write/SqlWrite.scala +++ b/squery/src/ba/sake/squery/write/SqlWrite.scala @@ -9,11 +9,13 @@ import java.time.LocalDate import java.time.LocalDateTime import java.time.ZoneId import java.time.OffsetDateTime +import scala.deriving.* +import scala.quoted.* trait SqlWrite[T]: def write(ps: jsql.PreparedStatement, idx: Int, valueOpt: Option[T]): Unit -object SqlWrite: +object SqlWrite { def apply[T](using sqlWrite: SqlWrite[T]): SqlWrite[T] = sqlWrite @@ -134,3 +136,56 @@ object SqlWrite: ): Unit = sw.write(ps, idx, value.flatten) } + + /* macro derived instances */ + inline def derived[T]: SqlWrite[T] = ${ derivedMacro[T] } + + private def derivedMacro[T: Type](using Quotes): Expr[SqlWrite[T]] = { + import quotes.reflect.* + + val mirror: Expr[Mirror.Of[T]] = Expr.summon[Mirror.Of[T]].getOrElse { + report.errorAndAbort( + s"Cannot derive SqlWrite[${Type.show[T]}] automatically because ${Type.show[T]} is not a singleton enum" + ) + } + + mirror match + case '{ $m: Mirror.ProductOf[T] } => + report.errorAndAbort("Product types are not supported") + + case '{ + type label <: Tuple; + $m: Mirror.SumOf[T] { type MirroredElemLabels = `label` } + } => + val labels = Expr(Type.valueOfTuple[label].map(_.toList.map(_.toString)).getOrElse(List.empty)) + + val isSingleCasesEnum = isSingletonCasesEnum[T] + if !isSingleCasesEnum then + report.errorAndAbort( + s"Cannot derive SqlWrite[${Type.show[T]}] automatically because ${Type.show[T]} is not a singleton-cases enum" + ) + + val companion = TypeRepr.of[T].typeSymbol.companionModule.termRef + val valueOfSelect = Select.unique(Ident(companion), "valueOf").symbol + + '{ + new SqlWrite[T] { + def write(ps: jsql.PreparedStatement, idx: Int, valueOpt: Option[T]): Unit = + valueOpt match + case Some(value) => + val index = $m.ordinal(value) + val label = $labels(index) + ps.setObject(idx, label, jsql.Types.OTHER) + case None => ps.setNull(idx, jsql.Types.OTHER) + } + } + + case hmm => report.errorAndAbort("Sum types are not supported") + } + + private def isSingletonCasesEnum[T: Type](using Quotes): Boolean = + import quotes.reflect.* + val ts = TypeRepr.of[T].typeSymbol + ts.flags.is(Flags.Enum) && ts.companionClass.methodMember("values").nonEmpty + +} diff --git a/squery/test/src/ba/sake/squery/PostgresSuite.scala b/squery/test/src/ba/sake/squery/PostgresSuite.scala index 52cc4d3..d132487 100644 --- a/squery/test/src/ba/sake/squery/PostgresSuite.scala +++ b/squery/test/src/ba/sake/squery/PostgresSuite.scala @@ -239,6 +239,10 @@ class PostgresSuite extends munit.FunSuite { test("Data types") { val ctx = initDb() ctx.run { + // enum + sql""" + CREATE TYPE color AS ENUM ('red', 'green', 'blue') + """.update() // note that Instant has NANOseconds precision! // postgres has MICROseconds precision sql""" @@ -249,7 +253,8 @@ class PostgresSuite extends munit.FunSuite { boolean BOOLEAN, string VARCHAR(255), uuid UUID, - tstz TIMESTAMPTZ + tstz TIMESTAMPTZ, + clr color ) """.update() val dt1 = Datatypes( @@ -259,21 +264,22 @@ class PostgresSuite extends munit.FunSuite { Some(true), Some("abc"), Some(UUID.randomUUID), - Some(Instant.now.truncatedTo(ChronoUnit.MICROS)) + Some(Instant.now.truncatedTo(ChronoUnit.MICROS)), + Some(Color.red) ) - val dt2 = Datatypes(None, None, None, None, None, None, None) + val dt2 = Datatypes(None, None, None, None, None, None, None, None) val values = Seq(dt1, dt2) - .map(dt => sql"(${dt.int}, ${dt.long}, ${dt.double}, ${dt.boolean}, ${dt.string}, ${dt.uuid}, ${dt.tstz})") + .map(dt => sql"(${dt.int}, ${dt.long}, ${dt.double}, ${dt.boolean}, ${dt.string}, ${dt.uuid}, ${dt.tstz}, ${dt.clr})") .intersperse(sql",") .reduce(_ ++ _) sql""" - INSERT INTO datatypes(int, long, double, boolean, string, uuid, tstz) + INSERT INTO datatypes(int, long, double, boolean, string, uuid, tstz, clr) VALUES ${values} """.insert() val storedRows = sql""" - SELECT int, long, double, boolean, string, uuid, tstz + SELECT int, long, double, boolean, string, uuid, tstz, clr FROM datatypes """.readRows[Datatypes]() assertEquals( diff --git a/squery/test/src/ba/sake/squery/dataTypes.scala b/squery/test/src/ba/sake/squery/dataTypes.scala index a9f5f4a..bb90201 100644 --- a/squery/test/src/ba/sake/squery/dataTypes.scala +++ b/squery/test/src/ba/sake/squery/dataTypes.scala @@ -23,5 +23,9 @@ case class Datatypes( boolean: Option[Boolean], string: Option[String], uuid: Option[UUID], - tstz: Option[Instant] + tstz: Option[Instant], + clr: Option[Color] ) derives SqlReadRow + +enum Color derives SqlRead, SqlWrite: + case red, green, blue