Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SqlDialect's defaultReturn return type to be SqlBlock #1357

Merged
merged 3 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ Thank you to all who have contributed!
### Deprecated

### Fixed
- Return type of `partiql-ast`'s `SqlDialect` for `defaultReturn` to be a `SqlBlock` rather than `Nothing`

### Removed

### Security

### Contributors
Thank you to all who have contributed!
- @<your-username>
- @alancai98

## [0.14.2] - 2024-01-25

Expand Down
106 changes: 53 additions & 53 deletions partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
public val PARTIQL = object : SqlDialect() {}
}

override fun defaultReturn(node: AstNode, head: SqlBlock) =
override fun defaultReturn(node: AstNode, head: SqlBlock): SqlBlock =
throw UnsupportedOperationException("Cannot print $node")

// STATEMENTS

override fun visitStatementQuery(node: Statement.Query, head: SqlBlock) = visitExpr(node.expr, head)
override fun visitStatementQuery(node: Statement.Query, head: SqlBlock): SqlBlock = visitExpr(node.expr, head)

// IDENTIFIERS & PATHS

Expand All @@ -55,7 +55,7 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
* @param node
* @param head
*/
public open fun visitExprWrapped(node: Expr, head: SqlBlock) = when (node) {
public open fun visitExprWrapped(node: Expr, head: SqlBlock): SqlBlock = when (node) {
is Expr.SFW -> {
var h = head
h = h concat "("
Expand All @@ -66,7 +66,7 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
else -> visitExpr(node, head)
}

override fun visitIdentifierSymbol(node: Identifier.Symbol, head: SqlBlock) = head concat r(node.sql())
override fun visitIdentifierSymbol(node: Identifier.Symbol, head: SqlBlock): SqlBlock = head concat r(node.sql())

override fun visitIdentifierQualified(node: Identifier.Qualified, head: SqlBlock): SqlBlock {
val path = node.steps.fold(node.root.sql()) { p, step -> p + "." + step.sql() }
Expand Down Expand Up @@ -116,93 +116,93 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
}

// cannot write path step outside the context of a path as we don't want it to reflow
override fun visitPathStep(node: Path.Step, head: SqlBlock) = error("path step cannot be written directly")
override fun visitPathStep(node: Path.Step, head: SqlBlock): SqlBlock = error("path step cannot be written directly")

override fun visitPathStepSymbol(node: Path.Step.Symbol, head: SqlBlock) = visitPathStep(node, head)
override fun visitPathStepSymbol(node: Path.Step.Symbol, head: SqlBlock): SqlBlock = visitPathStep(node, head)

override fun visitPathStepIndex(node: Path.Step.Index, head: SqlBlock) = visitPathStep(node, head)
override fun visitPathStepIndex(node: Path.Step.Index, head: SqlBlock): SqlBlock = visitPathStep(node, head)
alancai98 marked this conversation as resolved.
Show resolved Hide resolved

// TYPES

override fun visitTypeNullType(node: Type.NullType, head: SqlBlock) = head concat r("NULL")
override fun visitTypeNullType(node: Type.NullType, head: SqlBlock): SqlBlock = head concat r("NULL")

override fun visitTypeMissing(node: Type.Missing, head: SqlBlock) = head concat r("MISSING")
override fun visitTypeMissing(node: Type.Missing, head: SqlBlock): SqlBlock = head concat r("MISSING")

override fun visitTypeBool(node: Type.Bool, head: SqlBlock) = head concat r("BOOL")
override fun visitTypeBool(node: Type.Bool, head: SqlBlock): SqlBlock = head concat r("BOOL")

override fun visitTypeTinyint(node: Type.Tinyint, head: SqlBlock) = head concat r("TINYINT")
override fun visitTypeTinyint(node: Type.Tinyint, head: SqlBlock): SqlBlock = head concat r("TINYINT")

override fun visitTypeSmallint(node: Type.Smallint, head: SqlBlock) = head concat r("SMALLINT")
override fun visitTypeSmallint(node: Type.Smallint, head: SqlBlock): SqlBlock = head concat r("SMALLINT")

override fun visitTypeInt2(node: Type.Int2, head: SqlBlock) = head concat r("INT2")
override fun visitTypeInt2(node: Type.Int2, head: SqlBlock): SqlBlock = head concat r("INT2")

override fun visitTypeInt4(node: Type.Int4, head: SqlBlock) = head concat r("INT4")
override fun visitTypeInt4(node: Type.Int4, head: SqlBlock): SqlBlock = head concat r("INT4")

override fun visitTypeBigint(node: Type.Bigint, head: SqlBlock) = head concat r("BIGINT")
override fun visitTypeBigint(node: Type.Bigint, head: SqlBlock): SqlBlock = head concat r("BIGINT")

override fun visitTypeInt8(node: Type.Int8, head: SqlBlock) = head concat r("INT8")
override fun visitTypeInt8(node: Type.Int8, head: SqlBlock): SqlBlock = head concat r("INT8")

override fun visitTypeInt(node: Type.Int, head: SqlBlock) = head concat r("INT")
override fun visitTypeInt(node: Type.Int, head: SqlBlock): SqlBlock = head concat r("INT")

override fun visitTypeReal(node: Type.Real, head: SqlBlock) = head concat r("REAL")
override fun visitTypeReal(node: Type.Real, head: SqlBlock): SqlBlock = head concat r("REAL")

override fun visitTypeFloat32(node: Type.Float32, head: SqlBlock) = head concat r("FLOAT32")
override fun visitTypeFloat32(node: Type.Float32, head: SqlBlock): SqlBlock = head concat r("FLOAT32")

override fun visitTypeFloat64(node: Type.Float64, head: SqlBlock) = head concat r("DOUBLE PRECISION")
override fun visitTypeFloat64(node: Type.Float64, head: SqlBlock): SqlBlock = head concat r("DOUBLE PRECISION")

override fun visitTypeDecimal(node: Type.Decimal, head: SqlBlock) =
override fun visitTypeDecimal(node: Type.Decimal, head: SqlBlock): SqlBlock =
head concat type("DECIMAL", node.precision, node.scale)

override fun visitTypeNumeric(node: Type.Numeric, head: SqlBlock) =
override fun visitTypeNumeric(node: Type.Numeric, head: SqlBlock): SqlBlock =
head concat type("NUMERIC", node.precision, node.scale)

override fun visitTypeChar(node: Type.Char, head: SqlBlock) = head concat type("CHAR", node.length)
override fun visitTypeChar(node: Type.Char, head: SqlBlock): SqlBlock = head concat type("CHAR", node.length)

override fun visitTypeVarchar(node: Type.Varchar, head: SqlBlock) = head concat type("VARCHAR", node.length)
override fun visitTypeVarchar(node: Type.Varchar, head: SqlBlock): SqlBlock = head concat type("VARCHAR", node.length)

override fun visitTypeString(node: Type.String, head: SqlBlock) = head concat r("STRING")
override fun visitTypeString(node: Type.String, head: SqlBlock): SqlBlock = head concat r("STRING")

override fun visitTypeSymbol(node: Type.Symbol, head: SqlBlock) = head concat r("SYMBOL")
override fun visitTypeSymbol(node: Type.Symbol, head: SqlBlock): SqlBlock = head concat r("SYMBOL")

override fun visitTypeBit(node: Type.Bit, head: SqlBlock) = head concat type("BIT", node.length)
override fun visitTypeBit(node: Type.Bit, head: SqlBlock): SqlBlock = head concat type("BIT", node.length)

override fun visitTypeBitVarying(node: Type.BitVarying, head: SqlBlock) = head concat type("BINARY", node.length)
override fun visitTypeBitVarying(node: Type.BitVarying, head: SqlBlock): SqlBlock = head concat type("BINARY", node.length)

override fun visitTypeByteString(node: Type.ByteString, head: SqlBlock) = head concat type("BYTE", node.length)
override fun visitTypeByteString(node: Type.ByteString, head: SqlBlock): SqlBlock = head concat type("BYTE", node.length)

override fun visitTypeBlob(node: Type.Blob, head: SqlBlock) = head concat type("BLOB", node.length)
override fun visitTypeBlob(node: Type.Blob, head: SqlBlock): SqlBlock = head concat type("BLOB", node.length)

override fun visitTypeClob(node: Type.Clob, head: SqlBlock) = head concat type("CLOB", node.length)
override fun visitTypeClob(node: Type.Clob, head: SqlBlock): SqlBlock = head concat type("CLOB", node.length)

override fun visitTypeBag(node: Type.Bag, head: SqlBlock) = head concat r("BAG")
override fun visitTypeBag(node: Type.Bag, head: SqlBlock): SqlBlock = head concat r("BAG")

override fun visitTypeList(node: Type.List, head: SqlBlock) = head concat r("LIST")
override fun visitTypeList(node: Type.List, head: SqlBlock): SqlBlock = head concat r("LIST")

override fun visitTypeSexp(node: Type.Sexp, head: SqlBlock) = head concat r("SEXP")
override fun visitTypeSexp(node: Type.Sexp, head: SqlBlock): SqlBlock = head concat r("SEXP")

override fun visitTypeTuple(node: Type.Tuple, head: SqlBlock) = head concat r("TUPLE")
override fun visitTypeTuple(node: Type.Tuple, head: SqlBlock): SqlBlock = head concat r("TUPLE")

override fun visitTypeStruct(node: Type.Struct, head: SqlBlock) = head concat r("STRUCT")
override fun visitTypeStruct(node: Type.Struct, head: SqlBlock): SqlBlock = head concat r("STRUCT")

override fun visitTypeAny(node: Type.Any, head: SqlBlock) = head concat r("ANY")
override fun visitTypeAny(node: Type.Any, head: SqlBlock): SqlBlock = head concat r("ANY")

override fun visitTypeDate(node: Type.Date, head: SqlBlock) = head concat r("DATE")
override fun visitTypeDate(node: Type.Date, head: SqlBlock): SqlBlock = head concat r("DATE")

override fun visitTypeTime(node: Type.Time, head: SqlBlock): SqlBlock = head concat type("TIME", node.precision)

override fun visitTypeTimeWithTz(node: Type.TimeWithTz, head: SqlBlock) =
override fun visitTypeTimeWithTz(node: Type.TimeWithTz, head: SqlBlock): SqlBlock =
head concat type("TIME WITH TIMEZONE", node.precision, gap = true)

override fun visitTypeTimestamp(node: Type.Timestamp, head: SqlBlock) =
override fun visitTypeTimestamp(node: Type.Timestamp, head: SqlBlock): SqlBlock =
head concat type("TIMESTAMP", node.precision)

override fun visitTypeTimestampWithTz(node: Type.TimestampWithTz, head: SqlBlock) =
override fun visitTypeTimestampWithTz(node: Type.TimestampWithTz, head: SqlBlock): SqlBlock =
head concat type("TIMESTAMP WITH TIMEZONE", node.precision, gap = true)

override fun visitTypeInterval(node: Type.Interval, head: SqlBlock) = head concat type("INTERVAL", node.precision)
override fun visitTypeInterval(node: Type.Interval, head: SqlBlock): SqlBlock = head concat type("INTERVAL", node.precision)

// unsupported
override fun visitTypeCustom(node: Type.Custom, head: SqlBlock) = defaultReturn(node, head)
override fun visitTypeCustom(node: Type.Custom, head: SqlBlock): SqlBlock = defaultReturn(node, head)

// Expressions

Expand Down Expand Up @@ -274,7 +274,7 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
return h
}

override fun visitExprSessionAttribute(node: Expr.SessionAttribute, head: SqlBlock) =
override fun visitExprSessionAttribute(node: Expr.SessionAttribute, head: SqlBlock): SqlBlock =
head concat r(node.attribute.name)

override fun visitExprPath(node: Expr.Path, head: SqlBlock): SqlBlock {
Expand All @@ -283,7 +283,7 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
return h
}

override fun visitExprPathStepSymbol(node: Expr.Path.Step.Symbol, head: SqlBlock) =
override fun visitExprPathStepSymbol(node: Expr.Path.Step.Symbol, head: SqlBlock): SqlBlock =
head concat r(".${node.symbol.sql()}")

override fun visitExprPathStepIndex(node: Expr.Path.Step.Index, head: SqlBlock): SqlBlock {
Expand All @@ -296,9 +296,9 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
return h
}

override fun visitExprPathStepWildcard(node: Expr.Path.Step.Wildcard, head: SqlBlock) = head concat r("[*]")
override fun visitExprPathStepWildcard(node: Expr.Path.Step.Wildcard, head: SqlBlock): SqlBlock = head concat r("[*]")

override fun visitExprPathStepUnpivot(node: Expr.Path.Step.Unpivot, head: SqlBlock) = head concat r(".*")
override fun visitExprPathStepUnpivot(node: Expr.Path.Step.Unpivot, head: SqlBlock): SqlBlock = head concat r(".*")

override fun visitExprCall(node: Expr.Call, head: SqlBlock): SqlBlock {
var h = head
Expand All @@ -320,11 +320,11 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
return h
}

override fun visitExprParameter(node: Expr.Parameter, head: SqlBlock) = head concat r("?")
override fun visitExprParameter(node: Expr.Parameter, head: SqlBlock): SqlBlock = head concat r("?")

override fun visitExprValues(node: Expr.Values, head: SqlBlock) = head concat list("VALUES (") { node.rows }
override fun visitExprValues(node: Expr.Values, head: SqlBlock): SqlBlock = head concat list("VALUES (") { node.rows }

override fun visitExprValuesRow(node: Expr.Values.Row, head: SqlBlock) = head concat list { node.items }
override fun visitExprValuesRow(node: Expr.Values.Row, head: SqlBlock): SqlBlock = head concat list { node.items }

override fun visitExprCollection(node: Expr.Collection, head: SqlBlock): SqlBlock {
val (start, end) = when (node.type) {
Expand All @@ -337,7 +337,7 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
return head concat list(start, end) { node.values }
}

override fun visitExprStruct(node: Expr.Struct, head: SqlBlock) = head concat list("{", "}") { node.fields }
override fun visitExprStruct(node: Expr.Struct, head: SqlBlock): SqlBlock = head concat list("{", "}") { node.fields }

override fun visitExprStructField(node: Expr.Struct.Field, head: SqlBlock): SqlBlock {
var h = head
Expand Down Expand Up @@ -698,7 +698,7 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {

// LET

override fun visitLet(node: Let, head: SqlBlock) = head concat list("LET ", "") { node.bindings }
override fun visitLet(node: Let, head: SqlBlock): SqlBlock = head concat list("LET ", "") { node.bindings }

override fun visitLetBinding(node: Let.Binding, head: SqlBlock): SqlBlock {
var h = head
Expand Down Expand Up @@ -750,7 +750,7 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {

// ORDER BY

override fun visitOrderBy(node: OrderBy, head: SqlBlock) = head concat list("ORDER BY ", "") { node.sorts }
override fun visitOrderBy(node: OrderBy, head: SqlBlock): SqlBlock = head concat list("ORDER BY ", "") { node.sorts }

override fun visitSort(node: Sort, head: SqlBlock): SqlBlock {
var h = head
Expand Down
Loading