Skip to content

Commit

Permalink
Merge pull request #9536 from planetscale/more-evalengine
Browse files Browse the repository at this point in the history
feat: add IS to evalengine
  • Loading branch information
systay authored Jan 19, 2022
2 parents c11b97a + 88193e1 commit 2cc0d5f
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 31 deletions.
12 changes: 12 additions & 0 deletions go/vt/vtgate/evalengine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 33 additions & 1 deletion go/vt/vtgate/evalengine/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,36 @@ func convertLogicalExpr(opname string, left, right sqlparser.Expr, lookup Conver
}, nil
}

func convertIsExpr(left sqlparser.Expr, op sqlparser.IsExprOperator, lookup ConverterLookup) (Expr, error) {
expr, err := convertExpr(left, lookup)
if err != nil {
return nil, err
}

var check func(result *EvalResult) bool

switch op {
case sqlparser.IsNullOp:
check = func(er *EvalResult) bool { return er.null() }
case sqlparser.IsNotNullOp:
check = func(er *EvalResult) bool { return !er.null() }
case sqlparser.IsTrueOp:
check = func(er *EvalResult) bool { return er.truthy() == boolTrue }
case sqlparser.IsNotTrueOp:
check = func(er *EvalResult) bool { return er.truthy() != boolTrue }
case sqlparser.IsFalseOp:
check = func(er *EvalResult) bool { return er.truthy() == boolFalse }
case sqlparser.IsNotFalseOp:
check = func(er *EvalResult) bool { return er.truthy() != boolFalse }
}

return &IsExpr{
UnaryExpr: UnaryExpr{expr},
Op: op,
Check: check,
}, nil
}

func getCollation(expr sqlparser.Expr, lookup ConverterLookup) collations.TypedCollation {
collation := collations.TypedCollation{
Coercibility: collations.CoerceCoercible,
Expand Down Expand Up @@ -178,7 +208,7 @@ func convertExpr(e sqlparser.Expr, lookup ConverterLookup) (Expr, error) {
case sqlparser.IntVal:
return NewLiteralIntegralFromBytes(node.Bytes())
case sqlparser.FloatVal:
return NewLiteralRealFromBytes(node.Bytes())
return NewLiteralFloatFromBytes(node.Bytes())
case sqlparser.StrVal:
collation := getCollation(e, lookup)
return NewLiteralString(node.Bytes(), collation), nil
Expand Down Expand Up @@ -278,6 +308,8 @@ func convertExpr(e sqlparser.Expr, lookup ConverterLookup) (Expr, error) {
panic("character set introducers are only supported for literals and arguments")
}
return expr, nil
case *sqlparser.IsExpr:
return convertIsExpr(node.Left, node.Right, lookup)
}
return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "%s: %T", ErrConvertExprNotSupported, e)
}
72 changes: 64 additions & 8 deletions go/vt/vtgate/evalengine/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ func TestEvaluate(t *testing.T) {
expected sqltypes.Value
}

True := sqltypes.NewInt64(1)
False := sqltypes.NewInt64(0)
tests := []testCase{{
expression: "42",
expected: sqltypes.NewInt64(42),
Expand Down Expand Up @@ -172,40 +174,94 @@ func TestEvaluate(t *testing.T) {
expected: sqltypes.NewFloat64(2.2),
}, {
expression: "42 in (41, 42)",
expected: sqltypes.NewInt64(1),
expected: True,
}, {
expression: "42 in (41, 43)",
expected: sqltypes.NewInt64(0),
expected: False,
}, {
expression: "42 in (null, 41, 43)",
expected: NULL,
}, {
expression: "(1,2) in ((1,2), (2,3))",
expected: sqltypes.NewInt64(1),
expected: True,
}, {
expression: "(1,2) = (1,2)",
expected: sqltypes.NewInt64(1),
expected: True,
}, {
expression: "1 = 'sad'",
expected: sqltypes.NewInt64(0),
expected: False,
}, {
expression: "(1,2) = (1,3)",
expected: sqltypes.NewInt64(0),
expected: False,
}, {
expression: "(1,2) = (1,null)",
expected: NULL,
}, {
expression: "(1,2) in ((4,2), (2,3))",
expected: sqltypes.NewInt64(0),
expected: False,
}, {
expression: "(1,2) in ((1,null), (2,3))",
expected: NULL,
}, {
expression: "(1,(1,2,3),(1,(1,2),4),2) = (1,(1,2,3),(1,(1,2),4),2)",
expected: sqltypes.NewInt64(1),
expected: True,
}, {
expression: "(1,(1,2,3),(1,(1,NULL),4),2) = (1,(1,2,3),(1,(1,2),4),2)",
expected: NULL,
}, {
expression: "null is null",
expected: True,
}, {
expression: "true is null",
expected: False,
}, {
expression: "42 is null",
expected: False,
}, {
expression: "null is not null",
expected: False,
}, {
expression: "42 is not null",
expected: True,
}, {
expression: "true is not null",
expected: True,
}, {
expression: "null is true",
expected: False,
}, {
expression: "42 is true",
expected: True,
}, {
expression: "true is true",
expected: True,
}, {
expression: "null is false",
expected: False,
}, {
expression: "42 is false",
expected: False,
}, {
expression: "false is false",
expected: True,
}, {
expression: "null is not true",
expected: True,
}, {
expression: "42 is not true",
expected: False,
}, {
expression: "true is not true",
expected: False,
}, {
expression: "null is not false",
expected: True,
}, {
expression: "42 is not false",
expected: True,
}, {
expression: "false is not false",
expected: False,
}}

for _, test := range tests {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/eval_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ func (er *EvalResult) textual() bool {
return sqltypes.IsText(tt) || sqltypes.IsBinary(tt)
}

func (er *EvalResult) nonzero() boolean {
func (er *EvalResult) truthy() boolean {
switch er.type_ {
case sqltypes.Null:
return boolNULL
Expand Down
7 changes: 5 additions & 2 deletions go/vt/vtgate/evalengine/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ func (env *ExpressionEnv) typecheck(expr Expr) {
for _, subexpr := range expr {
env.typecheck(subexpr)
}

case *IsExpr:
env.ensureCardinality(expr.Inner, 1)
}
}

Expand Down Expand Up @@ -302,8 +305,8 @@ func NewLiteralFloat(val float64) Expr {
return lit
}

// NewLiteralRealFromBytes returns a float literal expression from a slice of bytes
func NewLiteralRealFromBytes(val []byte) (Expr, error) {
// NewLiteralFloatFromBytes returns a float literal expression from a slice of bytes
func NewLiteralFloatFromBytes(val []byte) (Expr, error) {
lit := &Literal{}
if bytes.IndexByte(val, 'e') >= 0 || bytes.IndexByte(val, 'E') >= 0 {
fval, err := strconv.ParseFloat(string(val), 64)
Expand Down
20 changes: 20 additions & 0 deletions go/vt/vtgate/evalengine/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"vitess.io/vitess/go/mysql/collations"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
)

func FormatExpr(expr Expr) string {
Expand Down Expand Up @@ -150,3 +151,22 @@ func (n *NotExpr) format(w *formatter, depth int) {
func (b *LogicalExpr) format(w *formatter, depth int) {
w.formatBinary(b.Left, b.opname, b.Right, depth)
}

func (i *IsExpr) format(w *formatter, depth int) {
w.Indent(depth)
i.Inner.format(w, depth)
switch i.Op {
case sqlparser.IsNullOp:
w.WriteString(" IS NULL")
case sqlparser.IsNotNullOp:
w.WriteString(" IS NOT NULL")
case sqlparser.IsTrueOp:
w.WriteString(" IS TRUE")
case sqlparser.IsNotTrueOp:
w.WriteString(" IS NOT TRUE")
case sqlparser.IsFalseOp:
w.WriteString(" IS FALSE")
case sqlparser.IsNotFalseOp:
w.WriteString(" IS NOT FALSE")
}
}
63 changes: 49 additions & 14 deletions go/vt/vtgate/evalengine/integration/comparison_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"fmt"
"strings"
"testing"

"vitess.io/vitess/go/mysql"
)

func perm(a []string, f func([]string)) {
Expand All @@ -39,6 +41,25 @@ func perm1(a []string, f func([]string), i int) {
}
}

func compareRemoteQuery(t *testing.T, conn *mysql.Conn, query string) {
t.Helper()

local, _, localErr := safeEvaluate(query)
if localErr != nil {
t.Errorf("local failure: %v", localErr)
return
}
remote, remoteErr := conn.ExecuteFetch(query, 1, false)
if remoteErr != nil {
t.Errorf("remote failure: %v", remoteErr)
return
}

if local.Value().String() != remote.Rows[0][0].String() {
t.Errorf("mismatch for query %q: local=%v, remote=%v", query, local.Value().String(), remote.Rows[0][0].String())
}
}

func TestAllComparisons(t *testing.T) {
var elems = []string{"NULL", "-1", "0", "1"}
var operators = []string{"=", "!=", "<=>", "<", "<=", ">", ">="}
Expand All @@ -56,22 +77,36 @@ func TestAllComparisons(t *testing.T) {
for i := 0; i < len(tuples); i++ {
for j := 0; j < len(tuples); j++ {
query := fmt.Sprintf("SELECT %s %s %s", tuples[i], op, tuples[j])
local, _, localErr := safeEvaluate(query)
if localErr != nil {
t.Errorf("local failure: %v", localErr)
continue
}
remote, remoteErr := conn.ExecuteFetch(query, 1, false)
if remoteErr != nil {
t.Errorf("remote failure: %v", remoteErr)
continue
}

if local.Value().String() != remote.Rows[0][0].String() {
t.Errorf("mismatch for query %q: local=%v, remote=%v", query, local.Value().String(), remote.Rows[0][0].String())
}
compareRemoteQuery(t, conn, query)
}
}
})
}
}

func TestAllIsStatements(t *testing.T) {
var left = []string{
"NULL", "TRUE", "FALSE",
`1`, `0`, `1.0`, `0.0`, `-1`, `666`,
`"1"`, `"0"`, `"1.0"`, `"0.0"`, `"-1"`, `"666"`,
`"POTATO"`, `""`, `" "`, `" "`,
}
var right = []string{
"NULL",
"NOT NULL",
"TRUE",
"NOT TRUE",
"FALSE",
"NOT FALSE",
}

var conn = mysqlconn(t)
defer conn.Close()

for _, l := range left {
for _, r := range right {
query := fmt.Sprintf("SELECT %s IS %s", l, r)
compareRemoteQuery(t, conn, query)
}
}
}
19 changes: 16 additions & 3 deletions go/vt/vtgate/evalengine/integration/fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ type (
dummyCollation collations.ID
)

var rhsOfIs = []string{
"null",
"not null",
"true",
"not true",
"false",
"not false",
}

func (g *gencase) arg(tuple bool) string {
if tuple || g.rand.Intn(g.ratioTuple) == 0 {
var exprs []string
Expand All @@ -63,7 +72,11 @@ func (g *gencase) arg(tuple bool) string {

func (g *gencase) expr() string {
op := g.operators[g.rand.Intn(len(g.operators))]
return fmt.Sprintf("%s %s %s", g.arg(false), op, g.arg(op == "IN" || op == "NOT IN"))
rhs := g.arg(op == "IN" || op == "NOT IN")
if op == "IS" {
rhs = rhsOfIs[g.rand.Intn(len(rhsOfIs))]
}
return fmt.Sprintf("%s %s %s", g.arg(false), op, rhs)
}

func (d dummyCollation) ColumnLookup(_ *sqlparser.ColName) (int, error) {
Expand Down Expand Up @@ -188,10 +201,10 @@ func TestGenerateFuzzCases(t *testing.T) {
ratioSubexpr: 8,
tupleLen: 4,
operators: []string{
"+", "-", "/", "*", "=", "!=", "<=>", "<", "<=", ">", ">=", "IN", "NOT IN", "LIKE", "NOT LIKE",
"+", "-", "/", "*", "=", "!=", "<=>", "<", "<=", ">", ">=", "IN", "NOT IN", "LIKE", "NOT LIKE", "IS",
},
primitives: []string{
"1", "0", "-1", `"foo"`, `"FOO"`, `"fOo"`, "NULL",
"1", "0", "-1", `"foo"`, `"FOO"`, `"fOo"`, "NULL", "12.0",
},
}

Expand Down
Loading

0 comments on commit 2cc0d5f

Please sign in to comment.