diff --git a/.github/workflows/ci_cd.yml b/.github/workflows/ci_cd.yml index 3bb6a13..c2b58c8 100644 --- a/.github/workflows/ci_cd.yml +++ b/.github/workflows/ci_cd.yml @@ -23,6 +23,7 @@ jobs: distribution: 'temurin' java-version: ${{ matrix.java }} - run: ./mill __.test + #TODO - run: ./mill mill-plugin-itest publish: needs: test diff --git a/build.sc b/build.sc index 9f85f9a..5730a51 100644 --- a/build.sc +++ b/build.sc @@ -75,16 +75,6 @@ object generator extends ScalaModule with CiReleaseModule { ivy"com.lihaoyi::os-lib:0.10.3", ivy"org.apache.commons:commons-text:1.12.0" ) - - object test extends ScalaTests with TestModule.Munit { - def ivyDeps = Agg( - ivy"org.scalameta::munit:1.0.0", - ivy"com.zaxxer:HikariCP:4.0.3", - ivy"org.postgresql:postgresql:42.5.4", - ivy"org.testcontainers:testcontainers:1.17.6", - ivy"org.testcontainers:postgresql:1.17.6" - ) - } } /* MILL PLUGIN */ @@ -116,7 +106,11 @@ object `mill-plugin` extends ScalaModule with CiReleaseModule with ScalafmtModul def moduleDeps = Seq(generator) def ivyDeps = Agg( - ivy"org.postgresql:postgresql:42.6.0" + ivy"com.h2database:h2:2.3.232", + ivy"org.postgresql:postgresql:42.6.0", + ivy"mysql:mysql-connector-java:8.0.33", + ivy"org.mariadb.jdbc:mariadb-java-client:3.3.2", + ivy"com.oracle.database.jdbc:ojdbc8:23.3.0.23.09" ) override def scalacOptions = Seq("-Ywarn-unused", "-deprecation") @@ -134,7 +128,7 @@ object `mill-plugin-itest` extends MillIntegrationTestModule { override def testInvocations: T[Seq[(PathRef, Seq[TestInvocation.Targets])]] = T { Seq( - PathRef(testBase / "simple") -> Seq( + PathRef(testBase / "h2") -> Seq( TestInvocation.Targets(Seq("verify"), noServer = true) ) ) diff --git a/generator/src/ba/sake/squery/generator/DbDefExtractor.scala b/generator/src/ba/sake/squery/generator/DbDefExtractor.scala new file mode 100644 index 0000000..91a5f7a --- /dev/null +++ b/generator/src/ba/sake/squery/generator/DbDefExtractor.scala @@ -0,0 +1,170 @@ +package ba.sake.squery.generator + +import java.sql.{Array => _, _} +import javax.sql.DataSource +import scala.util._ +import scala.util.chaining._ +import scala.collection.mutable.ArrayBuffer +import org.apache.commons.lang3.StringUtils +import org.apache.commons.text.CaseUtils + +object DbDefExtractor { + def apply(ds: DataSource): DbDefExtractor = + Using.resource(ds.getConnection()) { connection => + val databaseMetaData = connection.getMetaData() + val dbName = databaseMetaData.getDatabaseProductName().toLowerCase + dbName match { + case "postgres" => new PostgresDefExtractor(ds) + case _ => new JdbcDefExtractor(ds) + } + } +} + +abstract class DbDefExtractor(ds: DataSource) { + + def extract(): DbDef = Using.resource(ds.getConnection()) { connection => + val databaseMetaData = connection.getMetaData() + val dbName = databaseMetaData.getDatabaseProductName().toLowerCase + val schemaNames = Using.resource(databaseMetaData.getSchemas()) { rs => + val buff = ArrayBuffer.empty[String] + while (rs.next()) { + buff += rs.getString("TABLE_SCHEM") + } + buff.toSeq + } + val schemaDefs = schemaNames.map { schemaName => + val tables = extractTables(connection, schemaName, databaseMetaData) + SchemaDef(name = schemaName, tables = tables) + } + DbDef( + name = dbName, + schemas = schemaDefs + ) + } + + // (table, column) -> ColumnType + protected def getColumnTypes( + connection: Connection, + schemaName: String, + columnsMetadatas: Seq[ColumnMetadata] + ): Map[(String, String), ColumnType] + + private def extractTables( + connection: Connection, + schemaName: String, + databaseMetaData: DatabaseMetaData + ): Seq[TableDef] = { + + val allColumnsMetadatas = extractColumnMetadatas(databaseMetaData, schemaName) + val allColumnTypes = getColumnTypes(connection, schemaName, allColumnsMetadatas) + val allColumnDefs = allColumnsMetadatas.map { cMeta => + val resolvedType = allColumnTypes((cMeta.table, cMeta.name)) + ColumnDef(cMeta, resolvedType) + } + + Using.resource(databaseMetaData.getTables(null, schemaName, null, Array("TABLE"))) { tablesRS => + val tableDefsRes = ArrayBuffer.empty[TableDef] + while (tablesRS.next()) { + val tableName = tablesRS.getString("TABLE_NAME") + val tableColumnDefs = allColumnDefs.filter(_.metadata.table == tableName) + val pkColumns = Using.resource(databaseMetaData.getPrimaryKeys(null, schemaName, tableName)) { pksRS => + val tableDefsRes = ArrayBuffer.empty[ColumnDef] + while (pksRS.next()) { + val pkColName = pksRS.getString("COLUMN_NAME") + tableDefsRes += tableColumnDefs + .find(_.metadata.name == pkColName) + .getOrElse(throw new RuntimeException(s"PK column not found: ${pkColName}")) + } + tableDefsRes.toSeq + } + tableDefsRes += TableDef(schemaName, tableName, tableColumnDefs, pkColumns) + } + tableDefsRes.toSeq + } + } + + private def extractColumnMetadatas( + databaseMetaData: DatabaseMetaData, + schemaName: String + ): Seq[ColumnMetadata] = { + val res = ArrayBuffer.empty[ColumnMetadata] + Using.resource(databaseMetaData.getColumns(null, schemaName, null, null)) { resultSet => + while (resultSet.next()) { + val tableName = resultSet.getString("TABLE_NAME") + val columnName = resultSet.getString("COLUMN_NAME") + val typeName = resultSet.getString("TYPE_NAME") + val jdbcType = resultSet.getInt("DATA_TYPE") // java.sql.Types + val isNullable = resultSet.getString("IS_NULLABLE") == "YES" + val isAutoInc = resultSet.getString("IS_AUTOINCREMENT") == "YES" + val isGenerated = resultSet.getString("IS_GENERATEDCOLUMN") == "YES" + val defaultValue = Option(resultSet.getString("COLUMN_DEF")) + res += ColumnMetadata( + schemaName, + tableName, + columnName, + jdbcType, + isNullable, + isAutoInc, + isGenerated, + defaultValue + ) + } + } + res.toSeq + } + + // test utils + protected def printAll(resultSet: ResultSet) = { + val metadata = resultSet.getMetaData() + val totalCols = metadata.getColumnCount() + var columnNames = Seq.empty[String] + for (i <- 1 to totalCols) { + columnNames = columnNames.appended(metadata.getColumnName(i)) + } + + while (resultSet.next()) { + println("+" * 30) + for (i <- 1 to totalCols) { + val value = resultSet.getString(i) + print(s"${columnNames(i - 1)} = ${value}; ") + } + println() + } + } +} + +case class DbDef( + name: String, + schemas: Seq[SchemaDef] +) + +case class SchemaDef( + name: String, + tables: Seq[TableDef] +) + +case class TableDef(schema: String, name: String, columnDefs: Seq[ColumnDef], pkColumns: Seq[ColumnDef]) + +case class ColumnDef( + metadata: ColumnMetadata, + scalaType: ColumnType +) + +sealed abstract class ColumnType +object ColumnType { + case class Predefined(name: String) extends ColumnType + case class Enumeration(name: String, values: Seq[String]) extends ColumnType + case class Unknown(originalName: String) extends ColumnType +} + +// raw db data +case class ColumnMetadata( + schema: String, + table: String, + name: String, + jdbcType: Int, + isNullable: Boolean, + isAutoInc: Boolean, + isGenerated: Boolean, + defaultValue: Option[String] +) diff --git a/generator/src/ba/sake/squery/generator/JdbcDefExtractor.scala b/generator/src/ba/sake/squery/generator/JdbcDefExtractor.scala new file mode 100644 index 0000000..21a2e3d --- /dev/null +++ b/generator/src/ba/sake/squery/generator/JdbcDefExtractor.scala @@ -0,0 +1,48 @@ +package ba.sake.squery.generator + +import java.sql.{Array => _, _} +import javax.sql.DataSource +import scala.util._ +import scala.util.chaining._ +import scala.collection.mutable.ArrayBuffer +import org.apache.commons.lang3.StringUtils +import org.apache.commons.text.CaseUtils + +/** General data types extractor, based on JDBC metadata + * + * @param ds + */ +class JdbcDefExtractor(ds: DataSource) extends DbDefExtractor(ds) { + + // (table, column) -> ColumnType + override protected def getColumnTypes( + connection: Connection, + schemaName: String, + columnsMetadatas: Seq[ColumnMetadata] + ): Map[(String, String), ColumnType] = { + val databaseMetaData = connection.getMetaData() + columnsMetadatas.map { cMeta => + val tpe = cMeta.jdbcType match { + case Types.BIT => ColumnType.Predefined("Boolean") + case Types.BOOLEAN => ColumnType.Predefined("Boolean") + case Types.TINYINT => ColumnType.Predefined("Byte") + case Types.SMALLINT => ColumnType.Predefined("Short") + case Types.INTEGER => ColumnType.Predefined("Int") + case Types.BIGINT => ColumnType.Predefined("Long") + case Types.DECIMAL => ColumnType.Predefined("Double") + case Types.DOUBLE => ColumnType.Predefined("Double") + case Types.NUMERIC => ColumnType.Predefined("Double") + case Types.NVARCHAR => ColumnType.Predefined("String") + case Types.VARCHAR => ColumnType.Predefined("String") + case Types.DATE => ColumnType.Predefined("LocalDate") + case Types.TIMESTAMP => ColumnType.Predefined("LocalDateTime") + case Types.TIMESTAMP_WITH_TIMEZONE => ColumnType.Predefined("Instant") + case Types.VARBINARY => ColumnType.Predefined("Array[Byte]") + case Types.BINARY => ColumnType.Predefined("Array[Byte]") + case _ => ColumnType.Unknown(cMeta.jdbcType.toString) + } + (cMeta.table, cMeta.name) -> tpe + }.toMap + } + +} diff --git a/generator/src/ba/sake/squery/DbMetadataExtractor.scala b/generator/src/ba/sake/squery/generator/PostgresDefExtractor.scala similarity index 53% rename from generator/src/ba/sake/squery/DbMetadataExtractor.scala rename to generator/src/ba/sake/squery/generator/PostgresDefExtractor.scala index 721cc50..ca3c0ea 100644 --- a/generator/src/ba/sake/squery/DbMetadataExtractor.scala +++ b/generator/src/ba/sake/squery/generator/PostgresDefExtractor.scala @@ -9,77 +9,14 @@ import org.apache.commons.lang3.StringUtils import org.apache.commons.text.CaseUtils // https://stackoverflow.com/a/16624964/4496364 -class DbMetadataExtractor(ds: DataSource) { +class PostgresDefExtractor(ds: DataSource) extends DbDefExtractor(ds) { - def extract(): DbMetadata = Using.resource(ds.getConnection()) { connection => - - val databaseMetaData = connection.getMetaData() - val dbName = databaseMetaData.getDatabaseProductName().toLowerCase - - val schemaNames = Using.resource(databaseMetaData.getSchemas()) { rs => - val buff = ArrayBuffer.empty[String] - while (rs.next()) { - buff += rs.getString("TABLE_SCHEM") - } - buff.toSeq - } - val schemas = schemaNames.map { schemaName => - val tables = extractTables(connection, schemaName, databaseMetaData) - SchemaDef(name = schemaName, tables = tables) - } - DbMetadata( - name = dbName, - schemas = schemas - ) - } - - private def extractTables( + // (table, column) -> ColumnType + override protected def getColumnTypes( connection: Connection, schemaName: String, - databaseMetaData: DatabaseMetaData - ): Seq[TableDef] = { - var readTablesCount = 0 - val columnsMetadata = getColumnsMetadata(connection, schemaName) - Using.resource(databaseMetaData.getTables(null, schemaName, null, Array("TABLE"))) { resultSet => - val res = ArrayBuffer.empty[TableDef] - while (resultSet.next()) { - val schema = resultSet.getString("TABLE_SCHEM") - val tableName = resultSet.getString("TABLE_NAME") - val columnDefs = generateColumnDefs(databaseMetaData, columnsMetadata, schema, tableName) - res += TableDef(schema, tableName, columnDefs) - readTablesCount += 1 - } - res.toSeq - } - } - -// TODO getPrimaryKeys - private def generateColumnDefs( - databaseMetaData: DatabaseMetaData, - columnsMetadata: Map[(String, String), ColumnType], - schemaName: String, - tableName: String - ): Seq[ColumnDef] = { - - val res = ArrayBuffer.empty[ColumnDef] - Using.resource(databaseMetaData.getColumns(null, schemaName, tableName, null)) { resultSet => - while (resultSet.next()) { - val columnName = resultSet.getString("COLUMN_NAME") - val typeName = resultSet.getString("TYPE_NAME") - val jdbcType = resultSet.getInt("DATA_TYPE") // java.sql.Types - val isNullable = resultSet.getString("IS_NULLABLE") == "YES" - val isAutoInc = resultSet.getString("IS_AUTOINCREMENT") == "YES" - val isGenerated = resultSet.getString("IS_GENERATEDCOLUMN") == "YES" - val defaultValue = Option(resultSet.getString("COLUMN_DEF")) - val resolvedType = columnsMetadata((tableName, columnName)) - res += ColumnDef(columnName, resolvedType, isNullable, isAutoInc, isGenerated, defaultValue) - } - } - res.toSeq - } - - // (table, column) -> ColumnType - private def getColumnsMetadata(connection: Connection, schemaName: String): Map[(String, String), ColumnType] = { + columnsMetadatas: Seq[ColumnMetadata] + ): Map[(String, String), ColumnType] = { val query = s""" SELECT ns.nspname AS schema_name, tbl.relname AS table_name, @@ -159,7 +96,7 @@ class DbMetadataExtractor(ds: DataSource) { } } - def resolveEnumType(connection: Connection, typeName: String): Try[ColumnType.Enumeration] = + private def resolveEnumType(connection: Connection, typeName: String): Try[ColumnType.Enumeration] = Using(connection.createStatement()) { stmt => val resultSet = stmt.executeQuery(s"select unnest(enum_range(null, null::${typeName}))") val enumValues = ArrayBuffer.empty[String] @@ -170,50 +107,4 @@ class DbMetadataExtractor(ds: DataSource) { else ColumnType.Enumeration(typeName, enumValues.toSeq) } - // test utils - private def printAll(resultSet: ResultSet) = { - val metadata = resultSet.getMetaData() - val totalCols = metadata.getColumnCount() - var columnNames = Seq.empty[String] - for (i <- 1 to totalCols) { - columnNames = columnNames.appended(metadata.getColumnName(i)) - } - - while (resultSet.next()) { - println("+" * 30) - for (i <- 1 to totalCols) { - val value = resultSet.getString(i) - print(s"${columnNames(i - 1)} = ${value}; ") - } - println() - } - } -} - -case class DbMetadata( - name: String, - schemas: Seq[SchemaDef] -) - -case class SchemaDef( - name: String, - tables: Seq[TableDef] -) - -case class TableDef(schema: String, name: String, columnDefs: Seq[ColumnDef]) - -case class ColumnDef( - name: String, - scalaType: ColumnType, // scala type - isNullable: Boolean, - isAutoInc: Boolean, - isGenerated: Boolean, - defaultValue: Option[String] -) - -sealed abstract class ColumnType -object ColumnType { - case class Predefined(name: String) extends ColumnType - case class Enumeration(name: String, values: Seq[String]) extends ColumnType - case class Unknown(originalName: String) extends ColumnType } diff --git a/generator/src/ba/sake/squery/SqueryGenerator.scala b/generator/src/ba/sake/squery/generator/SqueryGenerator.scala similarity index 67% rename from generator/src/ba/sake/squery/SqueryGenerator.scala rename to generator/src/ba/sake/squery/generator/SqueryGenerator.scala index 861a4ca..ba122f4 100644 --- a/generator/src/ba/sake/squery/SqueryGenerator.scala +++ b/generator/src/ba/sake/squery/generator/SqueryGenerator.scala @@ -1,7 +1,6 @@ package ba.sake.squery.generator import java.sql.{Array => _, _} - import scala.util._ import org.apache.commons.text.CaseUtils import com.typesafe.scalalogging.Logger @@ -12,13 +11,13 @@ class SqueryGenerator(config: SqueryGeneratorConfig = SqueryGeneratorConfig.Defa private val Preamble = "/* DO NOT EDIT MANUALLY! Automatically generated by squery generator */" - def generateString(dbMeta: DbMetadata, schemaNames: Seq[String]): String = + def generateString(dbDef: DbDef, schemaNames: Seq[String]): String = schemaNames .map { schemaName => - dbMeta.schemas.find(_.name == schemaName) match { + dbDef.schemas.find(_.name == schemaName) match { case Some(schemaDef) => logger.info(s"Started generating schema '${schemaName}'") - val (imports, enumDefsScala, tableDefsScala) = generateSchema(schemaDef, dbName = dbMeta.name) + val (imports, enumDefsScala, tableDefsScala) = generateSchema(schemaDef, dbName = dbDef.name) val res = s"""|${Preamble} |${imports} @@ -35,40 +34,41 @@ class SqueryGenerator(config: SqueryGeneratorConfig = SqueryGeneratorConfig.Defa } .mkString("\n") - def generateFiles(dbMeta: DbMetadata, schemaConfigs: Seq[SchemaConfig]): Unit = + def generateFiles(dbDef: DbDef, schemaConfigs: Seq[SchemaConfig]): Unit = schemaConfigs.foreach { schemaConfig => - dbMeta.schemas.find(_.name == schemaConfig.name) match { + dbDef.schemas.find(_.name == schemaConfig.name) match { case Some(schemaDef) => logger.info(s"Started generating schema '${schemaConfig.name}' into '${schemaConfig.baseFolder}'") val packagePath = os.RelPath(schemaConfig.basePackage.replaceAll("\\.", "/")) - val (imports, enumDefsScala, tableDefsScala) = generateSchema(schemaDef, dbName = dbMeta.name) - enumDefsScala.foreach { enumFile => - val enumDefWithImports = + val (imports, modelDefsScala, daoDefsScala) = generateSchema(schemaDef, dbName = dbDef.name) + modelDefsScala.foreach { modelFile => + val modelFileWithImports = s"""|${Preamble} |package ${schemaConfig.basePackage}.models | |${imports} | - |${enumFile.content} + |${modelFile.content} |""".stripMargin os.write.over( - schemaConfig.baseFolder / packagePath / "models" / enumFile.baseName, - enumDefWithImports, + schemaConfig.baseFolder / packagePath / "models" / modelFile.baseName, + modelFileWithImports, createFolders = true ) } - tableDefsScala.foreach { tableDef => - val tableDefWithImports = + daoDefsScala.foreach { daoFile => + val daoFileWithImports = s"""|${Preamble} - |package ${schemaConfig.basePackage}.models + |package ${schemaConfig.basePackage}.daos | |${imports} + |import ${schemaConfig.basePackage}.models.* | - |${tableDef.content} + |${daoFile.content} |""".stripMargin os.write.over( - schemaConfig.baseFolder / packagePath / "models" / tableDef.baseName, - tableDefWithImports, + schemaConfig.baseFolder / packagePath / "daos" / daoFile.baseName, + daoFileWithImports, createFolders = true ) } @@ -78,6 +78,7 @@ class SqueryGenerator(config: SqueryGeneratorConfig = SqueryGeneratorConfig.Defa } } + // (imports, models, repos) private def generateSchema( schemaDef: SchemaDef, dbName: String @@ -92,40 +93,58 @@ class SqueryGenerator(config: SqueryGeneratorConfig = SqueryGeneratorConfig.Defa s" case ${enumDefCaseValue.safeIdentifier}" } val enumName = transformName(enumDef.name, config.typeNameMapper) - val contents = s"""|enum ${enumName} derives SqlRead, SqlWrite: - |${enumCaseDefs.mkString("\n")} - |""".stripMargin + val contents = + s"""|enum ${enumName} derives SqlRead, SqlWrite: + |${enumCaseDefs.mkString("\n")} + |""".stripMargin GeneratedFile(s"${enumName}.scala", contents) } val tableFiles = schemaDef.tables.map { tableDef => val columnDefsScala = tableDef.columnDefs.map { columnDef => val safeTypeName = getSafeTypeName(columnDef.scalaType, config.typeNameMapper) - val tpe = if (columnDef.isNullable) s"Option[${safeTypeName}]" else safeTypeName - s" ${columnDef.name.safeIdentifier}: ${tpe}" + val tpe = if (columnDef.metadata.isNullable) s"Option[${safeTypeName}]" else safeTypeName + s" ${columnDef.metadata.name.safeIdentifier}: ${tpe}" } val columnNamesScala = tableDef.columnDefs.map { columnDef => - s""" inline val ${columnDef.name.safeIdentifier} = "${columnDef.name.safeIdentifier}"""" + s""" inline val ${columnDef.metadata.name.safeIdentifier} = "${columnDef.metadata.name.safeIdentifier}"""" } val prefixedColumnNamesScala = tableDef.columnDefs.map { columnDef => - s"""prefix + ${columnDef.name.safeIdentifier}""" + s"""prefix + ${columnDef.metadata.name.safeIdentifier}""" } val caseClassName = transformName(tableDef.name, config.typeNameMapper) + config.rowTypeSuffix - val contents = s"""|case class ${caseClassName}( - |${columnDefsScala.mkString(",\n")} - |) derives SqlReadRow - | - |object ${caseClassName} { - |${columnNamesScala.mkString("\n")} - | - | inline val * = prefixed("") - | - | transparent inline def prefixed(inline prefix: String) = - | ${prefixedColumnNamesScala.mkString(""" + ", " + """)} - |} - |""".stripMargin + val contents = + s"""|case class ${caseClassName}( + |${columnDefsScala.mkString(",\n")} + |) derives SqlReadRow + | + |object ${caseClassName} { + |${columnNamesScala.mkString("\n")} + | + | inline val TableName = "${tableDef.schema}.${tableDef.name}" + | + | inline val * = prefixed("") + | + | transparent inline def prefixed(inline prefix: String) = + | ${prefixedColumnNamesScala.mkString(""" + ", " + """)} + |} + |""".stripMargin GeneratedFile(s"${caseClassName}.scala", contents) + } + val daoFiles = schemaDef.tables.map { tableDef => + val caseClassName = transformName(tableDef.name, config.typeNameMapper) + config.rowTypeSuffix + val daoClassName = transformName(tableDef.name, config.typeNameMapper) + "CrudDao" + val contents = + s"""|trait ${daoClassName} { + | def countAll(): DbAction[Int] = sql"SELECT COUNT(*) FROM $${${caseClassName}.TableName}".readValue() + |} + | + |object ${daoClassName} extends ${daoClassName} { + | + |} + |""".stripMargin + GeneratedFile(s"${daoClassName}.scala", contents) } val squeryDbPackage = @@ -139,13 +158,13 @@ class SqueryGenerator(config: SqueryGeneratorConfig = SqueryGeneratorConfig.Defa val imports = s"""|import java.time.* |import java.util.UUID - |import ba.sake.squery.* + |import ba.sake.squery.{*, given} |import ba.sake.squery.${squeryDbPackage}.given |import ba.sake.squery.read.SqlRead |import ba.sake.squery.write.SqlWrite |""".stripMargin - (imports, enumFiles, tableFiles) + (imports, enumFiles ++ tableFiles, daoFiles) } private def transformName(str: String, nameMapper: NameMapper): String = diff --git a/mill-plugin/src/ba/sake/squery/generator/mill/SqueryGeneratorModule.scala b/mill-plugin/src/ba/sake/squery/generator/mill/SqueryGeneratorModule.scala index 96c842f..b0b937b 100644 --- a/mill-plugin/src/ba/sake/squery/generator/mill/SqueryGeneratorModule.scala +++ b/mill-plugin/src/ba/sake/squery/generator/mill/SqueryGeneratorModule.scala @@ -3,40 +3,62 @@ package ba.sake.squery.generator.mill import mill._ import mill.scalalib._ import _root_.ba.sake.squery.generator._ -import org.postgresql.ds.PGSimpleDataSource trait SqueryGeneratorModule extends JavaModule { - def squeryServer: T[String] = T("localhost") - def squeryPort: T[Int] - + def squeryJdbcUrl: T[String] def squeryUsername: T[String] def squeryPassword: T[String] - def squeryDatabase: T[String] - - /** List of (schema, basePackage) - */ + /** List of (schema, basePackage) */ def squerySchemas: T[Seq[(String, String)]] def squeryTargetDir: T[os.Path] = T(millSourcePath / "src") def squeryGenerate(): Command[Unit] = T.command { - println("Started generating Squery sources...") - - // TODO parametrize db type - val ds = new PGSimpleDataSource() - ds.setUser(squeryUsername()) - ds.setPassword(squeryPassword()) - ds.setDatabaseName(squeryDatabase()) - ds.setServerNames(Array(squeryServer())) - ds.setPortNumbers(Array(squeryPort())) - - val extractor = new DbMetadataExtractor(ds) - val dbMeta = extractor.extract() + println("Starting to generate Squery sources...") + + val jdbcUrl = squeryJdbcUrl() + val username = squeryUsername() + val password = squeryPassword() + val dataSource: javax.sql.DataSource = + if (jdbcUrl.startsWith("jdbc:h2:")) { + val ds = new org.h2.jdbcx.JdbcDataSource() + ds.setURL(jdbcUrl) + ds.setUser(username) + ds.setPassword(password) + ds + } else if (jdbcUrl.startsWith("jdbc:postgresql:")) { + val ds = new org.postgresql.ds.PGSimpleDataSource() + ds.setURL(jdbcUrl) + ds.setUser(username) + ds.setPassword(password) + ds + } else if (jdbcUrl.startsWith("jdbc:mysql:")) { + val ds = new com.mysql.cj.jdbc.MysqlDataSource() + ds.setURL(jdbcUrl) + ds.setUser(username) + ds.setPassword(password) + ds + } else if (jdbcUrl.startsWith("jdbc:mariadb:")) { + val ds = new org.mariadb.jdbc.MariaDbDataSource() + ds.setUrl(jdbcUrl) + ds.setUser(username) + ds.setPassword(password) + ds + } else if (jdbcUrl.startsWith("jdbc:oracle:")) { + val ds = new oracle.jdbc.pool.OracleDataSource() + ds.setURL(jdbcUrl) + ds.setUser(username) + ds.setPassword(password) + ds + } else throw new RuntimeException(s"Unsupported database ${jdbcUrl}") + + val extractor = DbDefExtractor(dataSource) + val dbDef = extractor.extract() val generator = new SqueryGenerator() generator.generateFiles( - dbMeta, + dbDef, squerySchemas().map { case (schemaName, basePackage) => SchemaConfig( name = schemaName, @@ -46,7 +68,7 @@ trait SqueryGeneratorModule extends JavaModule { } ) - println("Finished generating Squery sources...") + println("Finished generating Squery sources") } } diff --git a/squery/src/ba/sake/squery/read/SqlRead.scala b/squery/src/ba/sake/squery/read/SqlRead.scala index 84d1100..63c6f9a 100644 --- a/squery/src/ba/sake/squery/read/SqlRead.scala +++ b/squery/src/ba/sake/squery/read/SqlRead.scala @@ -29,6 +29,14 @@ object SqlRead { Option(jRes.getBoolean(colIdx)).filterNot(_ => jRes.wasNull()) } + given SqlRead[Byte] with { + def readByName(jRes: jsql.ResultSet, colName: String): Option[Byte] = + Option(jRes.getByte(colName)) + + def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Byte] = + Option(jRes.getByte(colIdx)) + } + given SqlRead[Short] with { def readByName(jRes: jsql.ResultSet, colName: String): Option[Short] = Option(jRes.getShort(colName)).filterNot(_ => jRes.wasNull())