From 581cf6098648b552c8618ddbf2625a95f8b5042c Mon Sep 17 00:00:00 2001 From: Alan Cai Date: Fri, 12 Jul 2024 15:06:08 -0700 Subject: [PATCH 1/4] Upgrade to plk 0.14.6 --- build.gradle.kts | 2 +- .../org/partiql/scribe/shell/ShellHighlighter.kt | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index d804cdd..bfe077c 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -25,7 +25,7 @@ object Versions { const val jline = "3.21.0" const val junit5 = "5.9.3" const val picoCli = "4.7.0" - const val partiql = "0.14.5" + const val partiql = "0.14.6" } object Deps { diff --git a/src/main/kotlin/org/partiql/scribe/shell/ShellHighlighter.kt b/src/main/kotlin/org/partiql/scribe/shell/ShellHighlighter.kt index 86f4dd6..539af04 100644 --- a/src/main/kotlin/org/partiql/scribe/shell/ShellHighlighter.kt +++ b/src/main/kotlin/org/partiql/scribe/shell/ShellHighlighter.kt @@ -13,11 +13,6 @@ */ package org.partiql.scribe.shell -import org.antlr.v4.runtime.BaseErrorListener -import org.antlr.v4.runtime.CharStreams -import org.antlr.v4.runtime.CommonTokenStream -import org.antlr.v4.runtime.RecognitionException -import org.antlr.v4.runtime.Recognizer import org.jline.reader.Highlighter import org.jline.reader.LineReader import org.jline.utils.AttributedString @@ -25,6 +20,11 @@ import org.jline.utils.AttributedStringBuilder import org.jline.utils.AttributedStyle import org.partiql.parser.antlr.PartiQLParser import org.partiql.parser.antlr.PartiQLTokens +import org.partiql.parser.thirdparty.antlr.v4.runtime.BaseErrorListener +import org.partiql.parser.thirdparty.antlr.v4.runtime.CharStreams +import org.partiql.parser.thirdparty.antlr.v4.runtime.CommonTokenStream +import org.partiql.parser.thirdparty.antlr.v4.runtime.RecognitionException +import org.partiql.parser.thirdparty.antlr.v4.runtime.Recognizer import java.nio.charset.StandardCharsets import java.util.regex.Pattern @@ -127,12 +127,12 @@ internal class ShellHighlighter : Highlighter { msg: String?, e: RecognitionException? ) { - if (offendingSymbol != null && offendingSymbol is org.antlr.v4.runtime.Token && offendingSymbol.type != PartiQLParser.EOF) { + if (offendingSymbol != null && offendingSymbol is org.partiql.parser.thirdparty.antlr.v4.runtime.Token && offendingSymbol.type != PartiQLParser.EOF) { throw OffendingSymbolException(offendingSymbol) } } - class OffendingSymbolException(val offendingSymbol: org.antlr.v4.runtime.Token) : Exception() + class OffendingSymbolException(val offendingSymbol: org.partiql.parser.thirdparty.antlr.v4.runtime.Token) : Exception() } private fun getTokenStream(input: String): CommonTokenStream { From d88e463c439bc10b74b12fd818921634874b1975 Mon Sep 17 00:00:00 2001 From: Alan Cai Date: Fri, 12 Jul 2024 15:11:42 -0700 Subject: [PATCH 2/4] Fix tests following 0.14.6 upgrade --- .../outputs/redshift/operators/between.sql | 12 ++++++------ .../resources/outputs/spark/operators/between.sql | 12 ++++++------ .../resources/outputs/trino/operators/between.sql | 14 ++++++-------- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/src/test/resources/outputs/redshift/operators/between.sql b/src/test/resources/outputs/redshift/operators/between.sql index da7dcc9..7491ffa 100644 --- a/src/test/resources/outputs/redshift/operators/between.sql +++ b/src/test/resources/outputs/redshift/operators/between.sql @@ -1,23 +1,23 @@ --#[between-00] -- between(decimal, int32, int32) -SELECT "T_DECIMALS"."da" AS "da" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."da" BETWEEN -1 AND 1; +SELECT "T_DECIMALS"."da" AS "da" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."da" BETWEEN CAST(-1 AS DECIMAL) AND CAST(1 AS DECIMAL); --#[between-01] -- between(decimal, int64, int64) -SELECT "T_DECIMALS"."da" AS "da" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."da" BETWEEN -2147483649 AND 2147483648; +SELECT "T_DECIMALS"."da" AS "da" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."da" BETWEEN CAST(-2147483649 AS DECIMAL) AND CAST(2147483648 AS DECIMAL); --#[between-02] -- between(decimal, decimal, decimal) -SELECT "T_DECIMALS"."da" AS "da" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."da" BETWEEN -9223372036854775809 AND 9223372036854775808; +SELECT "T_DECIMALS"."da" AS "da" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."da" BETWEEN CAST(-9223372036854775809 AS DECIMAL) AND CAST(9223372036854775808 AS DECIMAL); --#[between-04] -- between(decimal(p,s), int32, int32) -SELECT "T_DECIMALS"."de" AS "de" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."de" BETWEEN -1 AND 1; +SELECT "T_DECIMALS"."de" AS "de" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE CAST("T_DECIMALS"."de" AS DECIMAL) BETWEEN CAST(-1 AS DECIMAL) AND CAST(1 AS DECIMAL); --#[between-05] -- between(decimal(p,s), int64, int64) -SELECT "T_DECIMALS"."de" AS "de" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."de" BETWEEN -2147483649 AND 2147483648; +SELECT "T_DECIMALS"."de" AS "de" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE CAST("T_DECIMALS"."de" AS DECIMAL) BETWEEN CAST(-2147483649 AS DECIMAL) AND CAST(2147483648 AS DECIMAL); --#[between-06] -- between(decimal(p,s), decimal, decimal) -SELECT "T_DECIMALS"."de" AS "de" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."de" BETWEEN -9223372036854775809 AND 9223372036854775808; +SELECT "T_DECIMALS"."de" AS "de" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE CAST("T_DECIMALS"."de" AS DECIMAL) BETWEEN CAST(-9223372036854775809 AS DECIMAL) AND CAST(9223372036854775808 AS DECIMAL); diff --git a/src/test/resources/outputs/spark/operators/between.sql b/src/test/resources/outputs/spark/operators/between.sql index 7dc4446..d8fedcc 100644 --- a/src/test/resources/outputs/spark/operators/between.sql +++ b/src/test/resources/outputs/spark/operators/between.sql @@ -1,23 +1,23 @@ --#[between-00] -- between(decimal, int32, int32) -SELECT `T_DECIMALS`.`da` AS `da` FROM `default`.`T_DECIMALS` AS `T_DECIMALS` WHERE `T_DECIMALS`.`da` BETWEEN -1 AND 1; +SELECT `T_DECIMALS`.`da` AS `da` FROM `default`.`T_DECIMALS` AS `T_DECIMALS` WHERE `T_DECIMALS`.`da` BETWEEN CAST(-1 AS DECIMAL) AND CAST(1 AS DECIMAL); --#[between-01] -- between(decimal, int64, int64) -SELECT `T_DECIMALS`.`da` AS `da` FROM `default`.`T_DECIMALS` AS `T_DECIMALS` WHERE `T_DECIMALS`.`da` BETWEEN -2147483649 AND 2147483648; +SELECT `T_DECIMALS`.`da` AS `da` FROM `default`.`T_DECIMALS` AS `T_DECIMALS` WHERE `T_DECIMALS`.`da` BETWEEN CAST(-2147483649 AS DECIMAL) AND CAST(2147483648 AS DECIMAL); --#[between-02] -- between(decimal, decimal, decimal) -SELECT `T_DECIMALS`.`da` AS `da` FROM `default`.`T_DECIMALS` AS `T_DECIMALS` WHERE `T_DECIMALS`.`da` BETWEEN -9223372036854775809 AND 9223372036854775808; +SELECT `T_DECIMALS`.`da` AS `da` FROM `default`.`T_DECIMALS` AS `T_DECIMALS` WHERE `T_DECIMALS`.`da` BETWEEN CAST(-9223372036854775809 AS DECIMAL) AND CAST(9223372036854775808 AS DECIMAL); --#[between-04] -- between(decimal(p,s), int32, int32) -SELECT `T_DECIMALS`.`de` AS `de` FROM `default`.`T_DECIMALS` AS `T_DECIMALS` WHERE `T_DECIMALS`.`de` BETWEEN -1 AND 1; +SELECT `T_DECIMALS`.`de` AS `de` FROM `default`.`T_DECIMALS` AS `T_DECIMALS` WHERE CAST(`T_DECIMALS`.`de` AS DECIMAL) BETWEEN CAST(-1 AS DECIMAL) AND CAST(1 AS DECIMAL); --#[between-05] -- between(decimal(p,s), int64, int64) -SELECT `T_DECIMALS`.`de` AS `de` FROM `default`.`T_DECIMALS` AS `T_DECIMALS` WHERE `T_DECIMALS`.`de` BETWEEN -2147483649 AND 2147483648; +SELECT `T_DECIMALS`.`de` AS `de` FROM `default`.`T_DECIMALS` AS `T_DECIMALS` WHERE CAST(`T_DECIMALS`.`de` AS DECIMAL) BETWEEN CAST(-2147483649 AS DECIMAL) AND CAST(2147483648 AS DECIMAL); --#[between-06] -- between(decimal(p,s), decimal, decimal) -SELECT `T_DECIMALS`.`de` AS `de` FROM `default`.`T_DECIMALS` AS `T_DECIMALS` WHERE `T_DECIMALS`.`de` BETWEEN -9223372036854775809 AND 9223372036854775808; +SELECT `T_DECIMALS`.`de` AS `de` FROM `default`.`T_DECIMALS` AS `T_DECIMALS` WHERE CAST(`T_DECIMALS`.`de` AS DECIMAL) BETWEEN CAST(-9223372036854775809 AS DECIMAL) AND CAST(9223372036854775808 AS DECIMAL); diff --git a/src/test/resources/outputs/trino/operators/between.sql b/src/test/resources/outputs/trino/operators/between.sql index d0659e1..634f434 100644 --- a/src/test/resources/outputs/trino/operators/between.sql +++ b/src/test/resources/outputs/trino/operators/between.sql @@ -1,23 +1,21 @@ --#[between-00] -- between(decimal, int32, int32) -SELECT "T_DECIMALS"."da" AS "da" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."da" BETWEEN -1 AND 1; +SELECT "T_DECIMALS"."da" AS "da" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."da" BETWEEN CAST(-1 AS DECIMAL) AND CAST(1 AS DECIMAL); --#[between-01] -- between(decimal, int64, int64) -SELECT "T_DECIMALS"."da" AS "da" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."da" BETWEEN -2147483649 AND 2147483648; +SELECT "T_DECIMALS"."da" AS "da" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."da" BETWEEN CAST(-2147483649 AS DECIMAL) AND CAST(2147483648 AS DECIMAL); --#[between-02] -- between(decimal, decimal, decimal) -SELECT "T_DECIMALS"."da" AS "da" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."da" BETWEEN -CAST('9223372036854775809' AS DECIMAL(38,0)) AND CAST('9223372036854775808' AS DECIMAL(38,0)); - +SELECT "T_DECIMALS"."da" AS "da" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."da" BETWEEN CAST(-CAST('9223372036854775809' AS DECIMAL(38,0)) AS DECIMAL) AND CAST(CAST('9223372036854775808' AS DECIMAL(38,0)) AS DECIMAL); --#[between-04] -- between(decimal(p,s), int32, int32) -SELECT "T_DECIMALS"."de" AS "de" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."de" BETWEEN -1 AND 1; - +SELECT "T_DECIMALS"."de" AS "de" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE CAST("T_DECIMALS"."de" AS DECIMAL) BETWEEN CAST(-1 AS DECIMAL) AND CAST(1 AS DECIMAL); --#[between-05] -- between(decimal(p,s), int64, int64) -SELECT "T_DECIMALS"."de" AS "de" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."de" BETWEEN -2147483649 AND 2147483648; +SELECT "T_DECIMALS"."de" AS "de" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE CAST("T_DECIMALS"."de" AS DECIMAL) BETWEEN CAST(-2147483649 AS DECIMAL) AND CAST(2147483648 AS DECIMAL); --#[between-06] -- between(decimal(p,s), decimal, decimal) -SELECT "T_DECIMALS"."de" AS "de" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE "T_DECIMALS"."de" BETWEEN -CAST('9223372036854775809' AS DECIMAL(38,0)) AND CAST('9223372036854775808' AS DECIMAL(38,0)); +SELECT "T_DECIMALS"."de" AS "de" FROM "default"."T_DECIMALS" AS "T_DECIMALS" WHERE CAST("T_DECIMALS"."de" AS DECIMAL) BETWEEN CAST(-CAST('9223372036854775809' AS DECIMAL(38,0)) AS DECIMAL) AND CAST(CAST('9223372036854775808' AS DECIMAL(38,0)) AS DECIMAL); From 37d2c6374c56384e674f78de2974e005065ff9af Mon Sep 17 00:00:00 2001 From: Alan Cai Date: Fri, 12 Jul 2024 15:12:39 -0700 Subject: [PATCH 3/4] Add UNION support --- .../org/partiql/scribe/sql/RexConverter.kt | 75 +++++++++++++++---- .../org/partiql/scribe/sql/SqlDialect.kt | 7 ++ .../targets/redshift/RedshiftFeatures.kt | 1 + .../scribe/targets/spark/SparkFeatures.kt | 1 + .../scribe/targets/trino/TrinoFeatures.kt | 1 + .../resources/catalogs/default/SIMPLE_T.ion | 21 ++++++ src/test/resources/inputs/basics/setop.sql | 20 +++++ .../outputs/partiql/basics/setop.sql | 20 +++++ .../outputs/redshift/basics/setop.sql | 20 +++++ .../resources/outputs/spark/basics/setop.sql | 20 +++++ .../resources/outputs/trino/basics/setop.sql | 20 +++++ 11 files changed, 193 insertions(+), 13 deletions(-) create mode 100644 src/test/resources/catalogs/default/SIMPLE_T.ion create mode 100644 src/test/resources/inputs/basics/setop.sql create mode 100644 src/test/resources/outputs/partiql/basics/setop.sql create mode 100644 src/test/resources/outputs/redshift/basics/setop.sql create mode 100644 src/test/resources/outputs/spark/basics/setop.sql create mode 100644 src/test/resources/outputs/trino/basics/setop.sql diff --git a/src/main/kotlin/org/partiql/scribe/sql/RexConverter.kt b/src/main/kotlin/org/partiql/scribe/sql/RexConverter.kt index 261d53e..b8d233f 100644 --- a/src/main/kotlin/org/partiql/scribe/sql/RexConverter.kt +++ b/src/main/kotlin/org/partiql/scribe/sql/RexConverter.kt @@ -3,7 +3,9 @@ package org.partiql.scribe.sql import org.partiql.ast.Expr import org.partiql.ast.Identifier import org.partiql.ast.Select +import org.partiql.ast.SetOp import org.partiql.ast.SetQuantifier +import org.partiql.ast.exprBagOp import org.partiql.ast.exprCall import org.partiql.ast.exprCase import org.partiql.ast.exprCaseBranch @@ -21,9 +23,11 @@ import org.partiql.ast.selectProject import org.partiql.ast.selectProjectItemAll import org.partiql.ast.selectProjectItemExpression import org.partiql.ast.selectValue +import org.partiql.ast.setOp import org.partiql.plan.PlanNode import org.partiql.plan.Rel import org.partiql.plan.Rex +import org.partiql.plan.rexOpSelect import org.partiql.plan.visitor.PlanBaseVisitor import org.partiql.types.BagType import org.partiql.types.ListType @@ -206,19 +210,64 @@ public open class RexConverter( return transform.getFunction(name, args) } + private fun planToAstSetQ(setQuantifier: org.partiql.plan.SetQuantifier?): org.partiql.ast.SetQuantifier { + return when (setQuantifier) { + null, org.partiql.plan.SetQuantifier.DISTINCT -> SetQuantifier.DISTINCT + org.partiql.plan.SetQuantifier.ALL -> SetQuantifier.ALL + } + } + + /** + * Create an ast [Expr.BagOp] from two plan [Rel] nodes (coming from a plan SQL set op). + */ + private fun relSetOpToBagOp(lhs: Rel, rhs: Rel, setq: org.partiql.plan.SetQuantifier, opType: SetOp.Type, ctx: StaticType): Expr.BagOp { + // Since the args to a SQL set op are both SFW queries, re-create an [Expr.SFW] + val lhsRex = rexOpSelect( + constructor = Rex( + type = lhs.type.schema.first().type, + Rex.Op.Var(0) + ), + rel = lhs + ) + val lhsExpr = visitRexOp(node = lhsRex, ctx = ctx) + val rhsRex = rexOpSelect( + constructor = Rex( + type = rhs.type.schema.first().type, + Rex.Op.Var(0) + ), + rel = rhs + ) + val rhsExpr = visitRexOp(node = rhsRex, ctx = ctx) + return exprBagOp( + type = setOp(type = opType, setq = planToAstSetQ(setq)), + lhs = lhsExpr, + rhs = rhsExpr, + outer = false + ) + } + override fun visitRexOpSelect(node: Rex.Op.Select, ctx: StaticType): Expr { - val relToSql = transform.getRelConverter() - val rexToSql = transform.getRexConverter(locals) - val sfw = relToSql.apply(node.rel) - assert(sfw.select != null) { "SELECT from RelConverter should never be null" } - val setq = getSetQuantifier(sfw.select!!) - val select = convertSelectValueToSqlSelect(sfw.select, node.constructor, node.rel, setq) ?: convertSelectValue( - node.constructor, - node.rel, - setq - ) ?: selectValue(rexToSql.apply(node.constructor), setq) - sfw.select = select - return sfw.build() + val rel = node.rel + return when (val op = rel.op) { + // SQL sets modeled as a Rel node taking in two Rels + is Rel.Op.Except -> relSetOpToBagOp(op.lhs, op.rhs, op.setq, SetOp.Type.EXCEPT, ctx) + is Rel.Op.Intersect -> relSetOpToBagOp(op.lhs, op.rhs, op.setq, SetOp.Type.INTERSECT, ctx) + is Rel.Op.Union -> relSetOpToBagOp(op.lhs, op.rhs, op.setq, SetOp.Type.UNION, ctx) + else -> { + val relToSql = transform.getRelConverter() + val rexToSql = transform.getRexConverter(locals) + val sfw = relToSql.apply(rel) + assert(sfw.select != null) { "SELECT from RelConverter should never be null" } + val setq = getSetQuantifier(sfw.select!!) + val select = convertSelectValueToSqlSelect(sfw.select, node.constructor, node.rel, setq) ?: convertSelectValue( + node.constructor, + node.rel, + setq + ) ?: selectValue(rexToSql.apply(node.constructor), setq) + sfw.select = select + return sfw.build() + } + } } override fun visitRexOpSubquery(node: Rex.Op.Subquery, ctx: StaticType): Expr { @@ -279,7 +328,7 @@ public open class RexConverter( val newRexConverter = RexConverter(transform, Locals(relProject.input.type.schema)) val type = constructor.type as? StructType ?: return null if (type.constraints.contains(TupleConstraint.Open(false)) - .not() || type.constraints.contains(TupleConstraint.Ordered).not() + .not() ) { return null } diff --git a/src/main/kotlin/org/partiql/scribe/sql/SqlDialect.kt b/src/main/kotlin/org/partiql/scribe/sql/SqlDialect.kt index e2f043d..d58f530 100644 --- a/src/main/kotlin/org/partiql/scribe/sql/SqlDialect.kt +++ b/src/main/kotlin/org/partiql/scribe/sql/SqlDialect.kt @@ -89,6 +89,13 @@ abstract class SqlDialect : AstBaseVisitor() { t = t concat ")" t } + node is Expr.BagOp -> { + var t = tail + t = t concat "(" + t = visitExprBagOp(node, t) + t = t concat ")" + t + } else -> visitExpr(node, tail) } diff --git a/src/main/kotlin/org/partiql/scribe/targets/redshift/RedshiftFeatures.kt b/src/main/kotlin/org/partiql/scribe/targets/redshift/RedshiftFeatures.kt index 9dc3f4c..7b47dbb 100644 --- a/src/main/kotlin/org/partiql/scribe/targets/redshift/RedshiftFeatures.kt +++ b/src/main/kotlin/org/partiql/scribe/targets/redshift/RedshiftFeatures.kt @@ -24,6 +24,7 @@ public open class RedshiftFeatures : SqlFeatures.Defensive() { Rel.Op.Exclude::class.java, Rel.Op.Exclude.Item::class.java, Rel.Op.Exclude.Step.StructField::class.java, + Rel.Op.Union::class.java, // Do not support Rel.Op.Exclude.Step.CollWildcard -- currently, no efficient way to reconstruct SUPER ARRAYs // // Rex diff --git a/src/main/kotlin/org/partiql/scribe/targets/spark/SparkFeatures.kt b/src/main/kotlin/org/partiql/scribe/targets/spark/SparkFeatures.kt index 4b4d7c8..c3fad08 100644 --- a/src/main/kotlin/org/partiql/scribe/targets/spark/SparkFeatures.kt +++ b/src/main/kotlin/org/partiql/scribe/targets/spark/SparkFeatures.kt @@ -22,6 +22,7 @@ public open class SparkFeatures : SqlFeatures.Defensive() { Rel.Op.Exclude.Item::class.java, Rel.Op.Exclude.Step.StructField::class.java, Rel.Op.Exclude.Step.CollWildcard::class.java, + Rel.Op.Union::class.java, // // Rex // diff --git a/src/main/kotlin/org/partiql/scribe/targets/trino/TrinoFeatures.kt b/src/main/kotlin/org/partiql/scribe/targets/trino/TrinoFeatures.kt index 9edd255..c29fef7 100644 --- a/src/main/kotlin/org/partiql/scribe/targets/trino/TrinoFeatures.kt +++ b/src/main/kotlin/org/partiql/scribe/targets/trino/TrinoFeatures.kt @@ -27,6 +27,7 @@ public open class TrinoFeatures : SqlFeatures.Defensive() { Rel.Op.Exclude.Item::class.java, Rel.Op.Exclude.Step.StructField::class.java, Rel.Op.Exclude.Step.CollWildcard::class.java, + Rel.Op.Union::class.java, // // Rex // diff --git a/src/test/resources/catalogs/default/SIMPLE_T.ion b/src/test/resources/catalogs/default/SIMPLE_T.ion new file mode 100644 index 0000000..260fbee --- /dev/null +++ b/src/test/resources/catalogs/default/SIMPLE_T.ion @@ -0,0 +1,21 @@ +{ + type: "bag", + items: { + type: "struct", + constraints: [ closed, ordered, unique ], + fields: [ + { + name: "a", + type: "bool", + }, + { + name: "b", + type: "int32", + }, + { + name: "c", + type: "string", + } + ] + } +} diff --git a/src/test/resources/inputs/basics/setop.sql b/src/test/resources/inputs/basics/setop.sql new file mode 100644 index 0000000..88013fe --- /dev/null +++ b/src/test/resources/inputs/basics/setop.sql @@ -0,0 +1,20 @@ +-- SQL set ops + +-- SQL UNION +--#[setop-00] +SELECT a FROM SIMPLE_T AS t1 UNION SELECT a FROM SIMPLE_T AS t2; + +--#[setop-01] +SELECT a FROM SIMPLE_T AS t1 UNION ALL SELECT a FROM SIMPLE_T AS t2; + +--#[setop-02] +SELECT a FROM SIMPLE_T AS t1 UNION ALL SELECT a FROM SIMPLE_T AS t2 UNION SELECT a FROM SIMPLE_T AS t3; + +--#[setop-03] +SELECT a FROM SIMPLE_T AS t1 UNION ALL (SELECT a FROM SIMPLE_T AS t2 UNION SELECT a FROM SIMPLE_T AS t3); + +--#[setop-04] +SELECT c, b, a FROM SIMPLE_T AS t1 UNION ALL (SELECT c, b, a FROM SIMPLE_T AS t2 UNION SELECT c, b, a FROM SIMPLE_T AS t3); + +--#[setop-05] +SELECT * FROM SIMPLE_T AS t1 UNION ALL SELECT * FROM SIMPLE_T AS t2 UNION SELECT * FROM SIMPLE_T AS t3; diff --git a/src/test/resources/outputs/partiql/basics/setop.sql b/src/test/resources/outputs/partiql/basics/setop.sql new file mode 100644 index 0000000..4da2464 --- /dev/null +++ b/src/test/resources/outputs/partiql/basics/setop.sql @@ -0,0 +1,20 @@ +-- SQL set ops + +-- SQL UNION +--#[setop-00] +(SELECT "t1"['a'] AS "a" FROM "default"."SIMPLE_T" AS "t1") UNION DISTINCT (SELECT "t2"['a'] AS "a" FROM "default"."SIMPLE_T" AS "t2"); + +--#[setop-01] +(SELECT "t1"['a'] AS "a" FROM "default"."SIMPLE_T" AS "t1") UNION ALL (SELECT "t2"['a'] AS "a" FROM "default"."SIMPLE_T" AS "t2"); + +--#[setop-02] +((SELECT "t1"['a'] AS "a" FROM "default"."SIMPLE_T" AS "t1") UNION ALL (SELECT "t2"['a'] AS "a" FROM "default"."SIMPLE_T" AS "t2")) UNION DISTINCT (SELECT "t3"['a'] AS "a" FROM "default"."SIMPLE_T" AS "t3"); + +--#[setop-03] +(SELECT "t1"['a'] AS "a" FROM "default"."SIMPLE_T" AS "t1") UNION ALL ((SELECT "t2"['a'] AS "a" FROM "default"."SIMPLE_T" AS "t2") UNION DISTINCT (SELECT "t3"['a'] AS "a" FROM "default"."SIMPLE_T" AS "t3")); + +--#[setop-04] +(SELECT "t1"['c'] AS "c", "t1"['b'] AS "b", "t1"['a'] AS "a" FROM "default"."SIMPLE_T" AS "t1") UNION ALL ((SELECT "t2"['c'] AS "c", "t2"['b'] AS "b", "t2"['a'] AS "a" FROM "default"."SIMPLE_T" AS "t2") UNION DISTINCT (SELECT "t3"['c'] AS "c", "t3"['b'] AS "b", "t3"['a'] AS "a" FROM "default"."SIMPLE_T" AS "t3")); + +--#[setop-05] +((SELECT "t1".* FROM "default"."SIMPLE_T" AS "t1") UNION ALL (SELECT "t2".* FROM "default"."SIMPLE_T" AS "t2")) UNION DISTINCT (SELECT "t3".* FROM "default"."SIMPLE_T" AS "t3"); diff --git a/src/test/resources/outputs/redshift/basics/setop.sql b/src/test/resources/outputs/redshift/basics/setop.sql new file mode 100644 index 0000000..08562fa --- /dev/null +++ b/src/test/resources/outputs/redshift/basics/setop.sql @@ -0,0 +1,20 @@ +-- SQL set ops + +-- SQL UNION +--#[setop-00] +(SELECT "t1"."a" AS "a" FROM "default"."SIMPLE_T" AS "t1") UNION DISTINCT (SELECT "t2"."a" AS "a" FROM "default"."SIMPLE_T" AS "t2"); + +--#[setop-01] +(SELECT "t1"."a" AS "a" FROM "default"."SIMPLE_T" AS "t1") UNION ALL (SELECT "t2"."a" AS "a" FROM "default"."SIMPLE_T" AS "t2"); + +--#[setop-02] +((SELECT "t1"."a" AS "a" FROM "default"."SIMPLE_T" AS "t1") UNION ALL (SELECT "t2"."a" AS "a" FROM "default"."SIMPLE_T" AS "t2")) UNION DISTINCT (SELECT "t3"."a" AS "a" FROM "default"."SIMPLE_T" AS "t3"); + +--#[setop-03] +(SELECT "t1"."a" AS "a" FROM "default"."SIMPLE_T" AS "t1") UNION ALL ((SELECT "t2"."a" AS "a" FROM "default"."SIMPLE_T" AS "t2") UNION DISTINCT (SELECT "t3"."a" AS "a" FROM "default"."SIMPLE_T" AS "t3")); + +--#[setop-04] +(SELECT "t1"."c" AS "c", "t1"."b" AS "b", "t1"."a" AS "a" FROM "default"."SIMPLE_T" AS "t1") UNION ALL ((SELECT "t2"."c" AS "c", "t2"."b" AS "b", "t2"."a" AS "a" FROM "default"."SIMPLE_T" AS "t2") UNION DISTINCT (SELECT "t3"."c" AS "c", "t3"."b" AS "b", "t3"."a" AS "a" FROM "default"."SIMPLE_T" AS "t3")); + +--#[setop-05] +((SELECT "t1"."a" AS "a", "t1"."b" AS "b", "t1"."c" AS "c" FROM "default"."SIMPLE_T" AS "t1") UNION ALL (SELECT "t2"."a" AS "a", "t2"."b" AS "b", "t2"."c" AS "c" FROM "default"."SIMPLE_T" AS "t2")) UNION DISTINCT (SELECT "t3"."a" AS "a", "t3"."b" AS "b", "t3"."c" AS "c" FROM "default"."SIMPLE_T" AS "t3"); diff --git a/src/test/resources/outputs/spark/basics/setop.sql b/src/test/resources/outputs/spark/basics/setop.sql new file mode 100644 index 0000000..ad31364 --- /dev/null +++ b/src/test/resources/outputs/spark/basics/setop.sql @@ -0,0 +1,20 @@ +-- SQL set ops + +-- SQL UNION +--#[setop-00] +(SELECT `t1`.`a` AS `a` FROM `default`.`SIMPLE_T` AS `t1`) UNION DISTINCT (SELECT `t2`.`a` AS `a` FROM `default`.`SIMPLE_T` AS `t2`); + +--#[setop-01] +(SELECT `t1`.`a` AS `a` FROM `default`.`SIMPLE_T` AS `t1`) UNION ALL (SELECT `t2`.`a` AS `a` FROM `default`.`SIMPLE_T` AS `t2`); + +--#[setop-02] +((SELECT `t1`.`a` AS `a` FROM `default`.`SIMPLE_T` AS `t1`) UNION ALL (SELECT `t2`.`a` AS `a` FROM `default`.`SIMPLE_T` AS `t2`)) UNION DISTINCT (SELECT `t3`.`a` AS `a` FROM `default`.`SIMPLE_T` AS `t3`); + +--#[setop-03] +(SELECT `t1`.`a` AS `a` FROM `default`.`SIMPLE_T` AS `t1`) UNION ALL ((SELECT `t2`.`a` AS `a` FROM `default`.`SIMPLE_T` AS `t2`) UNION DISTINCT (SELECT `t3`.`a` AS `a` FROM `default`.`SIMPLE_T` AS `t3`)); + +--#[setop-04] +(SELECT `t1`.`c` AS `c`, `t1`.`b` AS `b`, `t1`.`a` AS `a` FROM `default`.`SIMPLE_T` AS `t1`) UNION ALL ((SELECT `t2`.`c` AS `c`, `t2`.`b` AS `b`, `t2`.`a` AS `a` FROM `default`.`SIMPLE_T` AS `t2`) UNION DISTINCT (SELECT `t3`.`c` AS `c`, `t3`.`b` AS `b`, `t3`.`a` AS `a` FROM `default`.`SIMPLE_T` AS `t3`)); + +--#[setop-05] +((SELECT `t1`.`a` AS `a`, `t1`.`b` AS `b`, `t1`.`c` AS `c` FROM `default`.`SIMPLE_T` AS `t1`) UNION ALL (SELECT `t2`.`a` AS `a`, `t2`.`b` AS `b`, `t2`.`c` AS `c` FROM `default`.`SIMPLE_T` AS `t2`)) UNION DISTINCT (SELECT `t3`.`a` AS `a`, `t3`.`b` AS `b`, `t3`.`c` AS `c` FROM `default`.`SIMPLE_T` AS `t3`); diff --git a/src/test/resources/outputs/trino/basics/setop.sql b/src/test/resources/outputs/trino/basics/setop.sql new file mode 100644 index 0000000..08562fa --- /dev/null +++ b/src/test/resources/outputs/trino/basics/setop.sql @@ -0,0 +1,20 @@ +-- SQL set ops + +-- SQL UNION +--#[setop-00] +(SELECT "t1"."a" AS "a" FROM "default"."SIMPLE_T" AS "t1") UNION DISTINCT (SELECT "t2"."a" AS "a" FROM "default"."SIMPLE_T" AS "t2"); + +--#[setop-01] +(SELECT "t1"."a" AS "a" FROM "default"."SIMPLE_T" AS "t1") UNION ALL (SELECT "t2"."a" AS "a" FROM "default"."SIMPLE_T" AS "t2"); + +--#[setop-02] +((SELECT "t1"."a" AS "a" FROM "default"."SIMPLE_T" AS "t1") UNION ALL (SELECT "t2"."a" AS "a" FROM "default"."SIMPLE_T" AS "t2")) UNION DISTINCT (SELECT "t3"."a" AS "a" FROM "default"."SIMPLE_T" AS "t3"); + +--#[setop-03] +(SELECT "t1"."a" AS "a" FROM "default"."SIMPLE_T" AS "t1") UNION ALL ((SELECT "t2"."a" AS "a" FROM "default"."SIMPLE_T" AS "t2") UNION DISTINCT (SELECT "t3"."a" AS "a" FROM "default"."SIMPLE_T" AS "t3")); + +--#[setop-04] +(SELECT "t1"."c" AS "c", "t1"."b" AS "b", "t1"."a" AS "a" FROM "default"."SIMPLE_T" AS "t1") UNION ALL ((SELECT "t2"."c" AS "c", "t2"."b" AS "b", "t2"."a" AS "a" FROM "default"."SIMPLE_T" AS "t2") UNION DISTINCT (SELECT "t3"."c" AS "c", "t3"."b" AS "b", "t3"."a" AS "a" FROM "default"."SIMPLE_T" AS "t3")); + +--#[setop-05] +((SELECT "t1"."a" AS "a", "t1"."b" AS "b", "t1"."c" AS "c" FROM "default"."SIMPLE_T" AS "t1") UNION ALL (SELECT "t2"."a" AS "a", "t2"."b" AS "b", "t2"."c" AS "c" FROM "default"."SIMPLE_T" AS "t2")) UNION DISTINCT (SELECT "t3"."a" AS "a", "t3"."b" AS "b", "t3"."c" AS "c" FROM "default"."SIMPLE_T" AS "t3"); From b5434534a9889ab1849e59d0a555394aeb3b9613 Mon Sep 17 00:00:00 2001 From: Alan Cai Date: Mon, 15 Jul 2024 15:51:25 -0700 Subject: [PATCH 4/4] Add back ordered struct check --- .../org/partiql/scribe/sql/RexConverter.kt | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/main/kotlin/org/partiql/scribe/sql/RexConverter.kt b/src/main/kotlin/org/partiql/scribe/sql/RexConverter.kt index b8d233f..06c6b50 100644 --- a/src/main/kotlin/org/partiql/scribe/sql/RexConverter.kt +++ b/src/main/kotlin/org/partiql/scribe/sql/RexConverter.kt @@ -224,7 +224,7 @@ public open class RexConverter( // Since the args to a SQL set op are both SFW queries, re-create an [Expr.SFW] val lhsRex = rexOpSelect( constructor = Rex( - type = lhs.type.schema.first().type, + type = lhs.type.schema.first().type.asOrderedStruct(), Rex.Op.Var(0) ), rel = lhs @@ -232,7 +232,7 @@ public open class RexConverter( val lhsExpr = visitRexOp(node = lhsRex, ctx = ctx) val rhsRex = rexOpSelect( constructor = Rex( - type = rhs.type.schema.first().type, + type = rhs.type.schema.first().type.asOrderedStruct(), Rex.Op.Var(0) ), rel = rhs @@ -246,6 +246,16 @@ public open class RexConverter( ) } + // Adds the [TupleConstraint.Ordered] for [StructType]s + private fun StaticType.asOrderedStruct(): StaticType { + return when (this) { + is StructType -> this.copy( + constraints = this.constraints + setOf(TupleConstraint.Ordered) + ) + else -> this + } + } + override fun visitRexOpSelect(node: Rex.Op.Select, ctx: StaticType): Expr { val rel = node.rel return when (val op = rel.op) { @@ -328,7 +338,7 @@ public open class RexConverter( val newRexConverter = RexConverter(transform, Locals(relProject.input.type.schema)) val type = constructor.type as? StructType ?: return null if (type.constraints.contains(TupleConstraint.Open(false)) - .not() + .not() || type.constraints.contains(TupleConstraint.Ordered).not() ) { return null }