Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply SQL 9.3 typing rules for CASE-WHEN #1391

Merged
merged 8 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@

package org.partiql.cli

import AstPrinter
import com.amazon.ion.system.IonSystemBuilder
import com.amazon.ion.system.IonTextWriterBuilder
import org.partiql.cli.pico.PartiQLCommand
import org.partiql.cli.shell.info
import org.partiql.lang.eval.EvaluationSession
import org.partiql.parser.PartiQLParser
import org.partiql.plan.Statement
import org.partiql.plan.debug.PlanPrinter
import org.partiql.planner.PartiQLPlanner
import org.partiql.plugins.local.LocalConnector
import org.partiql.plugins.local.toIon
import picocli.CommandLine
import java.io.PrintStream
import java.nio.file.Paths
Expand Down Expand Up @@ -80,6 +82,16 @@ object Debug {
out.info("-- Plan ----------")
PlanPrinter.append(out, result.statement)

when (val plan = result.statement) {
is Statement.Query -> {
out.info("-- Schema ----------")
val outputSchema = java.lang.StringBuilder()
val ionWriter = IonTextWriterBuilder.minimal().withPrettyPrinting().build(outputSchema)
plan.root.type.toIon().writeTo(ionWriter)
out.info(outputSchema.toString())
}
}

return "OK"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,362 @@
@file:OptIn(PartiQLValueExperimental::class)

package org.partiql.planner.internal.typer

import org.partiql.types.MissingType
import org.partiql.types.NullType
import org.partiql.types.StaticType
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.PartiQLValueType
import org.partiql.value.PartiQLValueType.ANY
import org.partiql.value.PartiQLValueType.BAG
import org.partiql.value.PartiQLValueType.BINARY
import org.partiql.value.PartiQLValueType.BLOB
import org.partiql.value.PartiQLValueType.BOOL
import org.partiql.value.PartiQLValueType.BYTE
import org.partiql.value.PartiQLValueType.CHAR
import org.partiql.value.PartiQLValueType.CLOB
import org.partiql.value.PartiQLValueType.DATE
import org.partiql.value.PartiQLValueType.DECIMAL
import org.partiql.value.PartiQLValueType.DECIMAL_ARBITRARY
import org.partiql.value.PartiQLValueType.FLOAT32
import org.partiql.value.PartiQLValueType.FLOAT64
import org.partiql.value.PartiQLValueType.INT
import org.partiql.value.PartiQLValueType.INT16
import org.partiql.value.PartiQLValueType.INT32
import org.partiql.value.PartiQLValueType.INT64
import org.partiql.value.PartiQLValueType.INT8
import org.partiql.value.PartiQLValueType.INTERVAL
import org.partiql.value.PartiQLValueType.LIST
import org.partiql.value.PartiQLValueType.MISSING
import org.partiql.value.PartiQLValueType.NULL
import org.partiql.value.PartiQLValueType.SEXP
import org.partiql.value.PartiQLValueType.STRING
import org.partiql.value.PartiQLValueType.STRUCT
import org.partiql.value.PartiQLValueType.SYMBOL
import org.partiql.value.PartiQLValueType.TIME
import org.partiql.value.PartiQLValueType.TIMESTAMP

/**
* Graph of super types for quick lookup because we don't have a tree.
*/
internal typealias SuperGraph = Array<Array<PartiQLValueType?>>

/**
* For lack of a better name, this is the "dynamic typer" which implements the typing rules of SQL-99 9.3.
*
* SQL-99 9.3 Data types of results of aggregations (<case-when>, <collection value expression>, <query expression>)
* > https://web.cecs.pdx.edu/~len/sql1999.pdf#page=359
*
* Usage,
* To calculate the type of an "aggregation" create a new instance and "accumulate" each possible type.
* This is a pain with StaticType...
*/
@OptIn(PartiQLValueExperimental::class)
internal class DynamicTyper {

private var supertype = NULL
private var args = mutableListOf<PartiQLValueType>()

private var nullable = false
private var missable = false
private val types = mutableSetOf<StaticType>()

/**
* This primarily unpacks a StaticType because of NULL, MISSING.
*
* - T
* - NULL
* - MISSING
* - (NULL)
* - (MISSING)
* - (T..)
* - (T..|NULL)
* - (T..|MISSING)
* - (T..|NULL|MISSING)
* - (NULL|MISSING)
*
* @param type
*/
fun accumulate(type: StaticType) {
val nonAbsentTypes = mutableSetOf<StaticType>()
for (t in type.flatten().allTypes) {
when (t) {
is NullType -> nullable = true
is MissingType -> missable = true
else -> nonAbsentTypes.add(t)
}
}
when (nonAbsentTypes.size) {
0 -> {
// Ignore in calculating supertype.
args.add(NULL)
}
1 -> {
// Had single type
val single = nonAbsentTypes.first()
val singleRuntime = single.toRuntimeType()
types.add(single)
args.add(singleRuntime)
calculate(singleRuntime)
}
else -> {
// Had a union; use ANY runtime
types.addAll(nonAbsentTypes)
args.add(ANY)
calculate(ANY)
}
}
}

/**
* Returns a pair of the return StaticType and the coercion.
*
* If the list is null, then no mapping is required.
*
* @return
*/
fun mapping(): Pair<StaticType, List<Pair<PartiQLValueType, PartiQLValueType>>?> {
rchowell marked this conversation as resolved.
Show resolved Hide resolved
val modifiers = mutableSetOf<StaticType>()
if (nullable) modifiers.add(StaticType.NULL)
if (missable) modifiers.add(StaticType.MISSING)
// If at top supertype, then return union of all accumulated types
if (supertype == ANY) {
return StaticType.unionOf(types + modifiers) to null
}
// If a collection, then return union of all accumulated types as these coercion rules are not defined by SQL.
if (supertype == STRUCT || supertype == BAG || supertype == LIST || supertype == SEXP) {
return StaticType.unionOf(types + modifiers) to null
}
// Otherwise, return the supertype along with the coercion mapping
val type = supertype.toNonNullStaticType()
val mapping = args.map { it to supertype }
return if (modifiers.isEmpty()) {
type to mapping
} else {
StaticType.unionOf(setOf(type) + modifiers).flatten() to mapping
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If one would keep accumulating MISSING, does not seem like the NULL type will be dropped.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate more?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val typer = DynamicTyper() // at this stage, super Type is NULL
typer.accumulate(MISSING) // no touch on super type, missable is true

// modifier contains missing, 
// type is supertype.toNonNullStaticType() -> NULL
// return is StaticType.unionOf(setOf(type) + modifiers).flatten() to mapping
typer.mapping() // UnionOf( NULL, MISSING) 

}

private fun calculate(type: PartiQLValueType) {
// Don't bother calculating the new supertype, we've already hit the top.
if (supertype == ANY) return
// Lookup and set the new minimum common supertype
supertype = when {
supertype == NULL -> type // initialize
type == ANY -> type
type == NULL || type == MISSING || supertype == type -> return // skip
else -> graph[supertype][type] ?: ANY // lookup, if missing then go to top.
}
}

private operator fun <T> Array<T>.get(t: PartiQLValueType): T = get(t.ordinal)

/**
* !! IMPORTANT !!
*
* This is duplicated from the TypeLattice because that was removed in v1.0.0. I wanted to implement this as
* a standalone component so that it is easy to merge (and later merge with CastTable) into v1.0.0.
*/
companion object {

private operator fun <T> Array<T>.set(t: PartiQLValueType, value: T): Unit = this.set(t.ordinal, value)

@JvmStatic
private val N = PartiQLValueType.values().size

@JvmStatic
private fun edges(vararg edges: Pair<PartiQLValueType, PartiQLValueType>): Array<PartiQLValueType?> {
val arr = arrayOfNulls<PartiQLValueType?>(N)
for (type in edges) {
arr[type.first] = type.second
}
return arr
}

/**
* This table defines the rules in the SQL-99 section 9.3 BUT we don't have type constraints yet.
*
* TODO collection supertypes
* TODO datetime supertypes
*/
@JvmStatic
internal val graph: SuperGraph = run {
val graph = arrayOfNulls<Array<PartiQLValueType?>>(N)
for (type in PartiQLValueType.values()) {
// initialize all with empty edges
graph[type] = arrayOfNulls(N)
}
graph[ANY] = edges()
graph[NULL] = edges()
graph[MISSING] = edges()
graph[BOOL] = edges(
BOOL to BOOL
)
graph[INT8] = edges(
INT8 to INT8,
INT16 to INT16,
INT32 to INT32,
INT64 to INT64,
INT to INT,
DECIMAL to DECIMAL,
DECIMAL_ARBITRARY to DECIMAL_ARBITRARY,
FLOAT32 to FLOAT32,
FLOAT64 to FLOAT64,
)
graph[INT16] = edges(
INT8 to INT16,
INT16 to INT16,
INT32 to INT32,
INT64 to INT64,
INT to INT,
DECIMAL to DECIMAL,
DECIMAL_ARBITRARY to DECIMAL_ARBITRARY,
FLOAT32 to FLOAT32,
FLOAT64 to FLOAT64,
)
graph[INT32] = edges(
INT8 to INT32,
INT16 to INT32,
INT32 to INT32,
INT64 to INT64,
INT to INT,
DECIMAL to DECIMAL,
DECIMAL_ARBITRARY to DECIMAL_ARBITRARY,
FLOAT32 to FLOAT32,
FLOAT64 to FLOAT64,
)
graph[INT64] = edges(
INT8 to INT64,
INT16 to INT64,
INT32 to INT64,
INT64 to INT64,
INT to INT,
DECIMAL to DECIMAL,
DECIMAL_ARBITRARY to DECIMAL_ARBITRARY,
FLOAT32 to FLOAT32,
FLOAT64 to FLOAT64,
)
graph[INT] = edges(
INT8 to INT,
INT16 to INT,
INT32 to INT,
INT64 to INT,
INT to INT,
DECIMAL to DECIMAL,
DECIMAL_ARBITRARY to DECIMAL_ARBITRARY,
FLOAT32 to FLOAT32,
FLOAT64 to FLOAT64,
)
graph[DECIMAL] = edges(
INT8 to DECIMAL,
INT16 to DECIMAL,
INT32 to DECIMAL,
INT64 to DECIMAL,
INT to DECIMAL,
DECIMAL to DECIMAL,
DECIMAL_ARBITRARY to DECIMAL_ARBITRARY,
FLOAT32 to FLOAT32,
FLOAT64 to FLOAT64,
)
graph[DECIMAL_ARBITRARY] = edges(
INT8 to DECIMAL_ARBITRARY,
INT16 to DECIMAL_ARBITRARY,
INT32 to DECIMAL_ARBITRARY,
INT64 to DECIMAL_ARBITRARY,
INT to DECIMAL_ARBITRARY,
DECIMAL to DECIMAL_ARBITRARY,
DECIMAL_ARBITRARY to DECIMAL_ARBITRARY,
FLOAT32 to FLOAT32,
FLOAT64 to FLOAT64,
)
graph[FLOAT32] = edges(
INT8 to FLOAT32,
INT16 to FLOAT32,
INT32 to FLOAT32,
INT64 to FLOAT32,
INT to FLOAT32,
DECIMAL to FLOAT32,
DECIMAL_ARBITRARY to FLOAT32,
FLOAT32 to FLOAT32,
FLOAT64 to FLOAT64,
)
graph[FLOAT64] = edges(
INT8 to FLOAT64,
INT16 to FLOAT64,
INT32 to FLOAT64,
INT64 to FLOAT64,
INT to FLOAT64,
DECIMAL to FLOAT64,
DECIMAL_ARBITRARY to FLOAT64,
FLOAT32 to FLOAT64,
FLOAT64 to FLOAT64,
)
graph[CHAR] = edges(
CHAR to CHAR,
STRING to STRING,
SYMBOL to STRING,
alancai98 marked this conversation as resolved.
Show resolved Hide resolved
CLOB to CLOB,
)
graph[STRING] = edges(
CHAR to STRING,
STRING to STRING,
SYMBOL to STRING,
CLOB to CLOB,
)
graph[SYMBOL] = edges(
CHAR to SYMBOL,
STRING to STRING,
SYMBOL to SYMBOL,
CLOB to CLOB,
)
graph[BINARY] = edges(
BINARY to BINARY,
)
graph[BYTE] = edges(
BYTE to BYTE,
BLOB to BLOB,
)
graph[BLOB] = edges(
BYTE to BLOB,
BLOB to BLOB,
)
graph[DATE] = edges(
DATE to DATE,
)
graph[CLOB] = edges(
CHAR to CLOB,
STRING to CLOB,
SYMBOL to CLOB,
CLOB to CLOB,
)
graph[TIME] = edges(
TIME to TIME,
)
graph[TIMESTAMP] = edges(
TIMESTAMP to TIMESTAMP,
)
rchowell marked this conversation as resolved.
Show resolved Hide resolved
graph[INTERVAL] = edges(
INTERVAL to INTERVAL,
)
graph[LIST] = edges(
LIST to LIST,
SEXP to SEXP,
BAG to BAG,
)
graph[SEXP] = edges(
LIST to SEXP,
SEXP to SEXP,
BAG to BAG,
)
graph[BAG] = edges(
LIST to BAG,
SEXP to BAG,
BAG to BAG,
)
graph[STRUCT] = edges(
STRUCT to STRUCT,
)
graph.requireNoNulls()
}
}
}
Loading
Loading