From abfc58d70e1fc78d14f3d40aab4d6ba5fb08d2b5 Mon Sep 17 00:00:00 2001 From: yliuuuu <107505258+yliuuuu@users.noreply.github.com> Date: Thu, 25 Apr 2024 14:47:10 -0700 Subject: [PATCH] Support parsing for attribute and tuple level constraint (#1442) * Parsing for attribute and tuple constraint Co-authored-by: Alan Cai --- .../org/partiql/ast/helpers/ToLegacyAst.kt | 51 +-- .../src/main/resources/partiql_ast.ion | 91 +++--- .../lang/syntax/impl/PartiQLPigVisitor.kt | 2 +- .../lang/syntax/PartiQLParserDDLTest.kt | 42 ++- partiql-parser/src/main/antlr/PartiQL.g4 | 38 ++- .../parser/internal/PartiQLParserDefault.kt | 90 ++++-- .../parser/internal/PartiQLParserDDLTests.kt | 294 ++++++++++++++++-- 7 files changed, 485 insertions(+), 123 deletions(-) diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt index 0dd1a879ff..0fb163cfe8 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt @@ -16,7 +16,9 @@ import com.amazon.ionelement.api.ionString import com.amazon.ionelement.api.ionSymbol import com.amazon.ionelement.api.metaContainerOf import org.partiql.ast.AstNode +import org.partiql.ast.Constraint import org.partiql.ast.DatetimeField +import org.partiql.ast.DdlOp import org.partiql.ast.Exclude import org.partiql.ast.Expr import org.partiql.ast.From @@ -121,24 +123,23 @@ private class AstTranslator(val metas: Map) : AstBaseVisi domain(statement, type, format, metas) } - override fun visitStatementDDL(node: Statement.DDL, ctx: Ctx) = super.visit(node, ctx) as PartiqlAst.Statement.Ddl + override fun visitStatementDDL(node: Statement.DDL, ctx: Ctx) = when (val op = node.op) { + is DdlOp.CreateIndex -> visitDdlOpCreateIndex(op, ctx) + is DdlOp.CreateTable -> visitDdlOpCreateTable(op, ctx) + is DdlOp.DropIndex -> visitDdlOpDropIndex(op, ctx) + is DdlOp.DropTable -> visitDdlOpDropTable(op, ctx) + } - override fun visitStatementDDLCreateTable( - node: Statement.DDL.CreateTable, - ctx: Ctx, - ) = translate(node) { metas -> + override fun visitDdlOpCreateTable(node: DdlOp.CreateTable, ctx: Ctx) = translate(node) { metas -> if (node.name !is Identifier.Symbol) { error("The legacy AST does not support qualified identifiers as table names") } - val tableName = (node.name as Identifier.Symbol).symbol + val tableName = node.name.symbol val def = node.definition?.let { visitTableDefinition(it, ctx) } ddl(createTable(tableName, def), metas) } - override fun visitStatementDDLCreateIndex( - node: Statement.DDL.CreateIndex, - ctx: Ctx, - ) = translate(node) { metas -> + override fun visitDdlOpCreateIndex(node: DdlOp.CreateIndex, ctx: Ctx) = translate(node) { metas -> if (node.index != null) { error("The legacy AST does not support index names") } @@ -150,7 +151,7 @@ private class AstTranslator(val metas: Map) : AstBaseVisi ddl(createIndex(tableName, fields), metas) } - override fun visitStatementDDLDropTable(node: Statement.DDL.DropTable, ctx: Ctx) = translate(node) { metas -> + override fun visitDdlOpDropTable(node: DdlOp.DropTable, ctx: Ctx) = translate(node) { metas -> if (node.table !is Identifier.Symbol) { error("The legacy AST does not support qualified identifiers as table names") } @@ -159,7 +160,7 @@ private class AstTranslator(val metas: Map) : AstBaseVisi ddl(dropTable(tableName), metas) } - override fun visitStatementDDLDropIndex(node: Statement.DDL.DropIndex, ctx: Ctx) = translate(node) { metas -> + override fun visitDdlOpDropIndex(node: DdlOp.DropIndex, ctx: Ctx) = translate(node) { metas -> if (node.index !is Identifier.Symbol) { error("The legacy AST does not support qualified identifiers as index names") } @@ -174,28 +175,28 @@ private class AstTranslator(val metas: Map) : AstBaseVisi } override fun visitTableDefinition(node: TableDefinition, ctx: Ctx) = translate(node) { metas -> - val parts = node.columns.translate(ctx) + val parts = node.attributes.translate(ctx) + if (node.constraints.isNotEmpty()) { + error("The legacy AST does not support table level constraint declaration") + } tableDef(parts, metas) } - override fun visitTableDefinitionColumn(node: TableDefinition.Column, ctx: Ctx) = translate(node) { metas -> - val name = node.name + override fun visitTableDefinitionAttribute(node: TableDefinition.Attribute, ctx: Ctx) = translate(node) { metas -> + // Legacy AST treat table name as a case-sensitive string + val name = node.name.symbol val type = visitType(node.type, ctx) val constraints = node.constraints.translate(ctx) columnDeclaration(name, type, constraints, metas) } - override fun visitTableDefinitionColumnConstraint( - node: TableDefinition.Column.Constraint, - ctx: Ctx, - ) = translate(node) { metas -> + override fun visitConstraint(node: Constraint, ctx: Ctx) = translate(node) { val name = node.name - val def = when (node.body) { - is TableDefinition.Column.Constraint.Body.Check -> { - throw IllegalArgumentException("PIG AST does not support CHECK () constraint") - } - is TableDefinition.Column.Constraint.Body.NotNull -> columnNotnull() - is TableDefinition.Column.Constraint.Body.Nullable -> columnNull() + val def = when (node.definition) { + is Constraint.Definition.Check -> throw IllegalArgumentException("PIG AST does not support CHECK () constraint") + is Constraint.Definition.NotNull -> columnNotnull() + is Constraint.Definition.Nullable -> columnNull() + is Constraint.Definition.Unique -> throw IllegalArgumentException("PIG AST does not support Unique/Primary Key constraint") } columnConstraint(name, def, metas) } diff --git a/partiql-ast/src/main/resources/partiql_ast.ion b/partiql-ast/src/main/resources/partiql_ast.ion index 67c48663c8..e30da1faec 100644 --- a/partiql-ast/src/main/resources/partiql_ast.ion +++ b/partiql-ast/src/main/resources/partiql_ast.ion @@ -106,32 +106,9 @@ statement::[ ], // Data Definition Language - d_d_l::[ - - // CREATE TABLE [] - create_table::{ - name: identifier, - definition: optional::table_definition, - }, - - // CREATE INDEX [] ON ( [, ]...) - create_index::{ - index: optional::identifier, - table: identifier, - fields: list::[path], - }, - - // DROP TABLE - drop_table::{ - table: identifier, - }, - - // DROP INDEX ON - drop_index::{ - index: identifier, // [0] - table: identifier, // [1] - }, - ], + d_d_l::{ + op: ddl_op + }, // EXEC [.*] exec::{ @@ -151,6 +128,32 @@ statement::[ }, ] +ddl_op::[ + // CREATE TABLE [] + create_table::{ + name: identifier, + definition: optional::table_definition, + }, + + // CREATE INDEX [] ON ( [, ]...) + create_index::{ + index: optional::identifier, + table: identifier, + fields: list::[path], + }, + + // DROP TABLE + drop_table::{ + table: identifier, + }, + + // DROP INDEX ON + drop_index::{ + index: identifier, // [0] + table: identifier, // [1] + }, +] + // PartiQL Type AST nodes // // Several of these are the same "type", but have various syntax rules we wish to capture. @@ -781,27 +784,29 @@ returning::{ ], } -// ` *` -// `( CONSTRAINT )? ` table_definition::{ - columns: list::[column], + attributes: list::[attribute], + // table level constraints + constraints: list::[constraint], _: [ - column::{ - name: string, + attribute::{ + name: '.identifier.symbol', type: '.type', constraints: list::[constraint], - _: [ - // TODO improve modeling language to avoid these wrapped unions - // Also, prefer not to nest more than twice - constraint::{ - name: optional::string, - body: [ - nullable::{}, - not_null::{}, - check::{ expr: expr }, - ], - }, - ], + } + ], +} + +constraint::{ + name: optional::string, + definition: [ + nullable::{}, + not_null::{}, + check::{ expr: expr }, + unique::{ + // for attribute level constraint, we can set this attribute to null + attributes: optional::list::['.identifier.symbol'], + is_primary_key: bool, }, ], } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt index 98a864e91e..8a799a8493 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt @@ -268,7 +268,7 @@ internal class PartiQLPigVisitor( } override fun visitColumnConstraint(ctx: PartiQLParser.ColumnConstraintContext) = PartiqlAst.build { - val name = ctx.columnConstraintName()?.let { visitSymbolPrimitive(it.symbolPrimitive()).name.text } + val name = ctx.constraintName()?.let { visitSymbolPrimitive(it.symbolPrimitive()).name.text } val def = visit(ctx.columnConstraintDef()) as PartiqlAst.ColumnConstraintDef columnConstraint(name, def) } diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserDDLTest.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserDDLTest.kt index 8bdba4c8f2..4c61031c70 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserDDLTest.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserDDLTest.kt @@ -35,7 +35,47 @@ internal class PartiQLParserDDLTest : PartiQLParserTestBase() { query = "DROP Table foo.bar", code = ErrorCode.PARSE_UNEXPECTED_TOKEN, context = mapOf(), - ) + ), + ParserErrorTestCase( + description = "PIG Parser does not support Unique Constraints in CREATE TABLE", + query = """ + CREATE TABLE tbl ( + a INT2 UNIQUE + ) + """.trimIndent(), + code = ErrorCode.PARSE_UNEXPECTED_TOKEN, + context = mapOf(), + ), + ParserErrorTestCase( + description = "PIG Parser does not support Primary Key Constraint in CREATE TABLE", + query = """ + CREATE TABLE tbl ( + a INT2 PRIMARY KEY + ) + """.trimIndent(), + code = ErrorCode.PARSE_UNEXPECTED_TOKEN, + context = mapOf(), + ), + ParserErrorTestCase( + description = "PIG Parser does not support CHECK Constraint in CREATE TABLE", + query = """ + CREATE TABLE tbl ( + a INT2 CHECK(a > 0) + ) + """.trimIndent(), + code = ErrorCode.PARSE_UNEXPECTED_TOKEN, + context = mapOf(), + ), + ParserErrorTestCase( + description = "PIG Parser does not support table constraint in CREATE TABLE", + query = """ + CREATE TABLE tbl ( + check (a > 0) + ) + """.trimIndent(), + code = ErrorCode.PARSE_UNEXPECTED_TOKEN, + context = mapOf(), + ), ) } } diff --git a/partiql-parser/src/main/antlr/PartiQL.g4 b/partiql-parser/src/main/antlr/PartiQL.g4 index 7d05a22ab0..1d37acfcdb 100644 --- a/partiql-parser/src/main/antlr/PartiQL.g4 +++ b/partiql-parser/src/main/antlr/PartiQL.g4 @@ -75,9 +75,8 @@ execCommand qualifiedName : (qualifier+=symbolPrimitive PERIOD)* name=symbolPrimitive; tableName : symbolPrimitive; -tableConstraintName : symbolPrimitive; columnName : symbolPrimitive; -columnConstraintName : symbolPrimitive; +constraintName : symbolPrimitive; ddl : createCommand @@ -100,17 +99,43 @@ tableDef tableDefPart : columnName type columnConstraint* # ColumnDeclaration + | ( CONSTRAINT constraintName )? tableConstraintDef # TableConstrDeclaration + ; + +tableConstraintDef + : checkConstraintDef # TableConstrCheck + | uniqueConstraintDef # TableConstrUnique ; columnConstraint - : ( CONSTRAINT columnConstraintName )? columnConstraintDef + : ( CONSTRAINT constraintName )? columnConstraintDef ; columnConstraintDef - : NOT NULL # ColConstrNotNull - | NULL # ColConstrNull + : NOT NULL # ColConstrNotNull + | NULL # ColConstrNull + | uniqueSpec # ColConstrUnique + | checkConstraintDef # ColConstrCheck + ; + +checkConstraintDef + : CHECK PAREN_LEFT searchCondition PAREN_RIGHT + ; + +uniqueSpec + : PRIMARY KEY # PrimaryKey + | UNIQUE # Unique + ; + +uniqueConstraintDef + : uniqueSpec PAREN_LEFT columnName (COMMA columnName)* PAREN_RIGHT ; +// ::= | OR +// we cannot do exactly that for the way expression precedence is structured in the grammar file. +// but we at least can eliminate SFW query here. +searchCondition : exprOr; + /** * * DATA MANIPULATION LANGUAGE (DML) @@ -192,9 +217,6 @@ conflictTarget : PAREN_LEFT symbolPrimitive (COMMA symbolPrimitive)* PAREN_RIGHT | ON CONSTRAINT constraintName; -constraintName - : symbolPrimitive; - conflictAction : DO NOTHING | DO REPLACE doReplace diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt index 481d08acd7..82e3d395e9 100644 --- a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt @@ -33,7 +33,9 @@ import org.antlr.v4.runtime.atn.PredictionMode import org.antlr.v4.runtime.misc.ParseCancellationException import org.antlr.v4.runtime.tree.TerminalNode import org.partiql.ast.AstNode +import org.partiql.ast.Constraint import org.partiql.ast.DatetimeField +import org.partiql.ast.DdlOp import org.partiql.ast.Exclude import org.partiql.ast.Expr import org.partiql.ast.From @@ -49,8 +51,16 @@ import org.partiql.ast.SetOp import org.partiql.ast.SetQuantifier import org.partiql.ast.Sort import org.partiql.ast.Statement -import org.partiql.ast.TableDefinition import org.partiql.ast.Type +import org.partiql.ast.constraint +import org.partiql.ast.constraintDefinitionCheck +import org.partiql.ast.constraintDefinitionNotNull +import org.partiql.ast.constraintDefinitionNullable +import org.partiql.ast.constraintDefinitionUnique +import org.partiql.ast.ddlOpCreateIndex +import org.partiql.ast.ddlOpCreateTable +import org.partiql.ast.ddlOpDropIndex +import org.partiql.ast.ddlOpDropTable import org.partiql.ast.exclude import org.partiql.ast.excludeItem import org.partiql.ast.excludeStepCollIndex @@ -144,10 +154,7 @@ import org.partiql.ast.selectStar import org.partiql.ast.selectValue import org.partiql.ast.setOp import org.partiql.ast.sort -import org.partiql.ast.statementDDLCreateIndex -import org.partiql.ast.statementDDLCreateTable -import org.partiql.ast.statementDDLDropIndex -import org.partiql.ast.statementDDLDropTable +import org.partiql.ast.statementDDL import org.partiql.ast.statementDMLBatchLegacy import org.partiql.ast.statementDMLBatchLegacyOpDelete import org.partiql.ast.statementDMLBatchLegacyOpInsert @@ -168,10 +175,7 @@ import org.partiql.ast.statementExplain import org.partiql.ast.statementExplainTargetDomain import org.partiql.ast.statementQuery import org.partiql.ast.tableDefinition -import org.partiql.ast.tableDefinitionColumn -import org.partiql.ast.tableDefinitionColumnConstraint -import org.partiql.ast.tableDefinitionColumnConstraintBodyNotNull -import org.partiql.ast.tableDefinitionColumnConstraintBodyNullable +import org.partiql.ast.tableDefinitionAttribute import org.partiql.ast.typeAny import org.partiql.ast.typeBag import org.partiql.ast.typeBlob @@ -588,23 +592,25 @@ internal class PartiQLParserDefault : PartiQLParser { * */ - override fun visitQueryDdl(ctx: GeneratedParser.QueryDdlContext): AstNode = visitDdl(ctx.ddl()) + override fun visitQueryDdl(ctx: GeneratedParser.QueryDdlContext): AstNode = translate(ctx) { + statementDDL(visitAs(ctx.ddl())) + } override fun visitDropTable(ctx: GeneratedParser.DropTableContext) = translate(ctx) { val table = visitQualifiedName(ctx.qualifiedName()) - statementDDLDropTable(table) + ddlOpDropTable(table) } override fun visitDropIndex(ctx: GeneratedParser.DropIndexContext) = translate(ctx) { val table = visitSymbolPrimitive(ctx.on) val index = visitSymbolPrimitive(ctx.target) - statementDDLDropIndex(index, table) + ddlOpDropIndex(index, table) } override fun visitCreateTable(ctx: GeneratedParser.CreateTableContext) = translate(ctx) { val table = visitQualifiedName(ctx.qualifiedName()) val definition = ctx.tableDef()?.let { visitTableDef(it) } - statementDDLCreateTable(table, definition) + ddlOpCreateTable(table, definition) } override fun visitCreateIndex(ctx: GeneratedParser.CreateIndexContext) = translate(ctx) { @@ -612,7 +618,7 @@ internal class PartiQLParserDefault : PartiQLParser { val name: Identifier? = null val table = visitSymbolPrimitive(ctx.symbolPrimitive()) val fields = ctx.pathSimple().map { path -> visitPathSimple(path) } - statementDDLCreateIndex(name, table, fields) + ddlOpCreateIndex(name, table, fields) } override fun visitTableDef(ctx: GeneratedParser.TableDefContext) = translate(ctx) { @@ -620,30 +626,60 @@ internal class PartiQLParserDefault : PartiQLParser { val columns = ctx.tableDefPart().filterIsInstance().map { visitColumnDeclaration(it) } - tableDefinition(columns) + + val tblConstr = ctx.tableDefPart().filterIsInstance().map { + visitTableConstrDeclaration(it) + } + + tableDefinition(columns, tblConstr) } override fun visitColumnDeclaration(ctx: GeneratedParser.ColumnDeclarationContext) = translate(ctx) { - val name = symbolToString(ctx.columnName().symbolPrimitive()) + val name = visitAs (ctx.columnName().symbolPrimitive()) val type = visit(ctx.type()) as Type - val constraints = ctx.columnConstraint().map { - visitColumnConstraint(it) + val constraints = ctx.columnConstraint().map { constrCtx -> + val identifier = constrCtx.constraintName()?.let { symbolToString(it.symbolPrimitive()) } + val body = visit(constrCtx.columnConstraintDef()) as Constraint.Definition + constraint(identifier, body) } - tableDefinitionColumn(name, type, constraints) - } - - override fun visitColumnConstraint(ctx: GeneratedParser.ColumnConstraintContext) = translate(ctx) { - val identifier = ctx.columnConstraintName()?.let { symbolToString(it.symbolPrimitive()) } - val body = visit(ctx.columnConstraintDef()) as TableDefinition.Column.Constraint.Body - tableDefinitionColumnConstraint(identifier, body) + tableDefinitionAttribute(name, type, constraints) } override fun visitColConstrNotNull(ctx: GeneratedParser.ColConstrNotNullContext) = translate(ctx) { - tableDefinitionColumnConstraintBodyNotNull() + constraintDefinitionNotNull() } override fun visitColConstrNull(ctx: GeneratedParser.ColConstrNullContext) = translate(ctx) { - tableDefinitionColumnConstraintBodyNullable() + constraintDefinitionNullable() + } + + override fun visitColConstrUnique(ctx: GeneratedParser.ColConstrUniqueContext) = translate(ctx) { + when (ctx.uniqueSpec()) { + is GeneratedParser.PrimaryKeyContext -> constraintDefinitionUnique(null, true) + is GeneratedParser.UniqueContext -> constraintDefinitionUnique(null, false) + else -> throw error(ctx, "Expect UNIQUE or PRIMARY KEY") + } + } + + override fun visitCheckConstraintDef(ctx: GeneratedParser.CheckConstraintDefContext) = translate(ctx) { + val searchCondition = visitAs(ctx.searchCondition()) + constraintDefinitionCheck(searchCondition) + } + + override fun visitUniqueConstraintDef(ctx: GeneratedParser.UniqueConstraintDefContext) = translate(ctx) { + val isPrimaryKey = when (ctx.uniqueSpec()) { + is GeneratedParser.PrimaryKeyContext -> true + is GeneratedParser.UniqueContext -> false + else -> throw error(ctx, "Expect UNIQUE or PRIMARY KEY") + } + val columns = ctx.columnName().map { visitAs (it.symbolPrimitive()) } + constraintDefinitionUnique(columns, isPrimaryKey) + } + + override fun visitTableConstrDeclaration(ctx: GeneratedParser.TableConstrDeclarationContext) = translate(ctx) { + val identifier = ctx.constraintName()?.let { symbolToString(it.symbolPrimitive()) } + val body = visit(ctx.tableConstraintDef()) as Constraint.Definition + constraint(identifier, body) } /** diff --git a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserDDLTests.kt b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserDDLTests.kt index 3fbb0321a4..ee80ca2a2d 100644 --- a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserDDLTests.kt +++ b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserDDLTests.kt @@ -1,16 +1,32 @@ package org.partiql.parser.internal +import org.junit.jupiter.api.assertThrows import org.junit.jupiter.api.extension.ExtensionContext import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource -import org.partiql.ast.AstNode +import org.partiql.ast.DdlOp +import org.partiql.ast.Expr import org.partiql.ast.Identifier +import org.partiql.ast.Type +import org.partiql.ast.constraint +import org.partiql.ast.constraintDefinitionCheck +import org.partiql.ast.constraintDefinitionNotNull +import org.partiql.ast.constraintDefinitionUnique +import org.partiql.ast.ddlOpCreateTable +import org.partiql.ast.ddlOpDropTable +import org.partiql.ast.exprBinary +import org.partiql.ast.exprLit +import org.partiql.ast.exprVar import org.partiql.ast.identifierQualified import org.partiql.ast.identifierSymbol -import org.partiql.ast.statementDDLCreateTable -import org.partiql.ast.statementDDLDropTable +import org.partiql.ast.statementDDL +import org.partiql.ast.tableDefinition +import org.partiql.ast.tableDefinitionAttribute +import org.partiql.parser.PartiQLParserException +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.int32Value import java.util.stream.Stream import kotlin.test.assertEquals @@ -21,21 +37,35 @@ class PartiQLParserDDLTests { data class SuccessTestCase( val description: String? = null, val query: String, - val node: AstNode + val expectedOp: DdlOp ) - @ArgumentsSource(TestProvider::class) + data class ErrorTestCase( + val description: String? = null, + val query: String, + ) + + @ArgumentsSource(SuccessTestProvider::class) @ParameterizedTest - fun errorTests(tc: SuccessTestCase) = assertExpression(tc.query, tc.node) + fun successTests(tc: SuccessTestCase) = assertExpression(tc.query, tc.expectedOp) - class TestProvider : ArgumentsProvider { + @ArgumentsSource(ErrorTestProvider::class) + @ParameterizedTest + fun errorTests(tc: ErrorTestCase) = assertIssue(tc.query) + + class SuccessTestProvider : ArgumentsProvider { + @OptIn(PartiQLValueExperimental::class) val createTableTests = listOf( + // + // Qualified Identifier as Table Name + // + SuccessTestCase( "CREATE TABLE with unqualified case insensitive name", "CREATE TABLE foo", - statementDDLCreateTable( + ddlOpCreateTable( identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), - null + null, ) ), // Support Case Sensitive identifier as table name @@ -44,7 +74,7 @@ class PartiQLParserDDLTests { SuccessTestCase( "CREATE TABLE with unqualified case sensitive name", "CREATE TABLE \"foo\"", - statementDDLCreateTable( + ddlOpCreateTable( identifierSymbol("foo", Identifier.CaseSensitivity.SENSITIVE), null ) @@ -52,7 +82,7 @@ class PartiQLParserDDLTests { SuccessTestCase( "CREATE TABLE with qualified case insensitive name", "CREATE TABLE myCatalog.mySchema.foo", - statementDDLCreateTable( + ddlOpCreateTable( identifierQualified( identifierSymbol("myCatalog", Identifier.CaseSensitivity.INSENSITIVE), listOf( @@ -66,7 +96,7 @@ class PartiQLParserDDLTests { SuccessTestCase( "CREATE TABLE with qualified name with mixed case sensitivity", "CREATE TABLE myCatalog.\"mySchema\".foo", - statementDDLCreateTable( + ddlOpCreateTable( identifierQualified( identifierSymbol("myCatalog", Identifier.CaseSensitivity.INSENSITIVE), listOf( @@ -77,27 +107,232 @@ class PartiQLParserDDLTests { null ) ), + + // + // Column Constraints + // + SuccessTestCase( + "CREATE TABLE with Column NOT NULL Constraint", + """ + CREATE TABLE tbl ( + a INT2 NOT NULL + ) + """.trimIndent(), + ddlOpCreateTable( + identifierSymbol("tbl", Identifier.CaseSensitivity.INSENSITIVE), + tableDefinition( + listOf( + tableDefinitionAttribute( + identifierSymbol("a", Identifier.CaseSensitivity.INSENSITIVE), + Type.Int2(), + listOf(constraint(null, constraintDefinitionNotNull())), + ) + ), + emptyList() + ) + ) + ), + + SuccessTestCase( + "CREATE TABLE with Column Unique Constraint", + """ + CREATE TABLE tbl ( + a INT2 UNIQUE + ) + """.trimIndent(), + ddlOpCreateTable( + identifierSymbol("tbl", Identifier.CaseSensitivity.INSENSITIVE), + tableDefinition( + listOf( + tableDefinitionAttribute( + identifierSymbol("a", Identifier.CaseSensitivity.INSENSITIVE), + Type.Int2(), + listOf(constraint(null, constraintDefinitionUnique(null, false))), + ) + ), + emptyList() + ), + ) + ), + + SuccessTestCase( + "CREATE TABLE with Column Primary Key Constraint", + """ + CREATE TABLE tbl ( + a INT2 PRIMARY KEY + ) + """.trimIndent(), + ddlOpCreateTable( + identifierSymbol("tbl", Identifier.CaseSensitivity.INSENSITIVE), + tableDefinition( + listOf( + tableDefinitionAttribute( + identifierSymbol("a", Identifier.CaseSensitivity.INSENSITIVE), + Type.Int2(), + listOf(constraint(null, constraintDefinitionUnique(null, true))), + ) + ), + emptyList() + ), + ) + ), + + SuccessTestCase( + "CREATE TABLE with Column CHECK Constraint", + """ + CREATE TABLE tbl ( + a INT2 CHECK (a > 0) + ) + """.trimIndent(), + ddlOpCreateTable( + identifierSymbol("tbl", Identifier.CaseSensitivity.INSENSITIVE), + tableDefinition( + listOf( + tableDefinitionAttribute( + identifierSymbol("a", Identifier.CaseSensitivity.INSENSITIVE), + Type.Int2(), + listOf( + constraint( + null, + constraintDefinitionCheck( + exprBinary( + Expr.Binary.Op.GT, + exprVar(identifierSymbol("a", Identifier.CaseSensitivity.INSENSITIVE), Expr.Var.Scope.DEFAULT), + exprLit(int32Value(0)) + ) + ) + ) + ), + ) + ), + emptyList() + ), + ) + ), + + SuccessTestCase( + "CREATE TABLE with Table Unique Constraint", + """ + CREATE TABLE tbl ( + UNIQUE (a, b) + ) + """.trimIndent(), + ddlOpCreateTable( + identifierSymbol("tbl", Identifier.CaseSensitivity.INSENSITIVE), + tableDefinition( + emptyList(), + listOf( + constraint( + null, + constraintDefinitionUnique( + listOf( + identifierSymbol("a", Identifier.CaseSensitivity.INSENSITIVE), + identifierSymbol("b", Identifier.CaseSensitivity.INSENSITIVE), + ), + false + ) + ) + ) + ), + ) + ), + + SuccessTestCase( + "CREATE TABLE with Table Primary Key Constraint", + """ + CREATE TABLE tbl ( + PRIMARY KEY (a, b) + ) + """.trimIndent(), + ddlOpCreateTable( + identifierSymbol("tbl", Identifier.CaseSensitivity.INSENSITIVE), + tableDefinition( + emptyList(), + listOf( + constraint( + null, + constraintDefinitionUnique( + listOf( + identifierSymbol("a", Identifier.CaseSensitivity.INSENSITIVE), + identifierSymbol("b", Identifier.CaseSensitivity.INSENSITIVE), + ), + true + ) + ) + ) + ), + ) + ), + + SuccessTestCase( + "CREATE TABLE with Table CHECK Constraint", + """ + CREATE TABLE tbl ( + CHECK (a > 0) + ) + """.trimIndent(), + ddlOpCreateTable( + identifierSymbol("tbl", Identifier.CaseSensitivity.INSENSITIVE), + tableDefinition( + emptyList(), + listOf( + constraint( + null, + constraintDefinitionCheck( + exprBinary( + Expr.Binary.Op.GT, + exprVar(identifierSymbol("a", Identifier.CaseSensitivity.INSENSITIVE), Expr.Var.Scope.DEFAULT), + exprLit(int32Value(0)) + ) + ) + ) + ) + ), + ) + ), + + SuccessTestCase( + "CREATE TABLE with CASE SENSITIVE Identifier as column name", + """ + CREATE TABLE tbl ( + "a" INT2 + ) + """.trimIndent(), + ddlOpCreateTable( + identifierSymbol("tbl", Identifier.CaseSensitivity.INSENSITIVE), + tableDefinition( + listOf( + tableDefinitionAttribute( + identifierSymbol("a", Identifier.CaseSensitivity.SENSITIVE), + Type.Int2(), + emptyList(), + ) + ), + emptyList() + ), + ) + ), ) val dropTableTests = listOf( SuccessTestCase( "DROP TABLE with unqualified case insensitive name", "DROP TABLE foo", - statementDDLDropTable( + ddlOpDropTable( identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), ) ), SuccessTestCase( "DROP TABLE with unqualified case sensitive name", "DROP TABLE \"foo\"", - statementDDLDropTable( + ddlOpDropTable( identifierSymbol("foo", Identifier.CaseSensitivity.SENSITIVE), ) ), SuccessTestCase( "DROP TABLE with qualified case insensitive name", "DROP TABLE myCatalog.mySchema.foo", - statementDDLDropTable( + ddlOpDropTable( identifierQualified( identifierSymbol("myCatalog", Identifier.CaseSensitivity.INSENSITIVE), listOf( @@ -110,7 +345,7 @@ class PartiQLParserDDLTests { SuccessTestCase( "DROP TABLE with qualified name with mixed case sensitivity", "DROP TABLE myCatalog.\"mySchema\".foo", - statementDDLDropTable( + ddlOpDropTable( identifierQualified( identifierSymbol("myCatalog", Identifier.CaseSensitivity.INSENSITIVE), listOf( @@ -126,9 +361,32 @@ class PartiQLParserDDLTests { (createTableTests + dropTableTests).map { Arguments.of(it) }.stream() } - private fun assertExpression(input: String, expected: AstNode) { + class ErrorTestProvider : ArgumentsProvider { + + val errorTestCases = listOf( + ErrorTestCase( + "Create Table Illegal Check Expression", + """ + CREATE TABLE TBL( + CHECK (SELECT a FROM foo) + ) + """.trimIndent() + ) + ) + override fun provideArguments(p0: ExtensionContext?): Stream = + errorTestCases.map { Arguments.of(it) }.stream() + } + + private fun assertExpression(input: String, expected: DdlOp) { val result = parser.parse(input) val actual = result.root - assertEquals(expected, actual) + assertEquals(statementDDL(expected), actual) + } + + // For now, just assert throw + private fun assertIssue(input: String) { + assertThrows { + parser.parse(input) + } } }