Skip to content

Commit

Permalink
Merge pull request #1513 from johnedquinn/v1-conformance-datum-distinct
Browse files Browse the repository at this point in the history
Fixes all DISTINCT conformance tests
  • Loading branch information
johnedquinn authored Aug 7, 2024
2 parents 1244c58 + cda3557 commit e21a1eb
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@ package org.partiql.eval.internal.operator.rel
import org.partiql.eval.internal.Environment
import org.partiql.eval.internal.Record
import org.partiql.eval.internal.operator.Operator
import org.partiql.value.ListValue
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.listValue
import java.util.TreeSet

internal class RelDistinct(
val input: Operator.Relation
) : RelPeeking() {

// TODO: Add hashcode/equals support for PQLValue. Then we can use Record directly.
// TODO: Add hashcode/equals support for Datum. Then we can use Record directly.
@OptIn(PartiQLValueExperimental::class)
private val seen = mutableSetOf<List<PartiQLValue>>()
private val seen = TreeSet<ListValue<PartiQLValue>>(PartiQLValue.comparator())

override fun openPeeking(env: Environment) {
input.open(env)
Expand All @@ -21,7 +24,7 @@ internal class RelDistinct(
@OptIn(PartiQLValueExperimental::class)
override fun peek(): Record? {
for (next in input) {
val transformed = List(next.values.size) { next.values[it].toPartiQLValue() }
val transformed = listValue(List(next.values.size) { next.values[it].toPartiQLValue() })
if (seen.contains(transformed).not()) {
seen.add(transformed)
return next
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ internal open class PlanningProblemDetails(

data class ExpressionAlwaysReturnsMissing(val reason: String? = null) : PlanningProblemDetails(
severity = ProblemSeverity.ERROR,
messageFormatter = { "Expression always returns null or missing: caused by $reason" }
messageFormatter = { "Expression always returns missing: caused by $reason" }
)

data class InvalidArgumentTypeForFunction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@ internal object RexConverter {
@Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE")
private object ToRex : AstBaseVisitor<Rex, Env>() {

private val COLL_AGG_NAMES = setOf(
"coll_any",
"coll_avg",
"coll_count",
"coll_every",
"coll_max",
"coll_min",
"coll_some",
"coll_sum",
)

override fun defaultReturn(node: AstNode, context: Env): Rex =
throw IllegalArgumentException("unsupported rex $node")

Expand Down Expand Up @@ -465,11 +476,52 @@ internal object RexConverter {
}
// Args
val args = node.args.map { visitExprCoerce(it, context) }
// Rex

// Check if function is actually coll_<agg>
if (isCollAgg(node)) {
return callToCollAgg(id, node.setq, args)
}

if (node.setq != null) {
error("Currently, only COLL_<AGG> may use set quantifiers.")
}
val op = rexOpCallUnresolved(id, args)
return rex(type, op)
}

/**
* @return whether call is `COLL_<AGG>`.
*/
private fun isCollAgg(node: Expr.Call): Boolean {
val id = node.function as? org.partiql.ast.Identifier.Symbol ?: return false
return COLL_AGG_NAMES.contains(id.symbol.lowercase())
}

/**
* Converts COLL_<AGG> to the relevant function calls. For example:
* - `COLL_SUM(x)` becomes `coll_sum_all(x)`
* - `COLL_SUM(ALL x)` becomes `coll_sum_all(x)`
* - `COLL_SUM(DISTINCT x)` becomes `coll_sum_distinct(x)`
*
* It is assumed that the [id] has already been vetted by [isCollAgg].
*/
private fun callToCollAgg(id: Identifier, setQuantifier: SetQuantifier?, args: List<Rex>): Rex {
if (id.hasQualifier()) {
error("Qualified function calls are not currently supported.")
}
if (args.size != 1) {
error("Aggregate calls currently only support single arguments. Received ${args.size} arguments.")
}
val postfix = when (setQuantifier) {
SetQuantifier.DISTINCT -> "_distinct"
SetQuantifier.ALL -> "_all"
null -> "_all"
}
val newId = Identifier.regular(id.getIdentifier().getText() + postfix)
val op = Rex.Op.Call.Unresolved(newId, listOf(args[0]))
return Rex(ANY, op)
}

private fun visitExprCallTupleUnion(node: Expr.Call, context: Env): Rex {
val type = (STRUCT)
val args = node.args.map { visitExprCoerce(it, context) }.toMutableList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ internal class CompilerType(
// Note: This is an experimental property.
internal val isMissingValue: Boolean = false
) : PType {
fun getDelegate(): PType = _delegate
override fun getKind(): Kind = _delegate.kind
override fun getFields(): MutableCollection<Field> {
return _delegate.fields.map { field ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ internal class PlanTyper(private val env: Env) {
*
* TODO: Can this be merged with [anyOf]? Should we even allow this?
*/
fun anyOfLiterals(types: Collection<PType>): PType? {
fun anyOfLiterals(types: Collection<CompilerType>): PType? {
// Grab unique
var unique: Collection<PType> = types.toSet()
var unique: Collection<PType> = types.map { it.getDelegate() }.toSet()
if (unique.size == 0) {
return null
} else if (unique.size == 1) {
Expand Down Expand Up @@ -133,7 +133,7 @@ internal class PlanTyper(private val env: Env) {
}

private fun collapseCollection(collections: Iterable<PType>, type: Kind): PType {
val typeParam = anyOfLiterals(collections.map { it.typeParameter })!!
val typeParam = anyOfLiterals(collections.map { it.typeParameter.toCType() })!!
return when (type) {
Kind.LIST -> PType.typeList(typeParam)
Kind.BAG -> PType.typeList(typeParam)
Expand All @@ -145,13 +145,13 @@ internal class PlanTyper(private val env: Env) {
private fun collapseRows(rows: Iterable<PType>): PType {
val firstFields = rows.first().fields!!
val fieldNames = firstFields.map { it.name }
val fieldTypes = firstFields.map { mutableListOf(it.type) }
val fieldTypes = firstFields.map { mutableListOf(it.type.toCType()) }
rows.map { struct ->
val fields = struct.fields!!
if (fields.map { it.name } != fieldNames) {
return PType.typeStruct()
}
fields.forEachIndexed { index, field -> fieldTypes[index].add(field.type) }
fields.forEachIndexed { index, field -> fieldTypes[index].add(field.type.toCType()) }
}
val newFields = fieldTypes.mapIndexed { i, types -> Field.of(fieldNames[i], anyOfLiterals(types)!!) }
return PType.typeRow(newFields)
Expand All @@ -162,7 +162,10 @@ internal class PlanTyper(private val env: Env) {
return anyOf(unique)
}

fun PType.toCType(): CompilerType = CompilerType(this)
fun PType.toCType(): CompilerType = when (this) {
is CompilerType -> this
else -> CompilerType(this)
}

fun List<PType>.toCType(): List<CompilerType> = this.map { it.toCType() }

Expand Down
24 changes: 16 additions & 8 deletions partiql-spi/src/main/kotlin/org/partiql/spi/fn/SqlBuiltins.kt
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,22 @@ internal object SqlBuiltins {
Fn_CHAR_LENGTH__STRING__INT,
Fn_CHAR_LENGTH__CLOB__INT,
Fn_CHAR_LENGTH__SYMBOL__INT,
Fn_COLL_AGG__BAG__ANY.ANY,
Fn_COLL_AGG__BAG__ANY.AVG,
Fn_COLL_AGG__BAG__ANY.COUNT,
Fn_COLL_AGG__BAG__ANY.EVERY,
Fn_COLL_AGG__BAG__ANY.MAX,
Fn_COLL_AGG__BAG__ANY.MIN,
Fn_COLL_AGG__BAG__ANY.SOME,
Fn_COLL_AGG__BAG__ANY.SUM,
Fn_COLL_AGG__BAG__ANY.ANY_ALL,
Fn_COLL_AGG__BAG__ANY.AVG_ALL,
Fn_COLL_AGG__BAG__ANY.COUNT_ALL,
Fn_COLL_AGG__BAG__ANY.EVERY_ALL,
Fn_COLL_AGG__BAG__ANY.MAX_ALL,
Fn_COLL_AGG__BAG__ANY.MIN_ALL,
Fn_COLL_AGG__BAG__ANY.SOME_ALL,
Fn_COLL_AGG__BAG__ANY.SUM_ALL,
Fn_COLL_AGG__BAG__ANY.ANY_DISTINCT,
Fn_COLL_AGG__BAG__ANY.AVG_DISTINCT,
Fn_COLL_AGG__BAG__ANY.COUNT_DISTINCT,
Fn_COLL_AGG__BAG__ANY.EVERY_DISTINCT,
Fn_COLL_AGG__BAG__ANY.MAX_DISTINCT,
Fn_COLL_AGG__BAG__ANY.MIN_DISTINCT,
Fn_COLL_AGG__BAG__ANY.SOME_DISTINCT,
Fn_COLL_AGG__BAG__ANY.SUM_DISTINCT,
Fn_CONCAT__STRING_STRING__STRING,
Fn_CONCAT__CLOB_CLOB__CLOB,
Fn_CONCAT__SYMBOL_SYMBOL__SYMBOL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

package org.partiql.spi.fn.builtins

import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.Fn
import org.partiql.spi.fn.FnParameter
import org.partiql.spi.fn.FnSignature
import org.partiql.spi.fn.builtins.internal.Accumulator
import org.partiql.spi.fn.builtins.internal.AccumulatorAnySome
import org.partiql.spi.fn.builtins.internal.AccumulatorAvg
import org.partiql.spi.fn.builtins.internal.AccumulatorCount
import org.partiql.spi.fn.builtins.internal.AccumulatorDistinct
import org.partiql.spi.fn.builtins.internal.AccumulatorEvery
import org.partiql.spi.fn.builtins.internal.AccumulatorMax
import org.partiql.spi.fn.builtins.internal.AccumulatorMin
Expand All @@ -22,9 +22,18 @@ import org.partiql.value.PartiQLValueType
import org.partiql.value.check

@OptIn(PartiQLValueExperimental::class)
internal abstract class Fn_COLL_AGG__BAG__ANY : Fn {
internal abstract class Fn_COLL_AGG__BAG__ANY(
name: String,
private val isDistinct: Boolean,
private val accumulator: () -> Accumulator,
) : Fn {

abstract fun getAccumulator(): Agg.Accumulator
private fun getAccumulator(): Accumulator = when (isDistinct) {
true -> AccumulatorDistinct(accumulator.invoke())
false -> accumulator.invoke()
}

override val signature: FnSignature = createSignature(name)

companion object {
@JvmStatic
Expand All @@ -46,43 +55,35 @@ internal abstract class Fn_COLL_AGG__BAG__ANY : Fn {
return accumulator.value()
}

object SUM : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_sum")
override fun getAccumulator(): Accumulator = AccumulatorSum()
}
object SUM_ALL : Fn_COLL_AGG__BAG__ANY("coll_sum_all", false, ::AccumulatorSum)

object AVG : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_avg")
override fun getAccumulator(): Accumulator = AccumulatorAvg()
}
object SUM_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_sum_distinct", true, ::AccumulatorSum)

object MIN : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_min")
override fun getAccumulator(): Accumulator = AccumulatorMin()
}
object AVG_ALL : Fn_COLL_AGG__BAG__ANY("coll_avg_all", false, ::AccumulatorAvg)

object MAX : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_max")
override fun getAccumulator(): Accumulator = AccumulatorMax()
}
object AVG_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_avg_distinct", true, ::AccumulatorAvg)

object COUNT : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_count")
override fun getAccumulator(): Accumulator = AccumulatorCount()
}
object MIN_ALL : Fn_COLL_AGG__BAG__ANY("coll_min_all", false, ::AccumulatorMin)

object EVERY : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_every")
override fun getAccumulator(): Accumulator = AccumulatorEvery()
}
object MIN_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_min_distinct", true, ::AccumulatorMin)

object ANY : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_any")
override fun getAccumulator(): Accumulator = AccumulatorAnySome()
}
object MAX_ALL : Fn_COLL_AGG__BAG__ANY("coll_max_all", false, ::AccumulatorMax)

object SOME : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_some")
override fun getAccumulator(): Accumulator = AccumulatorAnySome()
}
object MAX_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_max_distinct", true, ::AccumulatorMax)

object COUNT_ALL : Fn_COLL_AGG__BAG__ANY("coll_count_all", false, ::AccumulatorCount)

object COUNT_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_count_distinct", true, ::AccumulatorCount)

object EVERY_ALL : Fn_COLL_AGG__BAG__ANY("coll_every_all", false, ::AccumulatorEvery)

object EVERY_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_every_distinct", true, ::AccumulatorEvery)

object ANY_ALL : Fn_COLL_AGG__BAG__ANY("coll_any_all", false, ::AccumulatorAnySome)

object ANY_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_any_distinct", true, ::AccumulatorAnySome)

object SOME_ALL : Fn_COLL_AGG__BAG__ANY("coll_some_all", false, ::AccumulatorAnySome)

object SOME_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_some_distinct", true, ::AccumulatorAnySome)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package org.partiql.spi.fn.builtins.internal

import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import java.util.TreeSet

@OptIn(PartiQLValueExperimental::class)
internal class AccumulatorDistinct(
private val _delegate: Accumulator,
) : Accumulator() {

// TODO: Add support for a datum comparator once the accumulator passes datums instead of PartiQL values.
@OptIn(PartiQLValueExperimental::class)
private val seen = TreeSet(PartiQLValue.comparator())

@OptIn(PartiQLValueExperimental::class)
override fun nextValue(value: PartiQLValue) {
if (!seen.contains(value)) {
seen.add(value)
_delegate.nextValue(value)
}
}

@OptIn(PartiQLValueExperimental::class)
override fun value(): PartiQLValue {
return _delegate.value()
}
}
Loading

0 comments on commit e21a1eb

Please sign in to comment.