Skip to content

Commit

Permalink
Support SQL arrays up to 3 dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
sake92 committed Sep 6, 2024
1 parent 02dd8b7 commit 4926a03
Show file tree
Hide file tree
Showing 9 changed files with 280 additions and 29 deletions.
14 changes: 8 additions & 6 deletions generator/src/ba/sake/squery/generator/SqueryGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene
basePackage: String,
fileGen: Boolean
): (Seq[GeneratedFileSource], Seq[GeneratedFileSource]) = {
val enumDefs = schemaDef.tables.flatMap {
_.columnDefs.map(_.scalaType).collect { case e: ColumnType.Enumeration =>
e
val enumDefs = schemaDef.tables
.flatMap {
_.columnDefs.map(_.scalaType).collect { case e: ColumnType.Enumeration =>
e
}
}
}.distinctBy(_.name)
.distinctBy(_.name)
val enumFiles = enumDefs.map { enumDef =>
val enumCaseDefs = Defn.RepeatedEnumCase(
List.empty,
Expand Down Expand Up @@ -547,8 +549,8 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene
q"import java.util.UUID",
q"import ba.sake.squery.{*, given}",
q"import ..${List(dbSpecificImporter)}",
q"import ba.sake.squery.read.SqlRead",
q"import ba.sake.squery.write.SqlWrite"
q"import ba.sake.squery.read.{*, given}",
q"import ba.sake.squery.write.{*, given}"
)
}
private def generateDaoImports(dbType: DbType, basePackage: String) = {
Expand Down
12 changes: 12 additions & 0 deletions squery/src/ba/sake/squery/SqlNonScalarType.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package ba.sake.squery

// - a marker typeclass for non-scalar types
// i.e. array-type and similar

// - this is to prevent infinite recursion of Array[Array[Array...T
// the SqlWrite, SqlRead typeclasses would break

trait SqlNonScalarType[T]

given [T]: SqlNonScalarType[Array[T]] = new {}
given [T]: SqlNonScalarType[Seq[T]] = new {}
8 changes: 0 additions & 8 deletions squery/src/ba/sake/squery/postgres/reads.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,3 @@ given SqlRead[UUID] with {
def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[UUID] =
Option(jRes.getObject(colIdx, classOf[UUID]))
}

given [T: SqlRead]: SqlRead[Array[T]] with {
def readByName(jRes: jsql.ResultSet, colName: String): Option[Array[T]] =
Option(jRes.getArray(colName)).map(_.getArray().asInstanceOf[Array[T]])

def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Array[T]] =
Option(jRes.getArray(colIdx)).map(_.getArray().asInstanceOf[Array[T]])
}
8 changes: 8 additions & 0 deletions squery/src/ba/sake/squery/postgres/typeNames.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package ba.sake.squery.postgres

import ba.sake.squery.write.SqlTypeName
import java.util.UUID

given SqlTypeName[Array[UUID]] with {
def value: String = "UUID"
}
66 changes: 65 additions & 1 deletion squery/src/ba/sake/squery/read/SqlRead.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import java.time.*
import java.util.UUID
import scala.deriving.*
import scala.quoted.*
import scala.util.NotGiven
import scala.reflect.ClassTag

// reads a value from a column
trait SqlRead[T]:
Expand Down Expand Up @@ -44,7 +46,6 @@ object SqlRead {
Option(jRes.getShort(colIdx)).filterNot(_ => jRes.wasNull())
}


given SqlRead[Int] with {
def readByName(jRes: jsql.ResultSet, colName: String): Option[Int] =
Option(jRes.getInt(colName)).filterNot(_ => jRes.wasNull())
Expand Down Expand Up @@ -98,6 +99,35 @@ object SqlRead {
Option(jRes.getTimestamp(colIdx)).map(_.toLocalDateTime())
}

/* Arrays */
// - general first, then specific ones, for implicits ordering
// - _.map(_.asInstanceOf[T]) because of boxing/unboxing...
given sqlReadArray1[T: SqlRead: ClassTag](using NotGiven[SqlNonScalarType[T]]): SqlRead[Array[T]] with {
def readByName(jRes: jsql.ResultSet, colName: String): Option[Array[T]] =
Option(jRes.getArray(colName)).map(_.getArray().asInstanceOf[Array[T]].map(_.asInstanceOf[T]))

def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Array[T]] =
Option(jRes.getArray(colIdx)).map(_.getArray().asInstanceOf[Array[T]].map(_.asInstanceOf[T]))
}

given sqlReadArray2[T: SqlRead: ClassTag](using NotGiven[SqlNonScalarType[T]]): SqlRead[Array[Array[T]]] with {
def readByName(jRes: jsql.ResultSet, colName: String): Option[Array[Array[T]]] =
Option(jRes.getArray(colName)).map(_.getArray().asInstanceOf[Array[Array[T]]].map(_.map(_.asInstanceOf[T])))

def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Array[Array[T]]] =
Option(jRes.getArray(colIdx)).map(_.getArray().asInstanceOf[Array[Array[T]]].map(_.map(_.asInstanceOf[T])))
}

given sqlReadArray3[T: SqlRead: ClassTag](using NotGiven[SqlNonScalarType[T]]): SqlRead[Array[Array[Array[T]]]] with {
def readByName(jRes: jsql.ResultSet, colName: String): Option[Array[Array[Array[T]]]] =
Option(jRes.getArray(colName))
.map(_.getArray().asInstanceOf[Array[Array[Array[T]]]].map(_.map(_.map(_.asInstanceOf[T]))))

def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Array[Array[Array[T]]]] =
Option(jRes.getArray(colIdx))
.map(_.getArray().asInstanceOf[Array[Array[Array[T]]]].map(_.map(_.map(_.asInstanceOf[T]))))
}

given SqlRead[Array[Byte]] with {
def readByName(jRes: jsql.ResultSet, colName: String): Option[Array[Byte]] =
Option(jRes.getBytes(colName))
Expand All @@ -106,6 +136,40 @@ object SqlRead {
Option(jRes.getBytes(colIdx))
}

// vector utils, nicer to deal with
given sqlReadVector1[T: SqlRead: ClassTag](using NotGiven[SqlNonScalarType[T]]): SqlRead[Vector[T]] with {
def readByName(jRes: jsql.ResultSet, colName: String): Option[Vector[T]] =
SqlRead[Array[T]].readByName(jRes, colName).map(_.toVector)

def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Vector[T]] =
SqlRead[Array[T]].readByIdx(jRes, colIdx).map(_.toVector)
}

given sqlReadVector2[T: SqlRead: ClassTag](using NotGiven[SqlNonScalarType[T]]): SqlRead[Vector[Vector[T]]] with {
def readByName(jRes: jsql.ResultSet, colName: String): Option[Vector[Vector[T]]] =
SqlRead[Array[Array[T]]].readByName(jRes, colName).map(_.toVector.map(_.toVector))

def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Vector[Vector[T]]] =
SqlRead[Array[Array[T]]].readByIdx(jRes, colIdx).map(_.toVector.map(_.toVector))
}

given sqlReadVector3[T: SqlRead: ClassTag](using NotGiven[SqlNonScalarType[T]]): SqlRead[Vector[Vector[Vector[T]]]]
with {
def readByName(jRes: jsql.ResultSet, colName: String): Option[Vector[Vector[Vector[T]]]] =
SqlRead[Array[Array[Array[T]]]].readByName(jRes, colName).map(_.toVector.map(_.toVector.map(_.toVector)))

def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Vector[Vector[Vector[T]]]] =
SqlRead[Array[Array[Array[T]]]].readByIdx(jRes, colIdx).map(_.toVector.map(_.toVector.map(_.toVector)))
}

given SqlRead[Vector[Byte]] with {
def readByName(jRes: jsql.ResultSet, colName: String): Option[Vector[Byte]] =
SqlRead[Array[Byte]].readByName(jRes, colName).map(_.toVector)

def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Vector[Byte]] =
SqlRead[Array[Byte]].readByIdx(jRes, colIdx).map(_.toVector)
}

// this "cannot fail"
given [T](using sr: SqlRead[T]): SqlRead[Option[T]] with {
def readByName(jRes: jsql.ResultSet, colName: String): Option[Option[T]] =
Expand Down
53 changes: 53 additions & 0 deletions squery/src/ba/sake/squery/write/SqlTypeName.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package ba.sake.squery.write

import java.time.*

// used for createArrayOf(sqlTypeName, myArray)
trait SqlTypeName[T]:
def value: String

// for Array[Array[... T]] the value is the inner-most T name
given [T](using stn: SqlTypeName[T]): SqlTypeName[Array[T]] with {
def value: String = stn.value
}
// for Seq[Seq[... T]] the value is the inner-most T name
given [T](using stn: SqlTypeName[T]): SqlTypeName[Seq[T]] with {
def value: String = stn.value
}

given SqlTypeName[String] with {
def value: String = "VARCHAR"
}

given SqlTypeName[Boolean] with {
def value: String = "BOOLEAN"
}
given SqlTypeName[Byte] with {
def value: String = "TINYINT"
}
given SqlTypeName[Short] with {
def value: String = "SMALLINT"
}
given SqlTypeName[Int] with {
def value: String = "INTEGER"
}
given SqlTypeName[Long] with {
def value: String = "BIGINT"
}
given SqlTypeName[Double] with {
def value: String = "REAL"
}

given SqlTypeName[LocalDate] with {
def value: String = "DATE"
}
given SqlTypeName[LocalDateTime] with {
def value: String = "TIMESTAMPT"
}
given SqlTypeName[Instant] with {
def value: String = "TIMESTAMPTZ"
}

given SqlTypeName[Array[Byte]] with {
def value: String = "BINARY"
}
91 changes: 90 additions & 1 deletion squery/src/ba/sake/squery/write/SqlWrite.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package ba.sake.squery.write
package ba.sake.squery
package write

import java.{sql => jsql}
import java.time.Instant
Expand All @@ -11,6 +12,8 @@ import java.time.ZoneId
import java.time.OffsetDateTime
import scala.deriving.*
import scala.quoted.*
import scala.reflect.ClassTag
import scala.util.NotGiven

trait SqlWrite[T]:
def write(ps: jsql.PreparedStatement, idx: Int, valueOpt: Option[T]): Unit
Expand Down Expand Up @@ -129,6 +132,51 @@ object SqlWrite {
case None => ps.setNull(idx, jsql.Types.TIMESTAMP)
}

/* Arrays */
given sqlWriteArray1[T: SqlWrite](using stn: SqlTypeName[T], ng: NotGiven[SqlNonScalarType[T]]): SqlWrite[Array[T]]
with {
def write(
ps: jsql.PreparedStatement,
idx: Int,
valueOpt: Option[Array[T]]
): Unit = valueOpt match
case Some(value) =>
val valuesAsAnyRef = value.map(_.asInstanceOf[AnyRef]) // box primitives like Array[Int]
val sqlArray = ps.getConnection().createArrayOf(stn.value, valuesAsAnyRef)
ps.setArray(idx, sqlArray)
case None => ps.setArray(idx, null)
}
given sqlWriteArray2[T: SqlWrite](using
stn: SqlTypeName[T],
ng: NotGiven[SqlNonScalarType[T]]
): SqlWrite[Array[Array[T]]] with {
def write(
ps: jsql.PreparedStatement,
idx: Int,
valueOpt: Option[Array[Array[T]]]
): Unit = valueOpt match
case Some(value) =>
val valuesAsAnyRef = value.map(_.map(_.asInstanceOf[AnyRef])) // box primitives like Array[Array[Int]]
val sqlArray = ps.getConnection().createArrayOf(stn.value, valuesAsAnyRef.asInstanceOf[Array[AnyRef]])
ps.setArray(idx, sqlArray)
case None => ps.setArray(idx, null)
}
given sqlWriteArray3[T: SqlWrite](using
stn: SqlTypeName[T],
ng: NotGiven[SqlNonScalarType[T]]
): SqlWrite[Array[Array[Array[T]]]] with {
def write(
ps: jsql.PreparedStatement,
idx: Int,
valueOpt: Option[Array[Array[Array[T]]]]
): Unit = valueOpt match
case Some(value) =>
val valuesAsAnyRef =
value.map(_.map(_.map(_.asInstanceOf[AnyRef]))) // box primitives like Array[Array[Array[Int]]]
val sqlArray = ps.getConnection().createArrayOf(stn.value, valuesAsAnyRef.asInstanceOf[Array[AnyRef]])
ps.setArray(idx, sqlArray)
case None => ps.setArray(idx, null)
}
given SqlWrite[Array[Byte]] with {
def write(
ps: jsql.PreparedStatement,
Expand All @@ -139,6 +187,47 @@ object SqlWrite {
case None => ps.setNull(idx, jsql.Types.BINARY)
}

given sqlWriteVector1[T: SqlWrite: ClassTag](using
stn: SqlTypeName[T],
ng: NotGiven[SqlNonScalarType[T]]
): SqlWrite[Vector[T]] with {
def write(
ps: jsql.PreparedStatement,
idx: Int,
valueOpt: Option[Vector[T]]
): Unit = SqlWrite[Array[T]].write(ps, idx, valueOpt.map(_.toArray))
}

given sqlWriteVector2[T: SqlWrite: ClassTag](using
stn: SqlTypeName[T],
ng: NotGiven[SqlNonScalarType[T]]
): SqlWrite[Vector[Vector[T]]] with {
def write(
ps: jsql.PreparedStatement,
idx: Int,
valueOpt: Option[Vector[Vector[T]]]
): Unit = SqlWrite[Array[Array[T]]].write(ps, idx, valueOpt.map(_.toArray.map(_.toArray)))
}

given sqlWriteVector3[T: SqlWrite: ClassTag](using
stn: SqlTypeName[T],
ng: NotGiven[SqlNonScalarType[T]]
): SqlWrite[Vector[Vector[Vector[T]]]] with {
def write(
ps: jsql.PreparedStatement,
idx: Int,
valueOpt: Option[Vector[Vector[Vector[T]]]]
): Unit = SqlWrite[Array[Array[Array[T]]]].write(ps, idx, valueOpt.map(_.toArray.map(_.toArray.map(_.toArray))))
}

given SqlWrite[Vector[Byte]] with {
def write(
ps: jsql.PreparedStatement,
idx: Int,
valueOpt: Option[Vector[Byte]]
): Unit = SqlWrite[Array[Byte]].write(ps, idx, valueOpt.map(_.toArray))
}

given [T](using sw: SqlWrite[T]): SqlWrite[Option[T]] with {
def write(
ps: jsql.PreparedStatement,
Expand Down
3 changes: 0 additions & 3 deletions squery/test/src/ba/sake/squery/mariadb/MariaDbSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ class MariaDbSuite extends munit.FunSuite {
(${customer2.name}, ${customer2.street})
""".insertReturningGenKeys[Int]()


println(customerIds)

customer1 = customer1.copy(id = customerIds(0))
customer2 = customer2.copy(id = customerIds(1))

Expand Down
Loading

0 comments on commit 4926a03

Please sign in to comment.