Skip to content

Commit

Permalink
Support for insert-returning in generator
Browse files Browse the repository at this point in the history
  • Loading branch information
sake92 committed Aug 17, 2024
1 parent d6cf598 commit 68935b7
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 21 deletions.
3 changes: 3 additions & 0 deletions generator/src/ba/sake/squery/generator/DbDefExtractor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ abstract class DbDefExtractor(ds: DataSource) {
val tables = extractTables(connection, schemaName, databaseMetaData)
SchemaDef(name = schemaName, tables = tables)
}
val dbType = DbType.fromDatabaseProductName(dbName)
DbDef(
name = dbName,
tpe = dbType,
schemas = schemaDefs
)
}
Expand Down Expand Up @@ -135,6 +137,7 @@ abstract class DbDefExtractor(ds: DataSource) {

case class DbDef(
name: String,
tpe: DbType,
schemas: Seq[SchemaDef]
)

Expand Down
25 changes: 25 additions & 0 deletions generator/src/ba/sake/squery/generator/DbType.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package ba.sake.squery.generator

sealed abstract class DbType(val squeryPackage: String) {
def supportsReturning: Boolean = false
}

object DbType {

def fromDatabaseProductName(dbName: String): DbType = {
if (dbName.contains("h2")) H2
else if (dbName.contains("postgres")) PostgreSQL
else if (dbName.contains("mysql")) MySQL
else if (dbName.contains("mariadb")) MariaDB
else if (dbName.contains("oracle")) Oracle
else throw new RuntimeException(s"Unknown database type $dbName")
}

case object H2 extends DbType("h2")
case object PostgreSQL extends DbType("postgres") {
override def supportsReturning: Boolean = true
}
case object MySQL extends DbType("mysql")
case object MariaDB extends DbType("mariadb")
case object Oracle extends DbType("oracle")
}
48 changes: 27 additions & 21 deletions generator/src/ba/sake/squery/generator/SqueryGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene
dbDef.schemas.find(_.name == schemaName) match {
case Some(schemaDef) =>
logger.info(s"Started generating schema '${schemaName}'")
val (imports, enumDefsScala, tableDefsScala) = generateSchema(schemaDef, dbName = dbDef.name)
val (imports, enumDefsScala, tableDefsScala) = generateSchema(schemaDef, dbType = dbDef.tpe)
val res =
s"""|${Preamble}
|${imports}
Expand Down Expand Up @@ -54,7 +54,7 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene
case Some(schemaDef) =>
logger.info(s"Started generating schema '${schemaConfig.name}' into '${schemaConfig.baseFolder}'")
val packagePath = os.RelPath(schemaConfig.basePackage.replaceAll("\\.", "/"))
val (imports, modelDefsScala, daoDefsScala) = generateSchema(schemaDef, dbName = dbDef.name)
val (imports, modelDefsScala, daoDefsScala) = generateSchema(schemaDef, dbType = dbDef.tpe)
modelDefsScala.foreach { modelFile =>
val modelFileWithImports =
s"""|${Preamble}
Expand Down Expand Up @@ -96,7 +96,7 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene
// (imports, models, repos)
private def generateSchema(
schemaDef: SchemaDef,
dbName: String
dbType: DbType
): (String, Seq[GeneratedFile], Seq[GeneratedFile]) = {
val enumDefs = schemaDef.tables.flatMap {
_.columnDefs.map(_.scalaType).collect { case e: ColumnType.Enumeration =>
Expand Down Expand Up @@ -177,7 +177,6 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene
else s"${pkCol.metadata.name} = $${id}"
}


def genCountAllQuery: String =
s"""| def countAll(): DbAction[Int] =
| sql"SELECT COUNT(*) FROM $${${caseClassName}.TableName}".readValue()
Expand All @@ -200,17 +199,32 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene
|""".stripMargin
}
def genInsert: String = {
// TODO if cols are autoinc
// https://www.jooq.org/doc/latest/manual/sql-building/sql-statements/insert-statement/insert-returning/
val colsListExpr = tableDef.columnDefs
.map(_.metadata.name)
.mkString(", ")
val insertExpr = tableDef.columnDefs
.map { colDef =>
s"$${row.${colDef.metadata.name.safeIdentifier}}"
}
.mkString(", ")
val whereExpr = byIdWhereExpr("id")
s"""| def insert(row: ${caseClassName}): DbAction[Unit] =
| sql""\"INSERT INTO $${${caseClassName}.TableName} VALUES (
| ${insertExpr}
| )""\".insert()
|""".stripMargin
if (dbType.supportsReturning)
s"""| def insert(row: ${caseClassName}): DbAction[${caseClassName}] =
| sql""\"
| INSERT INTO $${${caseClassName}.TableName}($${${caseClassName}.*})
| VALUES (${insertExpr})
| RETURNING $${${caseClassName}.*}
| ""\".insertReturningRow()
|""".stripMargin
else
s"""| def insert(row: ${caseClassName}): DbAction[Unit] =
| sql""\"
| INSERT INTO $${${caseClassName}.TableName}($${${caseClassName}.*})
| VALUES(${insertExpr})
| ""\".insert()
|""".stripMargin
}
def genUpdateById: Option[String] = Option.when(tableDef.hasPk) {
val updateExpr = tableDef.nonPkColDefs
Expand Down Expand Up @@ -239,16 +253,16 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene
val optionalSelectQueries = Seq(
genFindById,
genFindByIdOpt
).flatten.mkString("\n\n")
).flatten.mkString("\n")
val optionalUpdateQueries = Seq(
genUpdateById,
genDeleteById
).flatten.mkString("\n\n")
).flatten.mkString("\n")

val contents =
s"""|trait ${daoClassName} {
|${genCountAllQuery}

|
|${genFindAllQuery}
|
|${optionalSelectQueries}
Expand All @@ -265,19 +279,11 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene
GeneratedFile(s"${daoClassName}.scala", contents)
}

val squeryDbPackage =
if (dbName.contains("postgres")) "postgres"
else if (dbName.contains("mysql")) "mysql"
else if (dbName.contains("mariadb")) "mariadb"
else if (dbName.contains("oracle")) "oracle"
else if (dbName.contains("h2")) "h2"
else throw new RuntimeException(s"Unknown database type $dbName")

val imports =
s"""|import java.time.*
|import java.util.UUID
|import ba.sake.squery.{*, given}
|import ba.sake.squery.${squeryDbPackage}.given
|import ba.sake.squery.${dbType.squeryPackage}.given
|import ba.sake.squery.read.SqlRead
|import ba.sake.squery.write.SqlWrite
|""".stripMargin
Expand Down

0 comments on commit 68935b7

Please sign in to comment.