diff --git a/generator/src/ba/sake/squery/generator/DbDefExtractor.scala b/generator/src/ba/sake/squery/generator/DbDefExtractor.scala index dc56fa4..d28e211 100644 --- a/generator/src/ba/sake/squery/generator/DbDefExtractor.scala +++ b/generator/src/ba/sake/squery/generator/DbDefExtractor.scala @@ -36,8 +36,10 @@ abstract class DbDefExtractor(ds: DataSource) { val tables = extractTables(connection, schemaName, databaseMetaData) SchemaDef(name = schemaName, tables = tables) } + val dbType = DbType.fromDatabaseProductName(dbName) DbDef( name = dbName, + tpe = dbType, schemas = schemaDefs ) } @@ -135,6 +137,7 @@ abstract class DbDefExtractor(ds: DataSource) { case class DbDef( name: String, + tpe: DbType, schemas: Seq[SchemaDef] ) diff --git a/generator/src/ba/sake/squery/generator/DbType.scala b/generator/src/ba/sake/squery/generator/DbType.scala new file mode 100644 index 0000000..4b67ad4 --- /dev/null +++ b/generator/src/ba/sake/squery/generator/DbType.scala @@ -0,0 +1,25 @@ +package ba.sake.squery.generator + +sealed abstract class DbType(val squeryPackage: String) { + def supportsReturning: Boolean = false +} + +object DbType { + + def fromDatabaseProductName(dbName: String): DbType = { + if (dbName.contains("h2")) H2 + else if (dbName.contains("postgres")) PostgreSQL + else if (dbName.contains("mysql")) MySQL + else if (dbName.contains("mariadb")) MariaDB + else if (dbName.contains("oracle")) Oracle + else throw new RuntimeException(s"Unknown database type $dbName") + } + + case object H2 extends DbType("h2") + case object PostgreSQL extends DbType("postgres") { + override def supportsReturning: Boolean = true + } + case object MySQL extends DbType("mysql") + case object MariaDB extends DbType("mariadb") + case object Oracle extends DbType("oracle") +} diff --git a/generator/src/ba/sake/squery/generator/SqueryGenerator.scala b/generator/src/ba/sake/squery/generator/SqueryGenerator.scala index 43de05d..7833172 100644 --- a/generator/src/ba/sake/squery/generator/SqueryGenerator.scala +++ b/generator/src/ba/sake/squery/generator/SqueryGenerator.scala @@ -24,7 +24,7 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene dbDef.schemas.find(_.name == schemaName) match { case Some(schemaDef) => logger.info(s"Started generating schema '${schemaName}'") - val (imports, enumDefsScala, tableDefsScala) = generateSchema(schemaDef, dbName = dbDef.name) + val (imports, enumDefsScala, tableDefsScala) = generateSchema(schemaDef, dbType = dbDef.tpe) val res = s"""|${Preamble} |${imports} @@ -54,7 +54,7 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene case Some(schemaDef) => logger.info(s"Started generating schema '${schemaConfig.name}' into '${schemaConfig.baseFolder}'") val packagePath = os.RelPath(schemaConfig.basePackage.replaceAll("\\.", "/")) - val (imports, modelDefsScala, daoDefsScala) = generateSchema(schemaDef, dbName = dbDef.name) + val (imports, modelDefsScala, daoDefsScala) = generateSchema(schemaDef, dbType = dbDef.tpe) modelDefsScala.foreach { modelFile => val modelFileWithImports = s"""|${Preamble} @@ -96,7 +96,7 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene // (imports, models, repos) private def generateSchema( schemaDef: SchemaDef, - dbName: String + dbType: DbType ): (String, Seq[GeneratedFile], Seq[GeneratedFile]) = { val enumDefs = schemaDef.tables.flatMap { _.columnDefs.map(_.scalaType).collect { case e: ColumnType.Enumeration => @@ -177,7 +177,6 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene else s"${pkCol.metadata.name} = $${id}" } - def genCountAllQuery: String = s"""| def countAll(): DbAction[Int] = | sql"SELECT COUNT(*) FROM $${${caseClassName}.TableName}".readValue() @@ -200,17 +199,32 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene |""".stripMargin } def genInsert: String = { + // TODO if cols are autoinc + // https://www.jooq.org/doc/latest/manual/sql-building/sql-statements/insert-statement/insert-returning/ + val colsListExpr = tableDef.columnDefs + .map(_.metadata.name) + .mkString(", ") val insertExpr = tableDef.columnDefs .map { colDef => s"$${row.${colDef.metadata.name.safeIdentifier}}" } .mkString(", ") val whereExpr = byIdWhereExpr("id") - s"""| def insert(row: ${caseClassName}): DbAction[Unit] = - | sql""\"INSERT INTO $${${caseClassName}.TableName} VALUES ( - | ${insertExpr} - | )""\".insert() - |""".stripMargin + if (dbType.supportsReturning) + s"""| def insert(row: ${caseClassName}): DbAction[${caseClassName}] = + | sql""\" + | INSERT INTO $${${caseClassName}.TableName}($${${caseClassName}.*}) + | VALUES (${insertExpr}) + | RETURNING $${${caseClassName}.*} + | ""\".insertReturningRow() + |""".stripMargin + else + s"""| def insert(row: ${caseClassName}): DbAction[Unit] = + | sql""\" + | INSERT INTO $${${caseClassName}.TableName}($${${caseClassName}.*}) + | VALUES(${insertExpr}) + | ""\".insert() + |""".stripMargin } def genUpdateById: Option[String] = Option.when(tableDef.hasPk) { val updateExpr = tableDef.nonPkColDefs @@ -239,16 +253,16 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene val optionalSelectQueries = Seq( genFindById, genFindByIdOpt - ).flatten.mkString("\n\n") + ).flatten.mkString("\n") val optionalUpdateQueries = Seq( genUpdateById, genDeleteById - ).flatten.mkString("\n\n") + ).flatten.mkString("\n") val contents = s"""|trait ${daoClassName} { |${genCountAllQuery} - + | |${genFindAllQuery} | |${optionalSelectQueries} @@ -265,19 +279,11 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene GeneratedFile(s"${daoClassName}.scala", contents) } - val squeryDbPackage = - if (dbName.contains("postgres")) "postgres" - else if (dbName.contains("mysql")) "mysql" - else if (dbName.contains("mariadb")) "mariadb" - else if (dbName.contains("oracle")) "oracle" - else if (dbName.contains("h2")) "h2" - else throw new RuntimeException(s"Unknown database type $dbName") - val imports = s"""|import java.time.* |import java.util.UUID |import ba.sake.squery.{*, given} - |import ba.sake.squery.${squeryDbPackage}.given + |import ba.sake.squery.${dbType.squeryPackage}.given |import ba.sake.squery.read.SqlRead |import ba.sake.squery.write.SqlWrite |""".stripMargin