Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Escaping column and table names by Dialect #24

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ interface Dialect {
.joinToString(" ")
return sql + "\n" + limitAndOffset
}

fun escapeName(columnName: String): String = columnName
}

internal fun String.truncate(limit: Int): String {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,8 @@ open class MysqlDialect : Dialect {
override fun allocateIds(count: Int, sequence: String, columnName: String) = throw UnsupportedOperationException()

override val supportsFetchingGeneratedKeysByName = false

override fun escapeName(columnName: String): String =
'`' + columnName.replace("`", "``") + '`'

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,8 @@ open class PostgresDialect : Dialect {
"select nextval('$sequence') as $columnName from generate_series(1, $count)"

override val supportsFetchingGeneratedKeysByName = true

override fun escapeName(columnName: String): String =
'"' + columnName.replace("\"", "\"\"") + '"'

}
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,8 @@ open class SqliteDialect : Dialect {
override fun allocateIds(count: Int, sequence: String, columnName: String) = throw UnsupportedOperationException()

override val supportsFetchingGeneratedKeysByName = false

override fun escapeName(columnName: String): String =
'"' + columnName.replace("\"", "\"\"") + '"'

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ abstract class AbstractDao<T : Any, ID : Any>(

override val defaultColumns = table.defaultColumns

protected val columns = table.defaultColumns.join()
protected val columns = table.defaultColumns.joinNames()

private val listeners = linkedSetOf<Listener>()

private val escapedTableName = session.dialect.escapeName(table.name)

fun addListener(listener: Listener) {
listeners.add(listener)
}
Expand Down Expand Up @@ -75,12 +77,16 @@ abstract class AbstractDao<T : Any, ID : Any>(
return this.groupBy { it.first }.map { apply(it.key, it.value.map { it.second }) }
}

protected fun Iterable<Column<T, *>>.join(separator: String = ", ", f: (Column<T, *>) -> String = nf): String {
return this.map { f(it) }.joinToString(separator)
protected fun Iterable<Column<T, *>>.joinNames(separator: String = ", ", f: (Column<T, *>) -> String = nf): String {
return this.joinToString(separator) { session.dialect.escapeName(f(it)) }
}

protected fun Iterable<Column<T, *>>.joinStrings(separator: String = ", ", f: (Column<T, *>) -> String = nf): String {
return this.joinToString(separator) { f(it) }
}

protected fun Iterable<Column<T, *>>.equate(separator: String = ", ", f: (Column<T, *>) -> String = nf): String {
return this.map { "${f(it)} = :${f(it)}" }.joinToString(separator)
return this.joinToString(separator) { "${session.dialect.escapeName(f(it))} = :${f(it)}" }
}

protected fun Collection<ID>.copyToSqlArray(): java.sql.Array {
Expand All @@ -103,22 +109,22 @@ abstract class AbstractDao<T : Any, ID : Any>(
override fun findById(id: ID, columns: Set<Column<T, *>>): T? = withTransaction {
val name = "findById"
val sql = sql(name to columns) {
"select ${columns.join()} \nfrom ${table.name} \nwhere ${table.idColumns.equate(" and ")}"
"select ${columns.joinNames()} \nfrom $escapedTableName \nwhere ${table.idColumns.equate(" and ")}"
}
session.select(sql, table.idMap(session, id, nf), options(name), table.rowMapper(columns)).firstOrNull()
}

override fun findByIdForUpdate(id: ID, columns: Set<Column<T, *>>): T? = withTransaction {
val name = "findByIdForUpdate"
val sql = sql(name to columns) {
"select ${columns.join()} \nfrom ${table.name} \nwhere ${table.idColumns.equate(" and ")}\nfor update"
"select ${columns.joinNames()} \nfrom $escapedTableName \nwhere ${table.idColumns.equate(" and ")}\nfor update"
}
session.select(sql, table.idMap(session, id, nf), options(name), table.rowMapper(columns)).firstOrNull()
}

override fun findAll(columns: Set<Column<T, *>>): List<T> = withTransaction {
val name = "findAll"
val sql = sql(name to columns) { "select ${columns.join()} \nfrom ${table.name}" }
val sql = sql(name to columns) { "select ${columns.joinNames()} \nfrom $escapedTableName" }
session.select(sql, mapOf(), options(name), table.rowMapper(columns))
}

Expand All @@ -130,7 +136,7 @@ abstract class AbstractDao<T : Any, ID : Any>(

val exampleMap = table.objectMap(session, example, exampleColumns, nf)
val sql = sql(Triple(name, exampleColumns, columns)) {
"select ${columns.join()} \nfrom ${table.name}\nwhere ${exampleColumns.equate(" and ")}"
"select ${columns.joinNames()} \nfrom $escapedTableName\nwhere ${exampleColumns.equate(" and ")}"
}
session.select(sql, exampleMap, options(name), table.rowMapper(columns))
}
Expand Down Expand Up @@ -165,8 +171,10 @@ abstract class AbstractDao<T : Any, ID : Any>(
fun delta(): Pair<String, Map<String, Any?>> {
val differences = difference(oldMap, newMap)
val sql = sql(name to differences) {
val columns = differences.keys.map { "$it = :$it" }.joinToString(", ")
"update ${table.name}\nset $columns \nwhere ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam"
val columns = differences.keys.joinToString(", ") { "$it = :$it" }
"update $escapedTableName \n" +
"set $columns \n" +
"where ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam"
}
val parameters = hashMapOfExpectedSize<String, Any?>(differences.size + table.idColumns.size + 1)
parameters.putAll(differences)
Expand All @@ -177,7 +185,9 @@ abstract class AbstractDao<T : Any, ID : Any>(

fun full(): Pair<String, HashMap<String, Any?>> {
val sql = sql(name) {
"update ${table.name}\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam"
"update $escapedTableName \n" +
"set ${table.dataColumns.equate()} \n" +
"where ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam"
}
val parameters = hashMapOfExpectedSize<String, Any?>(newMap.size + table.idColumns.size + 1)
parameters.putAll(newMap)
Expand Down Expand Up @@ -208,7 +218,7 @@ abstract class AbstractDao<T : Any, ID : Any>(

override fun delete(id: ID): Int = withTransaction {
val name = "delete"
val sql = sql(name) { "delete from ${table.name} where ${table.idColumns.equate(" and ")}" }
val sql = sql(name) { "delete from $escapedTableName where ${table.idColumns.equate(" and ")}" }
val count = session.update(sql, table.idMap(session, id, nf), options(name))

fireEvent { DeleteEvent(table, id, null) }
Expand All @@ -221,7 +231,7 @@ abstract class AbstractDao<T : Any, ID : Any>(
val new = fireTransformingEvent(newValue) { PreUpdateEvent(table, id(newValue), newValue, null) }

val sql = sql(name) {
"update ${table.name}\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")}"
"update $escapedTableName\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")}"
}
val newMap = table.objectMap(session, new, table.allColumns)

Expand All @@ -247,7 +257,8 @@ abstract class AbstractDao<T : Any, ID : Any>(
}

val columns = if (generateKeys) table.dataColumns else table.allColumns
val sql = sql(name) { "insert into ${table.name}(${columns.join()}) \nvalues (${columns.join { ":${it.name}" }})" }
val sql = sql(name) { "insert into $escapedTableName (${columns.joinNames()}) \n" +
"values (${columns.joinStrings { ":${it.name}" }})" }

val inserted = if (generateKeys) {
val list = session.batchInsert(sql, new.map { table.objectMap(session, it, columns, nf) }, options(name),
Expand Down Expand Up @@ -281,7 +292,8 @@ abstract class AbstractDao<T : Any, ID : Any>(
val generateKeys = isGeneratedKey(new, idStrategy)

val columns = if (generateKeys) table.dataColumns else table.allColumns
val sql = sql(name to columns) { "insert into ${table.name}(${columns.join()}) \nvalues (${columns.join { ":${it.name}" }})" }
val sql = sql(name to columns) { "insert into $escapedTableName (${columns.joinNames()}) \n" +
"values (${columns.joinStrings { ":${it.name}" }})" }
val parameters = table.objectMap(session, new, columns, nf)

val (count, inserted) = if (generateKeys) {
Expand Down Expand Up @@ -314,7 +326,7 @@ abstract class AbstractDao<T : Any, ID : Any>(

val values = if (session.dialect.supportsArrayBasedIn) {
val sql = sql(name to columns) {
"select ${columns.join()} \nfrom ${table.name} \nwhere ${table.idColumns.first().name} " +
"select ${columns.joinNames()} \nfrom $escapedTableName \nwhere ${session.dialect.escapeName(table.idColumns.first().name)} " +
session.dialect.arrayBasedIn("ids")
}
val array = ids.copyToSqlArray()
Expand All @@ -325,7 +337,7 @@ abstract class AbstractDao<T : Any, ID : Any>(
}
} else {
val sql = sql(name to columns) {
"select ${columns.join()} \nfrom ${table.name} \nwhere ${table.idColumns.first().name} in (:ids)"
"select ${columns.joinNames()} \nfrom $escapedTableName \nwhere ${session.dialect.escapeName(table.idColumns.first().name)} in (:ids)"
}
session.select(sql, mapOf("ids" to ids), options(name), table.rowMapper(columns))
}
Expand Down Expand Up @@ -353,7 +365,7 @@ abstract class AbstractDao<T : Any, ID : Any>(
val updates = new.map { table.objectMap(session, it, table.allColumns) }

val sql = sql(name) {
"update ${table.name}\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")}"
"update $escapedTableName\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")}"
}

val counts = session.batchUpdate(sql, updates, options(name))
Expand Down Expand Up @@ -402,7 +414,7 @@ abstract class AbstractDao<T : Any, ID : Any>(
}

val sql = sql(name) {
"update ${table.name}\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam"
"update $escapedTableName\nset ${table.dataColumns.equate()} \nwhere ${table.idColumns.equate(" and ")} and $versionCol = :$oldVersionParam"
}

val counts = session.batchUpdate(sql, updates.map { it.first }, options(name))
Expand Down