Skip to content

Commit

Permalink
Dispatch generator DS on JDBC url. Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
sake92 committed Aug 16, 2024
1 parent 98f96c1 commit a20b4af
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 189 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jobs:
distribution: 'temurin'
java-version: ${{ matrix.java }}
- run: ./mill __.test
#TODO - run: ./mill mill-plugin-itest

publish:
needs: test
Expand Down
18 changes: 6 additions & 12 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,6 @@ object generator extends ScalaModule with CiReleaseModule {
ivy"com.lihaoyi::os-lib:0.10.3",
ivy"org.apache.commons:commons-text:1.12.0"
)

object test extends ScalaTests with TestModule.Munit {
def ivyDeps = Agg(
ivy"org.scalameta::munit:1.0.0",
ivy"com.zaxxer:HikariCP:4.0.3",
ivy"org.postgresql:postgresql:42.5.4",
ivy"org.testcontainers:testcontainers:1.17.6",
ivy"org.testcontainers:postgresql:1.17.6"
)
}
}

/* MILL PLUGIN */
Expand Down Expand Up @@ -116,7 +106,11 @@ object `mill-plugin` extends ScalaModule with CiReleaseModule with ScalafmtModul
def moduleDeps = Seq(generator)

def ivyDeps = Agg(
ivy"org.postgresql:postgresql:42.6.0"
ivy"com.h2database:h2:2.3.232",
ivy"org.postgresql:postgresql:42.6.0",
ivy"mysql:mysql-connector-java:8.0.33",
ivy"org.mariadb.jdbc:mariadb-java-client:3.3.2",
ivy"com.oracle.database.jdbc:ojdbc8:23.3.0.23.09"
)

override def scalacOptions = Seq("-Ywarn-unused", "-deprecation")
Expand All @@ -134,7 +128,7 @@ object `mill-plugin-itest` extends MillIntegrationTestModule {
override def testInvocations: T[Seq[(PathRef, Seq[TestInvocation.Targets])]] =
T {
Seq(
PathRef(testBase / "simple") -> Seq(
PathRef(testBase / "h2") -> Seq(
TestInvocation.Targets(Seq("verify"), noServer = true)
)
)
Expand Down
170 changes: 170 additions & 0 deletions generator/src/ba/sake/squery/generator/DbDefExtractor.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package ba.sake.squery.generator

import java.sql.{Array => _, _}
import javax.sql.DataSource
import scala.util._
import scala.util.chaining._
import scala.collection.mutable.ArrayBuffer
import org.apache.commons.lang3.StringUtils
import org.apache.commons.text.CaseUtils

object DbDefExtractor {
def apply(ds: DataSource): DbDefExtractor =
Using.resource(ds.getConnection()) { connection =>
val databaseMetaData = connection.getMetaData()
val dbName = databaseMetaData.getDatabaseProductName().toLowerCase
dbName match {
case "postgres" => new PostgresDefExtractor(ds)
case _ => new JdbcDefExtractor(ds)
}
}
}

abstract class DbDefExtractor(ds: DataSource) {

def extract(): DbDef = Using.resource(ds.getConnection()) { connection =>
val databaseMetaData = connection.getMetaData()
val dbName = databaseMetaData.getDatabaseProductName().toLowerCase
val schemaNames = Using.resource(databaseMetaData.getSchemas()) { rs =>
val buff = ArrayBuffer.empty[String]
while (rs.next()) {
buff += rs.getString("TABLE_SCHEM")
}
buff.toSeq
}
val schemaDefs = schemaNames.map { schemaName =>
val tables = extractTables(connection, schemaName, databaseMetaData)
SchemaDef(name = schemaName, tables = tables)
}
DbDef(
name = dbName,
schemas = schemaDefs
)
}

// (table, column) -> ColumnType
protected def getColumnTypes(
connection: Connection,
schemaName: String,
columnsMetadatas: Seq[ColumnMetadata]
): Map[(String, String), ColumnType]

private def extractTables(
connection: Connection,
schemaName: String,
databaseMetaData: DatabaseMetaData
): Seq[TableDef] = {

val allColumnsMetadatas = extractColumnMetadatas(databaseMetaData, schemaName)
val allColumnTypes = getColumnTypes(connection, schemaName, allColumnsMetadatas)
val allColumnDefs = allColumnsMetadatas.map { cMeta =>
val resolvedType = allColumnTypes((cMeta.table, cMeta.name))
ColumnDef(cMeta, resolvedType)
}

Using.resource(databaseMetaData.getTables(null, schemaName, null, Array("TABLE"))) { tablesRS =>
val tableDefsRes = ArrayBuffer.empty[TableDef]
while (tablesRS.next()) {
val tableName = tablesRS.getString("TABLE_NAME")
val tableColumnDefs = allColumnDefs.filter(_.metadata.table == tableName)
val pkColumns = Using.resource(databaseMetaData.getPrimaryKeys(null, schemaName, tableName)) { pksRS =>
val tableDefsRes = ArrayBuffer.empty[ColumnDef]
while (pksRS.next()) {
val pkColName = pksRS.getString("COLUMN_NAME")
tableDefsRes += tableColumnDefs
.find(_.metadata.name == pkColName)
.getOrElse(throw new RuntimeException(s"PK column not found: ${pkColName}"))
}
tableDefsRes.toSeq
}
tableDefsRes += TableDef(schemaName, tableName, tableColumnDefs, pkColumns)
}
tableDefsRes.toSeq
}
}

private def extractColumnMetadatas(
databaseMetaData: DatabaseMetaData,
schemaName: String
): Seq[ColumnMetadata] = {
val res = ArrayBuffer.empty[ColumnMetadata]
Using.resource(databaseMetaData.getColumns(null, schemaName, null, null)) { resultSet =>
while (resultSet.next()) {
val tableName = resultSet.getString("TABLE_NAME")
val columnName = resultSet.getString("COLUMN_NAME")
val typeName = resultSet.getString("TYPE_NAME")
val jdbcType = resultSet.getInt("DATA_TYPE") // java.sql.Types
val isNullable = resultSet.getString("IS_NULLABLE") == "YES"
val isAutoInc = resultSet.getString("IS_AUTOINCREMENT") == "YES"
val isGenerated = resultSet.getString("IS_GENERATEDCOLUMN") == "YES"
val defaultValue = Option(resultSet.getString("COLUMN_DEF"))
res += ColumnMetadata(
schemaName,
tableName,
columnName,
jdbcType,
isNullable,
isAutoInc,
isGenerated,
defaultValue
)
}
}
res.toSeq
}

// test utils
protected def printAll(resultSet: ResultSet) = {
val metadata = resultSet.getMetaData()
val totalCols = metadata.getColumnCount()
var columnNames = Seq.empty[String]
for (i <- 1 to totalCols) {
columnNames = columnNames.appended(metadata.getColumnName(i))
}

while (resultSet.next()) {
println("+" * 30)
for (i <- 1 to totalCols) {
val value = resultSet.getString(i)
print(s"${columnNames(i - 1)} = ${value}; ")
}
println()
}
}
}

case class DbDef(
name: String,
schemas: Seq[SchemaDef]
)

case class SchemaDef(
name: String,
tables: Seq[TableDef]
)

case class TableDef(schema: String, name: String, columnDefs: Seq[ColumnDef], pkColumns: Seq[ColumnDef])

case class ColumnDef(
metadata: ColumnMetadata,
scalaType: ColumnType
)

sealed abstract class ColumnType
object ColumnType {
case class Predefined(name: String) extends ColumnType
case class Enumeration(name: String, values: Seq[String]) extends ColumnType
case class Unknown(originalName: String) extends ColumnType
}

// raw db data
case class ColumnMetadata(
schema: String,
table: String,
name: String,
jdbcType: Int,
isNullable: Boolean,
isAutoInc: Boolean,
isGenerated: Boolean,
defaultValue: Option[String]
)
48 changes: 48 additions & 0 deletions generator/src/ba/sake/squery/generator/JdbcDefExtractor.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package ba.sake.squery.generator

import java.sql.{Array => _, _}
import javax.sql.DataSource
import scala.util._
import scala.util.chaining._
import scala.collection.mutable.ArrayBuffer
import org.apache.commons.lang3.StringUtils
import org.apache.commons.text.CaseUtils

/** General data types extractor, based on JDBC metadata
*
* @param ds
*/
class JdbcDefExtractor(ds: DataSource) extends DbDefExtractor(ds) {

// (table, column) -> ColumnType
override protected def getColumnTypes(
connection: Connection,
schemaName: String,
columnsMetadatas: Seq[ColumnMetadata]
): Map[(String, String), ColumnType] = {
val databaseMetaData = connection.getMetaData()
columnsMetadatas.map { cMeta =>
val tpe = cMeta.jdbcType match {
case Types.BIT => ColumnType.Predefined("Boolean")
case Types.BOOLEAN => ColumnType.Predefined("Boolean")
case Types.TINYINT => ColumnType.Predefined("Byte")
case Types.SMALLINT => ColumnType.Predefined("Short")
case Types.INTEGER => ColumnType.Predefined("Int")
case Types.BIGINT => ColumnType.Predefined("Long")
case Types.DECIMAL => ColumnType.Predefined("Double")
case Types.DOUBLE => ColumnType.Predefined("Double")
case Types.NUMERIC => ColumnType.Predefined("Double")
case Types.NVARCHAR => ColumnType.Predefined("String")
case Types.VARCHAR => ColumnType.Predefined("String")
case Types.DATE => ColumnType.Predefined("LocalDate")
case Types.TIMESTAMP => ColumnType.Predefined("LocalDateTime")
case Types.TIMESTAMP_WITH_TIMEZONE => ColumnType.Predefined("Instant")
case Types.VARBINARY => ColumnType.Predefined("Array[Byte]")
case Types.BINARY => ColumnType.Predefined("Array[Byte]")
case _ => ColumnType.Unknown(cMeta.jdbcType.toString)
}
(cMeta.table, cMeta.name) -> tpe
}.toMap
}

}
Loading

0 comments on commit a20b4af

Please sign in to comment.