Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add operator node to AST and parser #1499

Merged
merged 2 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
291 changes: 172 additions & 119 deletions partiql-ast/api/partiql-ast.api

Large diffs are not rendered by default.

118 changes: 67 additions & 51 deletions partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt
Original file line number Diff line number Diff line change
Expand Up @@ -336,70 +336,86 @@ private class AstTranslator(val metas: Map<String, MetaContainer>) : AstBaseVisi
return aggregates.contains(this)
}

override fun visitExprUnary(node: Expr.Unary, ctx: Ctx) = translate(node) { metas ->
val arg = visitExpr(node.expr, ctx)
when (node.op) {
Expr.Unary.Op.NOT -> not(arg, metas)
Expr.Unary.Op.POS -> {
when {
arg !is PartiqlAst.Expr.Lit -> pos(arg)
arg.value is IntElement -> arg
arg.value is FloatElement -> arg
arg.value is DecimalElement -> arg
else -> pos(arg)
override fun visitExprOperator(node: Expr.Operator, ctx: Ctx) = translate(node) { metas ->
val lhs = node.lhs?.let { visitExpr(it, ctx) }
val rhs = visitExpr(node.rhs, ctx)
if (lhs == null) {
when (node.symbol) {
"+" -> {
when {
rhs !is PartiqlAst.Expr.Lit -> pos(rhs)
rhs.value is IntElement -> rhs
rhs.value is FloatElement -> rhs
rhs.value is DecimalElement -> rhs
else -> pos(rhs)
}
}
}
Expr.Unary.Op.NEG -> {
when {
arg !is PartiqlAst.Expr.Lit -> neg(arg, metas)
arg.value is IntElement -> {
val intValue = when (arg.value.integerSize) {
IntElementSize.LONG -> ionInt(-arg.value.longValue)
IntElementSize.BIG_INTEGER -> when (arg.value.bigIntegerValue) {
Long.MAX_VALUE.toBigInteger() + (1L).toBigInteger() -> ionInt(Long.MIN_VALUE)
else -> ionInt(arg.value.bigIntegerValue * BigInteger.valueOf(-1L))
"-" -> {
when {
rhs !is PartiqlAst.Expr.Lit -> neg(rhs, metas)
rhs.value is IntElement -> {
val intValue = when (rhs.value.integerSize) {
IntElementSize.LONG -> ionInt(-rhs.value.longValue)
IntElementSize.BIG_INTEGER -> when (rhs.value.bigIntegerValue) {
Long.MAX_VALUE.toBigInteger() + (1L).toBigInteger() -> ionInt(Long.MIN_VALUE)
else -> ionInt(rhs.value.bigIntegerValue * BigInteger.valueOf(-1L))
}
}
rhs.copy(
value = intValue.asAnyElement(),
metas = metas,
)
}
arg.copy(
value = intValue.asAnyElement(),
rhs.value is FloatElement -> rhs.copy(
value = ionFloat(-(rhs.value.doubleValue)).asAnyElement(),
metas = metas,
)
rhs.value is DecimalElement -> rhs.copy(
value = ionDecimal(Decimal.valueOf(-(rhs.value.decimalValue))).asAnyElement(),
metas = metas,
)
else -> neg(rhs, metas)
}
arg.value is FloatElement -> arg.copy(
value = ionFloat(-(arg.value.doubleValue)).asAnyElement(),
metas = metas,
)
arg.value is DecimalElement -> arg.copy(
value = ionDecimal(Decimal.valueOf(-(arg.value.decimalValue))).asAnyElement(),
metas = metas,
)
else -> neg(arg, metas)
}
else -> error("unsupported unary expr operator $node")
}
} else {
val operands = listOf(lhs, rhs)
when (node.symbol) {
"+" -> plus(operands, metas)
"-" -> minus(operands, metas)
"*" -> times(operands, metas)
"/" -> divide(operands, metas)
"%" -> modulo(operands, metas)
"||" -> concat(operands, metas)
"=" -> eq(operands, metas)
"<>" -> ne(operands, metas)
"!=" -> ne(operands, metas)
">" -> gt(operands, metas)
">=" -> gte(operands, metas)
"<" -> lt(operands, metas)
"<=" -> lte(operands, metas)
"&" -> bitwiseAnd(operands, metas)
else -> error("unsupported binary expr operator $node")
}
}
}

override fun visitExprBinary(node: Expr.Binary, ctx: Ctx) = translate(node) { metas ->
override fun visitExprAnd(node: Expr.And, ctx: Ctx) = translate(node) { metas ->
val lhs = visitExpr(node.lhs, ctx)
val rhs = visitExpr(node.rhs, ctx)
val operands = listOf(lhs, rhs)
when (node.op) {
Expr.Binary.Op.PLUS -> plus(operands, metas)
Expr.Binary.Op.MINUS -> minus(operands, metas)
Expr.Binary.Op.TIMES -> times(operands, metas)
Expr.Binary.Op.DIVIDE -> divide(operands, metas)
Expr.Binary.Op.MODULO -> modulo(operands, metas)
Expr.Binary.Op.CONCAT -> concat(operands, metas)
Expr.Binary.Op.AND -> and(operands, metas)
Expr.Binary.Op.OR -> or(operands, metas)
Expr.Binary.Op.EQ -> eq(operands, metas)
Expr.Binary.Op.NE -> ne(operands, metas)
Expr.Binary.Op.GT -> gt(operands, metas)
Expr.Binary.Op.GTE -> gte(operands, metas)
Expr.Binary.Op.LT -> lt(operands, metas)
Expr.Binary.Op.LTE -> lte(operands, metas)
Expr.Binary.Op.BITWISE_AND -> bitwiseAnd(operands, metas)
}
and(lhs, rhs)
}

override fun visitExprOr(node: Expr.Or, ctx: Ctx) = translate(node) { metas ->
val lhs = visitExpr(node.lhs, ctx)
val rhs = visitExpr(node.rhs, ctx)
or(lhs, rhs)
}

override fun visitExprNot(node: Expr.Not, ctx: Ctx) = translate(node) { metas ->
val rhs = visitExpr(node.value, ctx)
not(rhs)
}

override fun visitExprPath(node: Expr.Path, ctx: Ctx) = translate(node) { metas ->
Expand Down
57 changes: 30 additions & 27 deletions partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt
Original file line number Diff line number Diff line change
Expand Up @@ -229,44 +229,47 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
return head concat r("`$value`")
}

override fun visitExprUnary(node: Expr.Unary, head: SqlBlock): SqlBlock {
val op = when (node.op) {
Expr.Unary.Op.NOT -> "NOT ("
Expr.Unary.Op.POS -> "+("
Expr.Unary.Op.NEG -> "-("
override fun visitExprOperator(node: Expr.Operator, head: SqlBlock): SqlBlock {
val lhs = node.lhs
return if (lhs != null) {
var h = head
h = visitExprWrapped(node.lhs, h)
h = h concat r(" ${node.symbol} ")
h = visitExprWrapped(node.rhs, h)
h
} else {
var h = head
h = h concat r(node.symbol + "(")
h = visitExprWrapped(node.rhs, h)
h = h concat r(")")
return h
}
}

override fun visitExprAnd(node: Expr.And, head: SqlBlock): SqlBlock {
var h = head
h = h concat r(op)
h = visitExprWrapped(node.expr, h)
h = h concat r(")")
h = visitExprWrapped(node.lhs, h)
h = h concat r(" AND ")
h = visitExprWrapped(node.rhs, h)
return h
}

override fun visitExprBinary(node: Expr.Binary, head: SqlBlock): SqlBlock {
val op = when (node.op) {
Expr.Binary.Op.PLUS -> "+"
Expr.Binary.Op.MINUS -> "-"
Expr.Binary.Op.TIMES -> "*"
Expr.Binary.Op.DIVIDE -> "/"
Expr.Binary.Op.MODULO -> "%"
Expr.Binary.Op.CONCAT -> "||"
Expr.Binary.Op.AND -> "AND"
Expr.Binary.Op.OR -> "OR"
Expr.Binary.Op.EQ -> "="
Expr.Binary.Op.NE -> "<>"
Expr.Binary.Op.GT -> ">"
Expr.Binary.Op.GTE -> ">="
Expr.Binary.Op.LT -> "<"
Expr.Binary.Op.LTE -> "<="
Expr.Binary.Op.BITWISE_AND -> "&"
}
override fun visitExprOr(node: Expr.Or, head: SqlBlock): SqlBlock {
var h = head
h = visitExprWrapped(node.lhs, h)
h = h concat r(" $op ")
h = h concat r(" OR ")
h = visitExprWrapped(node.rhs, h)
return h
}

override fun visitExprNot(node: Expr.Not, head: SqlBlock): SqlBlock {
var h = head
h = h concat r("NOT (")
h = visitExprWrapped(node.value, h)
h = h concat r(")")
return h
}

override fun visitExprVar(node: Expr.Var, head: SqlBlock): SqlBlock {
var h = head
// Prepend @
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,44 +255,47 @@ internal abstract class InternalSqlDialect : AstBaseVisitor<InternalSqlBlock, In
return tail concat "`$value`"
}

override fun visitExprUnary(node: Expr.Unary, tail: InternalSqlBlock): InternalSqlBlock {
val op = when (node.op) {
Expr.Unary.Op.NOT -> "NOT ("
Expr.Unary.Op.POS -> "+("
Expr.Unary.Op.NEG -> "-("
override fun visitExprOperator(node: Expr.Operator, tail: InternalSqlBlock): InternalSqlBlock {
val lhs = node.lhs
return if (lhs != null) {
var t = tail
t = visitExprWrapped(node.lhs, t)
t = t concat " ${node.symbol} "
t = visitExprWrapped(node.rhs, t)
t
} else {
var t = tail
t = t concat node.symbol + "("
t = visitExprWrapped(node.rhs, t)
t = t concat ")"
return t
}
}

override fun visitExprAnd(node: Expr.And, tail: InternalSqlBlock): InternalSqlBlock {
var t = tail
t = t concat op
t = visitExprWrapped(node.expr, t)
t = t concat ")"
t = visitExprWrapped(node.lhs, t)
t = t concat " AND "
t = visitExprWrapped(node.rhs, t)
return t
}

override fun visitExprBinary(node: Expr.Binary, tail: InternalSqlBlock): InternalSqlBlock {
val op = when (node.op) {
Expr.Binary.Op.PLUS -> "+"
Expr.Binary.Op.MINUS -> "-"
Expr.Binary.Op.TIMES -> "*"
Expr.Binary.Op.DIVIDE -> "/"
Expr.Binary.Op.MODULO -> "%"
Expr.Binary.Op.CONCAT -> "||"
Expr.Binary.Op.AND -> "AND"
Expr.Binary.Op.OR -> "OR"
Expr.Binary.Op.EQ -> "="
Expr.Binary.Op.NE -> "<>"
Expr.Binary.Op.GT -> ">"
Expr.Binary.Op.GTE -> ">="
Expr.Binary.Op.LT -> "<"
Expr.Binary.Op.LTE -> "<="
Expr.Binary.Op.BITWISE_AND -> "&"
}
override fun visitExprOr(node: Expr.Or, tail: InternalSqlBlock): InternalSqlBlock {
var t = tail
t = visitExprWrapped(node.lhs, t)
t = t concat " $op "
t = t concat " OR "
t = visitExprWrapped(node.rhs, t)
return t
}

override fun visitExprNot(node: Expr.Not, tail: InternalSqlBlock): InternalSqlBlock {
var t = tail
t = t concat "NOT ("
t = visitExprWrapped(node.value, t)
t = t concat ")"
return t
}

override fun visitExprVar(node: Expr.Var, tail: InternalSqlBlock): InternalSqlBlock {
var t = tail
// Prepend @
Expand Down
29 changes: 18 additions & 11 deletions partiql-ast/src/main/resources/partiql_ast.ion
Original file line number Diff line number Diff line change
Expand Up @@ -354,19 +354,26 @@ expr::[
index: int,
},

// Unary Operators
unary::{
op: [ NOT, POS, NEG ],
expr: expr,
// Operator expr node
operator::{
symbol: string,
lhs: optional::expr,
rhs: expr
},

// Binary Operators
binary::{
op: [
PLUS, MINUS, TIMES, DIVIDE, MODULO, CONCAT, BITWISE_AND,
AND, OR,
EQ, NE, GT, GTE, LT, LTE,
],
// SQL special form `NOT`
not::{
value: expr,
},

// SQL special form `AND`
and::{
lhs: expr,
rhs: expr,
},

// SQL special form `OR`
or::{
lhs: expr,
rhs: expr,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,27 +286,26 @@ class ToLegacyAstTest {
@JvmStatic
fun operators() = listOf(
expect("(not (lit null))") {
exprUnary {
op = Expr.Unary.Op.NOT
expr = NULL
exprNot {
value = NULL
}
},
expect("(pos (lit null))") {
exprUnary {
op = Expr.Unary.Op.POS
expr = NULL
exprOperator {
symbol = "+"
rhs = NULL
}
},
expect("(neg (lit null))") {
exprUnary {
op = Expr.Unary.Op.NEG
expr = NULL
exprOperator {
symbol = "-"
rhs = NULL
}
},
// we don't really need to test _all_ binary operators
expect("(plus (lit null) (lit null))") {
exprBinary {
op = Expr.Binary.Op.PLUS
exprOperator {
symbol = "+"
lhs = NULL
rhs = NULL
}
Expand Down
Loading
Loading