Skip to content

Commit

Permalink
evalengine: Implement NOT and logical operations
Browse files Browse the repository at this point in the history
This adds the NOT and generic logical operations to the compiler and
fixes a number of existing bugs in the evalengine. Specifically NOT is
currently broken as we don't translate it properly.

Some main issues are that we need to ensure lazy evaluation for the
logical operations but also needing it for arithmetic as well.

All cases where we've been pushing the boolean singleton value need to
be fixed as well in the compiler, because we inline update things in
arithmetic operations and we'd update the singleton value before.

It also needs to split parsing into JSON from using an argument as a
partial JSON value. This specifically is needed for CAST() with an input
string where it should parse it. Additionally, convering a JSON boolean
to numeric needs to create a floating point value, not an integer one.

Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink committed Mar 19, 2023
1 parent 00ccf1e commit 7b83dbb
Show file tree
Hide file tree
Showing 21 changed files with 581 additions and 106 deletions.
7 changes: 7 additions & 0 deletions go/vt/vtgate/evalengine/arena.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ func (a *Arena) newEvalDecimal(dec decimal.Decimal, m, d int32) *evalDecimal {
return a.newEvalDecimalWithPrec(dec.Clamp(m-d, d), d)
}

func (a *Arena) newEvalBool(b bool) *evalInt64 {
if b {
return a.newEvalInt64(1)
}
return a.newEvalInt64(0)
}

func (a *Arena) newEvalInt64(i int64) *evalInt64 {
if cap(a.aInt64) > len(a.aInt64) {
a.aInt64 = a.aInt64[:len(a.aInt64)+1]
Expand Down
6 changes: 6 additions & 0 deletions go/vt/vtgate/evalengine/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ func (c *compiler) compileExpr(expr Expr) (ctype, error) {
case *InExpr:
return c.compileIn(expr)

case *NotExpr:
return c.compileNot(expr)

case *LogicalExpr:
return c.compileLogical(expr)

case callable:
return c.compileFn(expr)

Expand Down
43 changes: 28 additions & 15 deletions go/vt/vtgate/evalengine/compiler_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,15 @@ func (c *compiler) compileArithmeticAdd(left, right Expr) (ctype, error) {
if err != nil {
return ctype{}, err
}
skip1 := c.compileNullCheck1(lt)

rt, err := c.compileExpr(right)
if err != nil {
return ctype{}, err
}

swap := false
skip := c.compileNullCheck2(lt, rt)
skip2 := c.compileNullCheck2(lt, rt)
lt = c.compileToNumeric(lt, 2)
rt = c.compileToNumeric(rt, 1)
lt, rt, swap = c.compileNumericPriority(lt, rt)
Expand Down Expand Up @@ -144,7 +145,8 @@ func (c *compiler) compileArithmeticAdd(left, right Expr) (ctype, error) {
sumtype = sqltypes.Float64
}

c.asm.jumpDestination(skip)
c.asm.jumpDestination(skip1)
c.asm.jumpDestination(skip2)
return ctype{Type: sumtype, Col: collationNumeric}, nil
}

Expand All @@ -153,13 +155,14 @@ func (c *compiler) compileArithmeticSub(left, right Expr) (ctype, error) {
if err != nil {
return ctype{}, err
}
skip1 := c.compileNullCheck1(lt)

rt, err := c.compileExpr(right)
if err != nil {
return ctype{}, err
}

skip := c.compileNullCheck2(lt, rt)
skip2 := c.compileNullCheck2(lt, rt)
lt = c.compileToNumeric(lt, 2)
rt = c.compileToNumeric(rt, 1)

Expand Down Expand Up @@ -221,7 +224,8 @@ func (c *compiler) compileArithmeticSub(left, right Expr) (ctype, error) {
panic("did not compile?")
}

c.asm.jumpDestination(skip)
c.asm.jumpDestination(skip1)
c.asm.jumpDestination(skip2)
return ctype{Type: subtype, Col: collationNumeric}, nil
}

Expand All @@ -230,14 +234,15 @@ func (c *compiler) compileArithmeticMul(left, right Expr) (ctype, error) {
if err != nil {
return ctype{}, err
}
skip1 := c.compileNullCheck1(lt)

rt, err := c.compileExpr(right)
if err != nil {
return ctype{}, err
}

swap := false
skip := c.compileNullCheck2(lt, rt)
skip2 := c.compileNullCheck2(lt, rt)
lt = c.compileToNumeric(lt, 2)
rt = c.compileToNumeric(rt, 1)
lt, rt, swap = c.compileNumericPriority(lt, rt)
Expand Down Expand Up @@ -274,7 +279,8 @@ func (c *compiler) compileArithmeticMul(left, right Expr) (ctype, error) {
multype = sqltypes.Decimal
}

c.asm.jumpDestination(skip)
c.asm.jumpDestination(skip1)
c.asm.jumpDestination(skip2)
return ctype{Type: multype, Col: collationNumeric}, nil
}

Expand All @@ -283,43 +289,47 @@ func (c *compiler) compileArithmeticDiv(left, right Expr) (ctype, error) {
if err != nil {
return ctype{}, err
}
skip1 := c.compileNullCheck1(lt)

rt, err := c.compileExpr(right)
if err != nil {
return ctype{}, err
}
skip2 := c.compileNullCheck2(lt, rt)

skip := c.compileNullCheck2(lt, rt)
lt = c.compileToNumeric(lt, 2)
rt = c.compileToNumeric(rt, 1)

ct := ctype{Col: collationNumeric, Flag: flagNullable}
if lt.Type == sqltypes.Float64 || rt.Type == sqltypes.Float64 {
ct.Type = sqltypes.Float64
c.compileToFloat(lt, 2)
c.compileToFloat(rt, 1)
c.asm.Div_ff()
c.asm.jumpDestination(skip)
return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil
} else {
ct.Type = sqltypes.Decimal
c.compileToDecimal(lt, 2)
c.compileToDecimal(rt, 1)
c.asm.Div_dd()
c.asm.jumpDestination(skip)
return ctype{Type: sqltypes.Decimal, Col: collationNumeric}, nil
}
c.asm.jumpDestination(skip1)
c.asm.jumpDestination(skip2)
return ct, nil
}

func (c *compiler) compileArithmeticIntDiv(left, right Expr) (ctype, error) {
lt, err := c.compileExpr(left)
if err != nil {
return ctype{}, err
}
skip1 := c.compileNullCheck1(lt)

rt, err := c.compileExpr(right)
if err != nil {
return ctype{}, err
}

skip := c.compileNullCheck2(lt, rt)
skip2 := c.compileNullCheck2(lt, rt)
lt = c.compileToNumeric(lt, 2)
rt = c.compileToNumeric(rt, 1)

Expand Down Expand Up @@ -383,7 +393,8 @@ func (c *compiler) compileArithmeticIntDiv(left, right Expr) (ctype, error) {
c.asm.IntDiv_di()
}
}
c.asm.jumpDestination(skip)
c.asm.jumpDestination(skip1)
c.asm.jumpDestination(skip2)
return ct, nil
}

Expand All @@ -392,13 +403,14 @@ func (c *compiler) compileArithmeticMod(left, right Expr) (ctype, error) {
if err != nil {
return ctype{}, err
}
skip1 := c.compileNullCheck1(lt)

rt, err := c.compileExpr(right)
if err != nil {
return ctype{}, err
}

skip := c.compileNullCheck2(lt, rt)
skip2 := c.compileNullCheck2(lt, rt)
lt = c.compileToNumeric(lt, 2)
rt = c.compileToNumeric(rt, 1)

Expand Down Expand Up @@ -453,6 +465,7 @@ func (c *compiler) compileArithmeticMod(left, right Expr) (ctype, error) {
c.asm.Mod_ff()
}

c.asm.jumpDestination(skip)
c.asm.jumpDestination(skip1)
c.asm.jumpDestination(skip2)
return ct, nil
}
Loading

0 comments on commit 7b83dbb

Please sign in to comment.