Skip to content

Commit

Permalink
SQLServer many to many relationship update breaks when updating from…
Browse files Browse the repository at this point in the history
… exposed 0.26.2 to 0.27.1 JetBrains#1319
  • Loading branch information
Tapac authored and SchweinchenFuntik committed Oct 23, 2021
1 parent 078bf31 commit 8533af1
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package org.jetbrains.exposed.sql.statements

import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.statements.api.PreparedStatementApi
import org.jetbrains.exposed.sql.transactions.TransactionManager
import java.sql.ResultSet

class BatchDataInconsistentException(message: String) : Exception(message)
Expand All @@ -25,11 +24,13 @@ open class SQLServerBatchInsertStatement(table: Table, ignore: Boolean = false,
}
}

private val columnToReturnValue = table.autoIncColumn?.takeIf { shouldReturnGeneratedValues && it.autoIncColumnType?.nextValExpression == null }

override fun prepareSQL(transaction: Transaction): String {
val values = arguments!!
val sql = if (values.isEmpty()) ""
else {
val output = table.autoIncColumn?.takeIf { shouldReturnGeneratedValues && it.autoIncColumnType?.nextValExpression == null }?.let {
val output = columnToReturnValue?.let {
" OUTPUT inserted.${transaction.identity(it)} AS GENERATED_KEYS"
}.orEmpty()

Expand All @@ -47,6 +48,12 @@ open class SQLServerBatchInsertStatement(table: Table, ignore: Boolean = false,
override fun arguments() = listOfNotNull(super.arguments().flatten().takeIf { data.isNotEmpty() })

override fun PreparedStatementApi.execInsertFunction(): Pair<Int, ResultSet?> {
return arguments!!.size to executeQuery()
val rs = if (columnToReturnValue != null) {
executeQuery()
} else {
executeUpdate()
null
}
return arguments!!.size to rs
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,36 @@ object ViaTestData {
override val primaryKey = PrimaryKey(id)
}

object ConnectionTable : Table() {
val numId = reference("numId", NumbersTable, ReferenceOption.CASCADE)
val stringId = reference("stringId", StringsTable, ReferenceOption.CASCADE)
interface IConnectionTable {
val numId: Column<EntityID<UUID>>
val stringId: Column<EntityID<Long>>
}

object ConnectionTable : Table(), IConnectionTable {
override val numId = reference("numId", NumbersTable, ReferenceOption.CASCADE)
override val stringId = reference("stringId", StringsTable, ReferenceOption.CASCADE)

init {
index(true, numId, stringId)
}
}

object ConnectionAutoIncTable : IntIdTable(), IConnectionTable {
override val numId = reference("numId", NumbersTable, ReferenceOption.CASCADE)
override val stringId = reference("stringId", StringsTable, ReferenceOption.CASCADE)

init {
index(true, numId, stringId)
}
}

val allTables: Array<Table> = arrayOf(NumbersTable, StringsTable, ConnectionTable)
val allTables: Array<Table> = arrayOf(NumbersTable, StringsTable, ConnectionTable, ConnectionAutoIncTable)
}

class VNumber(id: EntityID<UUID>) : UUIDEntity(id) {
var number by ViaTestData.NumbersTable.number
var connectedStrings: SizedIterable<VString> by VString via ViaTestData.ConnectionTable
var connectedAutoStrings: SizedIterable<VString> by VString via ViaTestData.ConnectionAutoIncTable

companion object : UUIDEntityClass<VNumber>(ViaTestData.NumbersTable)
}
Expand All @@ -49,15 +64,29 @@ class VString(id: EntityID<Long>) : Entity<Long>(id) {
}

class ViaTests : DatabaseTestsBase() {

private fun VNumber.testWithBothTables(valuesToSet: List<VString>, body: (ViaTestData.IConnectionTable, List<ResultRow>) -> Unit) {
listOf(ViaTestData.ConnectionTable, ViaTestData.ConnectionAutoIncTable).forEach { t ->
if (t == ViaTestData.ConnectionTable) {
connectedStrings = SizedCollection(valuesToSet)
} else {
connectedAutoStrings = SizedCollection(valuesToSet)
}

val result = t.selectAll().toList()
body(t, result)
}
}

@Test fun testConnection01() {
withTables(*ViaTestData.allTables) {
val n = VNumber.new { number = 10 }
val s = VString.new { text = "aaa" }
n.connectedStrings = SizedCollection(listOf(s))

val row = ViaTestData.ConnectionTable.selectAll().single()
assertEquals(n.id, row[ViaTestData.ConnectionTable.numId])
assertEquals(s.id, row[ViaTestData.ConnectionTable.stringId])
n.testWithBothTables(listOf(s)) { table, result ->
val row = result.single()
assertEquals(n.id, row[table.numId])
assertEquals(s.id, row[table.stringId])
}
}
}

Expand All @@ -68,13 +97,12 @@ class ViaTests : DatabaseTestsBase() {
val s1 = VString.new { text = "aaa" }
val s2 = VString.new { text = "bbb" }

n1.connectedStrings = SizedCollection(listOf(s1, s2))

val row = ViaTestData.ConnectionTable.selectAll().toList()
assertEquals(2, row.count())
assertEquals(n1.id, row[0][ViaTestData.ConnectionTable.numId])
assertEquals(n1.id, row[1][ViaTestData.ConnectionTable.numId])
assertEqualCollections(listOf(s1.id, s2.id), row.map { it[ViaTestData.ConnectionTable.stringId] })
n1.testWithBothTables(listOf(s1, s2)) { table, row ->
assertEquals(2, row.count())
assertEquals(n1.id, row[0][table.numId])
assertEquals(n1.id, row[1][table.numId])
assertEqualCollections(listOf(s1.id, s2.id), row.map { it[table.stringId] })
}
}
}

Expand All @@ -85,23 +113,17 @@ class ViaTests : DatabaseTestsBase() {
val s1 = VString.new { text = "aaa" }
val s2 = VString.new { text = "bbb" }

n1.connectedStrings = SizedCollection(listOf(s1, s2))
n2.connectedStrings = SizedCollection(listOf(s1, s2))

run {
val row = ViaTestData.ConnectionTable.selectAll().toList()
n1.testWithBothTables(listOf(s1, s2)) { _, _ -> }
n2.testWithBothTables(listOf(s1, s2)) { _, row ->
assertEquals(4, row.count())
assertEqualCollections(n1.connectedStrings, listOf(s1, s2))
assertEqualCollections(n2.connectedStrings, listOf(s1, s2))
}

n1.connectedStrings = SizedCollection(emptyList())

run {
val row = ViaTestData.ConnectionTable.selectAll().toList()
n1.testWithBothTables(emptyList()) { table, row ->
assertEquals(2, row.count())
assertEquals(n2.id, row[0][ViaTestData.ConnectionTable.numId])
assertEquals(n2.id, row[1][ViaTestData.ConnectionTable.numId])
assertEquals(n2.id, row[0][table.numId])
assertEquals(n2.id, row[1][table.numId])
assertEqualCollections(n1.connectedStrings, emptyList())
assertEqualCollections(n2.connectedStrings, listOf(s1, s2))
}
Expand All @@ -115,20 +137,14 @@ class ViaTests : DatabaseTestsBase() {
val s1 = VString.new { text = "aaa" }
val s2 = VString.new { text = "bbb" }

n1.connectedStrings = SizedCollection(listOf(s1, s2))
n2.connectedStrings = SizedCollection(listOf(s1, s2))

run {
val row = ViaTestData.ConnectionTable.selectAll().toList()
n1.testWithBothTables(listOf(s1, s2)) { _, _ -> }
n2.testWithBothTables(listOf(s1, s2)) { _, row ->
assertEquals(4, row.count())
assertEqualCollections(n1.connectedStrings, listOf(s1, s2))
assertEqualCollections(n2.connectedStrings, listOf(s1, s2))
}

n1.connectedStrings = SizedCollection(listOf(s1))

run {
val row = ViaTestData.ConnectionTable.selectAll().toList()
n1.testWithBothTables(listOf(s1)) { table, row ->
assertEquals(3, row.count())
assertEqualCollections(n1.connectedStrings, listOf(s1))
assertEqualCollections(n2.connectedStrings, listOf(s1, s2))
Expand Down

0 comments on commit 8533af1

Please sign in to comment.