Skip to content

Commit

Permalink
#384 Defer Foreign references during creation
Browse files Browse the repository at this point in the history
  • Loading branch information
Tapac committed Sep 20, 2018
1 parent 3c0f2ec commit f410859
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 96 deletions.
32 changes: 1 addition & 31 deletions src/main/kotlin/org/jetbrains/exposed/dao/Entity.kt
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ class EntityCache {

val insertedTables = inserts.keys

for (t in sortTablesByReferences(tables)) {
for (t in SchemaUtils.sortTablesByReferences(tables)) {
flushInserts(t as IdTable<*>)
}

Expand Down Expand Up @@ -466,36 +466,6 @@ class EntityCache {
it.expireCache()
}
}

fun sortTablesByReferences(tables: Iterable<Table>) = addDependencies(tables).toCollection(arrayListOf()).run {
if(this.count() <= 1) return this
val canBeReferenced = arrayListOf<Table>()
do {
val (movable, others) = partition {
it.columns.all { it.referee == null || canBeReferenced.contains(it.referee!!.table) || it.referee!!.table == it.table}
}
canBeReferenced.addAll(movable)
this.removeAll(movable)
} while (others.isNotEmpty() && movable.isNotEmpty())
canBeReferenced.addAll(this)
canBeReferenced
}

fun addDependencies(tables: Iterable<Table>): Iterable<Table> {
val workset = HashSet<Table>()

fun checkTable(table: Table) {
if (workset.add(table)) {
for (c in table.columns) {
c.referee?.table?.let { checkTable(it) }
}
}
}

for (t in tables) checkTable(t)

return workset
}
}
}

Expand Down
10 changes: 6 additions & 4 deletions src/main/kotlin/org/jetbrains/exposed/sql/Constraints.kt
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ data class ForeignKeyConstraint(val fkName: String, val refereeTable: String, va
companion object {
fun from(column: Column<*>): ForeignKeyConstraint {
assert(column.referee != null && (column.onDelete != null || column.onUpdate != null)) { "$column does not reference anything" }
val s = TransactionManager.current()
return ForeignKeyConstraint("", s.identity(column.referee!!.table),
s.identity(column.referee!!), s.identity(column.table), s.identity(column),
val referee = column.referee!!
val t = TransactionManager.current()
val refName = t.quoteIfNecessary(t.cutIfNecessary("fk_${referee.table.tableName}_${referee.name}_${column.name}"))
return ForeignKeyConstraint(refName, t.identity(referee.table), t.identity(referee),
t.identity(column.table), t.identity(column),
column.onUpdate ?: ReferenceOption.NO_ACTION,
column.onDelete ?: ReferenceOption.NO_ACTION)
}
Expand All @@ -56,7 +58,7 @@ data class ForeignKeyConstraint(val fkName: String, val refereeTable: String, va
}
}

override fun createStatement() = listOf("ALTER TABLE $referencedTable ADD" + if (fkName.isNotBlank()) " CONSTRAINT $fkName" else "" + foreignKeyPart)
override fun createStatement() = listOf("ALTER TABLE $referencedTable ADD" + (if (fkName.isNotBlank()) " CONSTRAINT $fkName" else "") + foreignKeyPart)

override fun dropStatement() = listOf("ALTER TABLE $refereeTable DROP " +
when (currentDialect) {
Expand Down
86 changes: 71 additions & 15 deletions src/main/kotlin/org/jetbrains/exposed/sql/SchemaUtils.kt
Original file line number Diff line number Diff line change
@@ -1,33 +1,89 @@
package org.jetbrains.exposed.sql

import org.jetbrains.exposed.dao.EntityCache
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.jetbrains.exposed.sql.vendors.currentDialect
import org.jetbrains.exposed.sql.vendors.inProperCase
import java.util.*

object SchemaUtils {
fun createStatements(vararg tables: Table): List<String> {
val statements = ArrayList<String>()
if (tables.isEmpty())
return statements
private class TableDepthGraph(val tables: List<Table>) {
val graph = fetchAllTables().associate { t ->
t to t.columns.mapNotNull { c ->
c.referee?.let{ it.table to c.columnType.nullable }
}.toMap()
}

private fun fetchAllTables(): HashSet<Table> {
val result = HashSet<Table>()

fun parseTable(table: Table) {
if (result.add(table)) {
table.columns.forEach {
it.referee?.table?.let(::parseTable)
}
}
}
tables.forEach(::parseTable)
return result
}

fun sorted() : List<Table> {
val visited = mutableSetOf<Table>()
val result = arrayListOf<Table>()

val newTables = ArrayList<Table>()
fun traverse(table: Table) {
if (table !in visited) {
visited += table
graph[table]!!.forEach { t, u ->
if (t !in visited) {
traverse(t)
}
}
result += table
}
}

for (table in EntityCache.sortTablesByReferences(tables.toList())) {
tables.forEach(::traverse)
return result
}

if (table.exists()) continue else newTables.add(table)
fun hasCycle() : Boolean {
val visited = mutableSetOf<Table>()
val recursion = mutableSetOf<Table>()

// create table
statements.addAll(table.ddl)
val sortedTables = sorted()

// create indices
for (index in table.indices) {
statements.addAll(createIndex(index))
fun traverse(table: Table) : Boolean {
if (table in recursion) return true
if (table in visited) return false
recursion += table
visited += table
return if (graph[table]!!.any{ traverse(it.key) }) {
true
} else {
recursion -= table
false
}
}
return sortedTables.any { traverse(it) }
}
}

fun sortTablesByReferences(tables: Iterable<Table>) = TableDepthGraph(tables.toList()).sorted()
fun checkCycle(vararg tables: Table) = TableDepthGraph(tables.toList()).hasCycle()

fun createStatements(vararg tables: Table): List<String> {
if (tables.isEmpty())
return emptyList()

return statements
val toCreate = sortTablesByReferences(tables.toList()).filterNot { it.exists() }
val alters = arrayListOf<String>()
return toCreate.flatMap { table ->
val (create, alter) = table.ddl.partition { it.startsWith("CREATE ") }
val indicesDDL = table.indices.flatMap { createIndex(it) }
alters += alter
create + indicesDDL
} + alters
}

fun createSequence(name: String) = Seq(name).createStatement()
Expand Down Expand Up @@ -185,7 +241,7 @@ object SchemaUtils {
if (tables.isEmpty()) return
val transaction = TransactionManager.current()
transaction.flushCache()
var tablesForDeletion = EntityCache
var tablesForDeletion = SchemaUtils
.sortTablesByReferences(tables.toList())
.reversed()
.filter { it in tables }
Expand Down
40 changes: 25 additions & 15 deletions src/main/kotlin/org/jetbrains/exposed/sql/Table.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ package org.jetbrains.exposed.sql
import org.jetbrains.exposed.dao.EntityID
import org.jetbrains.exposed.dao.IdTable
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.jetbrains.exposed.sql.vendors.OracleDialect
import org.jetbrains.exposed.sql.vendors.currentDialect
import org.jetbrains.exposed.sql.vendors.currentDialectIfAvailable
import org.jetbrains.exposed.sql.vendors.inProperCase
import org.jetbrains.exposed.sql.vendors.*
import org.joda.time.DateTime
import java.math.BigDecimal
import java.sql.Blob
Expand Down Expand Up @@ -485,6 +482,8 @@ open class Table(name: String = ""): ColumnSet(), DdlAware {
Seq(it).createStatement()
}.orEmpty()

val addForeignKeysInAlterPart = SchemaUtils.checkCycle(this) && currentDialect !is SQLiteDialect

val createTableDDL = buildString {
append("CREATE TABLE ")
if (currentDialect.supportsIfNotExists) {
Expand All @@ -498,24 +497,31 @@ open class Table(name: String = ""): ColumnSet(), DdlAware {
append(", $it")
}
}
columns.filter { it.referee != null }.let { references ->
if (references.isNotEmpty()) {
append(references.joinToString(prefix = ", ", separator = ", ") { ForeignKeyConstraint.from(it).foreignKeyPart })

if (!addForeignKeysInAlterPart) {
columns.filter { it.referee != null }.takeIf { it.isNotEmpty() }?.let { references ->
references.joinTo(this, prefix = ", ", separator = ", ") { ForeignKeyConstraint.from(it).foreignKeyPart }
}
}

if (checkConstraints.isNotEmpty()) {
append(
checkConstraints.mapIndexed { index, (name, op) ->
val resolvedName = name.takeIf { it.isNotBlank() } ?: "check_${tableName}_$index"
CheckConstraint.from(this@Table, resolvedName, op).checkPart
}.joinToString(prefix = ",", separator = ",")
)
checkConstraints.asSequence().mapIndexed { index, (name, op) ->
val resolvedName = name.takeIf { it.isNotBlank() } ?: "check_${tableName}_$index"
CheckConstraint.from(this@Table, resolvedName, op).checkPart
}.joinTo(this, prefix = ",", separator = ",")
}

append(")")
}
}
return seqDDL + createTableDDL

val constraintSQL = if (addForeignKeysInAlterPart) {
columns.filter { it.referee != null }.flatMap {
ForeignKeyConstraint.from(it).createStatement()
}
} else emptyList()

return seqDDL + createTableDDL + constraintSQL
}

internal fun primaryKeyConstraint(): String? {
Expand All @@ -538,12 +544,16 @@ open class Table(name: String = ""): ColumnSet(), DdlAware {
val dropTableDDL = buildString {
append("DROP TABLE ")
if (currentDialect.supportsIfNotExists) {
append(" IF EXISTS ")
append("IF EXISTS ")
}
append(TransactionManager.current().identity(this@Table))
if (currentDialectIfAvailable is OracleDialect) {
append(" CASCADE CONSTRAINTS")
}

if (currentDialectIfAvailable is PostgreSQLDialect && SchemaUtils.checkCycle(this@Table)) {
append(" CASCADE")
}
}
val seqDDL = autoIncColumn?.autoIncSeqName?.let {
Seq(it).dropStatement()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ class DDLTests : DatabaseTestsBase() {
}
}

withTables(excludeSettings = listOf(TestDB.H2, TestDB.H2_MYSQL, TestDB.SQLITE), tables = initialTable) {
withTables(excludeSettings = listOf(TestDB.H2, TestDB.H2_MYSQL, TestDB.SQLITE), tables = *arrayOf(initialTable)) {
assertEquals("ALTER TABLE ${tableName.inProperCase()} ADD ${"id".inProperCase()} ${t.id.columnType.sqlType()} PRIMARY KEY", t.id.ddl)
assertEquals(1, currentDialect.tableColumns(t)[t]!!.size)
SchemaUtils.createMissingTablesAndColumns(t)
Expand Down Expand Up @@ -489,6 +489,36 @@ class DDLTests : DatabaseTestsBase() {
}
}

object Table1 : IntIdTable() {
val table2 = reference("teamId", Table2, onDelete = ReferenceOption.CASCADE)
}

object Table2 : IntIdTable() {
val table1 = optReference("teamId", Table1, onDelete = ReferenceOption.SET_NULL)
}

@Test fun testCrossReference() {
withTables(Table2, Table1) {
val table2id = Table2.insertAndGetId{}
val table1id = Table1.insertAndGetId {
it[Table1.table2] = table2id
}

Table2.insertAndGetId {
it[Table2.table1] = table1id
}

assertEquals(1, Table1.selectAll().count())
assertEquals(2, Table2.selectAll().count())
if (currentDialect is MysqlDialect) {
exec("SET foreign_key_checks = 0;")
}
if (currentDialect is PostgreSQLDialect) {
exec("set constraints all deferred;")
}
}
}

@Test fun testUUIDColumnType() {
val Node = object: Table("node") {
val uuid = uuid("uuid").primaryKey()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ class EntityTests: DatabaseTestsBase() {

@Test
fun tableSelfReferenceTest() {
assertEquals<List<Table>>(listOf(Categories, Boards, Posts), EntityCache.sortTablesByReferences(listOf(Posts, Boards, Categories)))
assertEquals(listOf(Categories, Boards, Posts), SchemaUtils.sortTablesByReferences(listOf(Posts, Boards, Categories)))
}

@Test
Expand Down
Loading

0 comments on commit f410859

Please sign in to comment.