Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes all DISTINCT conformance tests #1513

Merged
merged 4 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@ package org.partiql.eval.internal.operator.rel
import org.partiql.eval.internal.Environment
import org.partiql.eval.internal.Record
import org.partiql.eval.internal.operator.Operator
import org.partiql.value.ListValue
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.listValue
import java.util.TreeSet

internal class RelDistinct(
val input: Operator.Relation
) : RelPeeking() {

// TODO: Add hashcode/equals support for PQLValue. Then we can use Record directly.
// TODO: Add hashcode/equals support for Datum. Then we can use Record directly.
@OptIn(PartiQLValueExperimental::class)
private val seen = mutableSetOf<List<PartiQLValue>>()
private val seen = TreeSet<ListValue<PartiQLValue>>(PartiQLValue.comparator())
alancai98 marked this conversation as resolved.
Show resolved Hide resolved

override fun openPeeking(env: Environment) {
input.open(env)
Expand All @@ -21,7 +24,7 @@ internal class RelDistinct(
@OptIn(PartiQLValueExperimental::class)
override fun peek(): Record? {
for (next in input) {
val transformed = List(next.values.size) { next.values[it].toPartiQLValue() }
val transformed = listValue(List(next.values.size) { next.values[it].toPartiQLValue() })
if (seen.contains(transformed).not()) {
seen.add(transformed)
return next
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ internal open class PlanningProblemDetails(

data class ExpressionAlwaysReturnsMissing(val reason: String? = null) : PlanningProblemDetails(
severity = ProblemSeverity.ERROR,
messageFormatter = { "Expression always returns null or missing: caused by $reason" }
messageFormatter = { "Expression always returns missing: caused by $reason" }
)

data class InvalidArgumentTypeForFunction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@ internal object RexConverter {
@Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE")
private object ToRex : AstBaseVisitor<Rex, Env>() {

private val COLL_AGG_NAMES = setOf(
"coll_any",
"coll_avg",
"coll_count",
"coll_every",
"coll_max",
"coll_min",
"coll_some",
"coll_sum",
)

override fun defaultReturn(node: AstNode, context: Env): Rex =
throw IllegalArgumentException("unsupported rex $node")

Expand Down Expand Up @@ -465,11 +476,52 @@ internal object RexConverter {
}
// Args
val args = node.args.map { visitExprCoerce(it, context) }
// Rex

// Check if function is actually coll_<agg>
if (isCollAgg(node)) {
return callToCollAgg(id, node.setq, args)
}

if (node.setq != null) {
error("Currently, only COLL_<AGG> may use set quantifiers.")
}
val op = rexOpCallUnresolved(id, args)
return rex(type, op)
}

/**
* @return whether call is `COLL_<AGG>`.
*/
private fun isCollAgg(node: Expr.Call): Boolean {
val id = node.function as? org.partiql.ast.Identifier.Symbol ?: return false
return COLL_AGG_NAMES.contains(id.symbol.lowercase())
}

/**
* Converts COLL_<AGG> to the relevant function calls. For example:
* - `COLL_SUM(x)` becomes `coll_sum_all(x)`
* - `COLL_SUM(ALL x)` becomes `coll_sum_all(x)`
* - `COLL_SUM(DISTINCT x)` becomes `coll_sum_distinct(x)`
*
* It is assumed that the [id] has already been vetted by [isCollAgg].
*/
private fun callToCollAgg(id: Identifier, setQuantifier: SetQuantifier?, args: List<Rex>): Rex {
if (id.hasQualifier()) {
error("Qualified function calls are not currently supported.")
}
if (args.size != 1) {
error("Aggregate calls currently only support single arguments. Received ${args.size} arguments.")
}
val postfix = when (setQuantifier) {
SetQuantifier.DISTINCT -> "_distinct"
SetQuantifier.ALL -> "_all"
null -> "_all"
}
val newId = Identifier.regular(id.getIdentifier().getText() + postfix)
val op = Rex.Op.Call.Unresolved(newId, listOf(args[0]))
return Rex(ANY, op)
}

private fun visitExprCallTupleUnion(node: Expr.Call, context: Env): Rex {
val type = (STRUCT)
val args = node.args.map { visitExprCoerce(it, context) }.toMutableList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ internal class CompilerType(
// Note: This is an experimental property.
internal val isMissingValue: Boolean = false
) : PType {
fun getDelegate(): PType = _delegate
override fun getKind(): Kind = _delegate.kind
override fun getFields(): MutableCollection<Field> {
return _delegate.fields.map { field ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ internal class PlanTyper(private val env: Env) {
*
* TODO: Can this be merged with [anyOf]? Should we even allow this?
*/
fun anyOfLiterals(types: Collection<PType>): PType? {
fun anyOfLiterals(types: Collection<CompilerType>): PType? {
// Grab unique
var unique: Collection<PType> = types.toSet()
var unique: Collection<PType> = types.map { it.getDelegate() }.toSet()
if (unique.size == 0) {
return null
} else if (unique.size == 1) {
Expand Down Expand Up @@ -133,7 +133,7 @@ internal class PlanTyper(private val env: Env) {
}

private fun collapseCollection(collections: Iterable<PType>, type: Kind): PType {
val typeParam = anyOfLiterals(collections.map { it.typeParameter })!!
val typeParam = anyOfLiterals(collections.map { it.typeParameter.toCType() })!!
return when (type) {
Kind.LIST -> PType.typeList(typeParam)
Kind.BAG -> PType.typeList(typeParam)
Expand All @@ -145,13 +145,13 @@ internal class PlanTyper(private val env: Env) {
private fun collapseRows(rows: Iterable<PType>): PType {
val firstFields = rows.first().fields!!
val fieldNames = firstFields.map { it.name }
val fieldTypes = firstFields.map { mutableListOf(it.type) }
val fieldTypes = firstFields.map { mutableListOf(it.type.toCType()) }
rows.map { struct ->
val fields = struct.fields!!
if (fields.map { it.name } != fieldNames) {
return PType.typeStruct()
}
fields.forEachIndexed { index, field -> fieldTypes[index].add(field.type) }
fields.forEachIndexed { index, field -> fieldTypes[index].add(field.type.toCType()) }
}
val newFields = fieldTypes.mapIndexed { i, types -> Field.of(fieldNames[i], anyOfLiterals(types)!!) }
return PType.typeRow(newFields)
Expand All @@ -162,7 +162,10 @@ internal class PlanTyper(private val env: Env) {
return anyOf(unique)
}

fun PType.toCType(): CompilerType = CompilerType(this)
fun PType.toCType(): CompilerType = when (this) {
is CompilerType -> this
else -> CompilerType(this)
}

fun List<PType>.toCType(): List<CompilerType> = this.map { it.toCType() }

Expand Down
24 changes: 16 additions & 8 deletions partiql-spi/src/main/kotlin/org/partiql/spi/fn/SqlBuiltins.kt
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,22 @@ internal object SqlBuiltins {
Fn_CHAR_LENGTH__STRING__INT,
Fn_CHAR_LENGTH__CLOB__INT,
Fn_CHAR_LENGTH__SYMBOL__INT,
Fn_COLL_AGG__BAG__ANY.ANY,
Fn_COLL_AGG__BAG__ANY.AVG,
Fn_COLL_AGG__BAG__ANY.COUNT,
Fn_COLL_AGG__BAG__ANY.EVERY,
Fn_COLL_AGG__BAG__ANY.MAX,
Fn_COLL_AGG__BAG__ANY.MIN,
Fn_COLL_AGG__BAG__ANY.SOME,
Fn_COLL_AGG__BAG__ANY.SUM,
Fn_COLL_AGG__BAG__ANY.ANY_ALL,
Fn_COLL_AGG__BAG__ANY.AVG_ALL,
Fn_COLL_AGG__BAG__ANY.COUNT_ALL,
Fn_COLL_AGG__BAG__ANY.EVERY_ALL,
Fn_COLL_AGG__BAG__ANY.MAX_ALL,
Fn_COLL_AGG__BAG__ANY.MIN_ALL,
Fn_COLL_AGG__BAG__ANY.SOME_ALL,
Fn_COLL_AGG__BAG__ANY.SUM_ALL,
Fn_COLL_AGG__BAG__ANY.ANY_DISTINCT,
Fn_COLL_AGG__BAG__ANY.AVG_DISTINCT,
Fn_COLL_AGG__BAG__ANY.COUNT_DISTINCT,
Fn_COLL_AGG__BAG__ANY.EVERY_DISTINCT,
Fn_COLL_AGG__BAG__ANY.MAX_DISTINCT,
Fn_COLL_AGG__BAG__ANY.MIN_DISTINCT,
Fn_COLL_AGG__BAG__ANY.SOME_DISTINCT,
Fn_COLL_AGG__BAG__ANY.SUM_DISTINCT,
Fn_CONCAT__STRING_STRING__STRING,
Fn_CONCAT__CLOB_CLOB__CLOB,
Fn_CONCAT__SYMBOL_SYMBOL__SYMBOL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

package org.partiql.spi.fn.builtins

import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.Fn
import org.partiql.spi.fn.FnParameter
import org.partiql.spi.fn.FnSignature
import org.partiql.spi.fn.builtins.internal.Accumulator
import org.partiql.spi.fn.builtins.internal.AccumulatorAnySome
import org.partiql.spi.fn.builtins.internal.AccumulatorAvg
import org.partiql.spi.fn.builtins.internal.AccumulatorCount
import org.partiql.spi.fn.builtins.internal.AccumulatorDistinct
import org.partiql.spi.fn.builtins.internal.AccumulatorEvery
import org.partiql.spi.fn.builtins.internal.AccumulatorMax
import org.partiql.spi.fn.builtins.internal.AccumulatorMin
Expand All @@ -22,9 +22,18 @@ import org.partiql.value.PartiQLValueType
import org.partiql.value.check

@OptIn(PartiQLValueExperimental::class)
internal abstract class Fn_COLL_AGG__BAG__ANY : Fn {
internal abstract class Fn_COLL_AGG__BAG__ANY(
name: String,
private val isDistinct: Boolean,
private val accumulator: () -> Accumulator,
) : Fn {

abstract fun getAccumulator(): Agg.Accumulator
private fun getAccumulator(): Accumulator = when (isDistinct) {
true -> AccumulatorDistinct(accumulator.invoke())
false -> accumulator.invoke()
}

override val signature: FnSignature = createSignature(name)
alancai98 marked this conversation as resolved.
Show resolved Hide resolved

companion object {
@JvmStatic
Expand All @@ -46,43 +55,35 @@ internal abstract class Fn_COLL_AGG__BAG__ANY : Fn {
return accumulator.value()
}

object SUM : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_sum")
override fun getAccumulator(): Accumulator = AccumulatorSum()
}
object SUM_ALL : Fn_COLL_AGG__BAG__ANY("coll_sum_all", false, ::AccumulatorSum)

object AVG : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_avg")
override fun getAccumulator(): Accumulator = AccumulatorAvg()
}
object SUM_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_sum_distinct", true, ::AccumulatorSum)

object MIN : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_min")
override fun getAccumulator(): Accumulator = AccumulatorMin()
}
object AVG_ALL : Fn_COLL_AGG__BAG__ANY("coll_avg_all", false, ::AccumulatorAvg)

object MAX : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_max")
override fun getAccumulator(): Accumulator = AccumulatorMax()
}
object AVG_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_avg_distinct", true, ::AccumulatorAvg)

object COUNT : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_count")
override fun getAccumulator(): Accumulator = AccumulatorCount()
}
object MIN_ALL : Fn_COLL_AGG__BAG__ANY("coll_min_all", false, ::AccumulatorMin)

object EVERY : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_every")
override fun getAccumulator(): Accumulator = AccumulatorEvery()
}
object MIN_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_min_distinct", true, ::AccumulatorMin)

object ANY : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_any")
override fun getAccumulator(): Accumulator = AccumulatorAnySome()
}
object MAX_ALL : Fn_COLL_AGG__BAG__ANY("coll_max_all", false, ::AccumulatorMax)

object SOME : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_some")
override fun getAccumulator(): Accumulator = AccumulatorAnySome()
}
object MAX_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_max_distinct", true, ::AccumulatorMax)

object COUNT_ALL : Fn_COLL_AGG__BAG__ANY("coll_count_all", false, ::AccumulatorCount)

object COUNT_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_count_distinct", true, ::AccumulatorCount)

object EVERY_ALL : Fn_COLL_AGG__BAG__ANY("coll_every_all", false, ::AccumulatorEvery)

object EVERY_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_every_distinct", true, ::AccumulatorEvery)

object ANY_ALL : Fn_COLL_AGG__BAG__ANY("coll_any_all", false, ::AccumulatorAnySome)

object ANY_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_any_distinct", true, ::AccumulatorAnySome)

object SOME_ALL : Fn_COLL_AGG__BAG__ANY("coll_some_all", false, ::AccumulatorAnySome)

object SOME_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_some_distinct", true, ::AccumulatorAnySome)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package org.partiql.spi.fn.builtins.internal

import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import java.util.TreeSet

@OptIn(PartiQLValueExperimental::class)
internal class AccumulatorDistinct(
private val _delegate: Accumulator,
) : Accumulator() {

// TODO: Add support for a datum comparator once the accumulator passes datums instead of PartiQL values.
@OptIn(PartiQLValueExperimental::class)
private val seen = TreeSet(PartiQLValue.comparator())

@OptIn(PartiQLValueExperimental::class)
override fun nextValue(value: PartiQLValue) {
if (!seen.contains(value)) {
seen.add(value)
_delegate.nextValue(value)
}
}

@OptIn(PartiQLValueExperimental::class)
override fun value(): PartiQLValue {
return _delegate.value()
}
}
Loading
Loading