From 00ccf1e471dfe4392c33279c3c57fcb1a26f1645 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Fri, 17 Mar 2023 16:43:31 +0100 Subject: [PATCH 1/4] evalengine: Implement integer division and modulo Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/arithmetic.go | 489 ++++++++++++++---- .../vtgate/evalengine/compiler_arithmetic.go | 161 +++++- go/vt/vtgate/evalengine/compiler_asm.go | 197 +++++++ go/vt/vtgate/evalengine/expr_arithmetic.go | 39 +- .../evalengine/internal/decimal/decimal.go | 12 +- .../internal/decimal/decimal_test.go | 6 +- go/vt/vtgate/evalengine/testcases/cases.go | 2 +- go/vt/vtgate/evalengine/translate.go | 4 + 8 files changed, 786 insertions(+), 124 deletions(-) diff --git a/go/vt/vtgate/evalengine/arithmetic.go b/go/vt/vtgate/evalengine/arithmetic.go index 5b6c0ee538c..89ecb6fd22f 100644 --- a/go/vt/vtgate/evalengine/arithmetic.go +++ b/go/vt/vtgate/evalengine/arithmetic.go @@ -26,12 +26,17 @@ import ( "vitess.io/vitess/go/sqltypes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine/internal/decimal" ) func dataOutOfRangeError[N1, N2 constraints.Integer | constraints.Float](v1 N1, v2 N2, typ, sign string) error { return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "%s value is out of range in '(%v %s %v)'", typ, v1, sign, v2) } +func dataOutOfRangeErrorDecimal(v1 decimal.Decimal, v2 decimal.Decimal, typ, sign string) error { + return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "%s value is out of range in '(%v %s %v)'", typ, v1.String(), sign, v2.String()) +} + func addNumericWithError(left, right eval) (eval, error) { v1, v2 := makeNumericAndPrioritize(left, right) switch v1 := v1.(type) { @@ -127,6 +132,108 @@ func divideNumericWithError(left, right eval, precise bool) (eval, error) { return mathDiv_xx(v1, v2, divPrecisionIncrement) } +func integerDivideNumericWithError(left, right eval, precise bool) (eval, error) { + v1 := evalToNumeric(left) + v2 := evalToNumeric(right) + switch v1 := v1.(type) { + case *evalInt64: + switch v2 := v2.(type) { + case *evalInt64: + return mathIntDiv_ii(v1, v2) + case *evalUint64: + return mathIntDiv_iu(v1, v2) + case *evalFloat: + return mathIntDiv_di(v1.toDecimal(0, 0), v2.toDecimal(0, 0)) + case *evalDecimal: + return mathIntDiv_di(v1.toDecimal(0, 0), v2) + } + case *evalUint64: + switch v2 := v2.(type) { + case *evalInt64: + return mathIntDiv_ui(v1, v2) + case *evalUint64: + return mathIntDiv_uu(v1, v2) + case *evalFloat: + return mathIntDiv_du(v1.toDecimal(0, 0), v2.toDecimal(0, 0)) + case *evalDecimal: + return mathIntDiv_du(v1.toDecimal(0, 0), v2) + } + case *evalFloat: + switch v2 := v2.(type) { + case *evalUint64: + return mathIntDiv_du(v1.toDecimal(0, 0), v2.toDecimal(0, 0)) + default: + return mathIntDiv_di(v1.toDecimal(0, 0), v2.toDecimal(0, 0)) + } + case *evalDecimal: + switch v2 := v2.(type) { + case *evalUint64: + return mathIntDiv_du(v1, v2.toDecimal(0, 0)) + default: + return mathIntDiv_di(v1, v2.toDecimal(0, 0)) + } + } + + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", evalToSQLValue(v1), evalToSQLValue(v2)) +} + +func modNumericWithError(left, right eval, precise bool) (eval, error) { + v1 := evalToNumeric(left) + v2 := evalToNumeric(right) + + switch v1 := v1.(type) { + case *evalInt64: + switch v2 := v2.(type) { + case *evalInt64: + return mathMod_ii(v1, v2) + case *evalUint64: + return mathMod_iu(v1, v2) + case *evalFloat: + v1f, ok := v1.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathMod_ff(v1f, v2) + case *evalDecimal: + return mathMod_dd(v1.toDecimal(0, 0), v2) + } + case *evalUint64: + switch v2 := v2.(type) { + case *evalInt64: + return mathMod_ui(v1, v2) + case *evalUint64: + return mathMod_uu(v1, v2) + case *evalFloat: + v1f, ok := v1.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathMod_ff(v1f, v2) + case *evalDecimal: + return mathMod_dd(v1.toDecimal(0, 0), v2) + } + case *evalDecimal: + switch v2 := v2.(type) { + case *evalFloat: + v1f, ok := v1.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathMod_ff(v1f, v2) + default: + return mathMod_dd(v1, v2.toDecimal(0, 0)) + } + case *evalFloat: + v2f, ok := v2.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathMod_ff(v1, v2f) + } + + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", evalToSQLValue(v1), evalToSQLValue(v2)) +} + // makeNumericAndPrioritize reorders the input parameters // to be Float64, Decimal, Uint64, Int64. func makeNumericAndPrioritize(left, right eval) (evalNumeric, evalNumeric) { @@ -162,28 +269,68 @@ func mathAdd_ii0(v1, v2 int64) (int64, error) { return result, nil } -func mathSub_ii(v1, v2 int64) (*evalInt64, error) { - result, err := mathSub_ii0(v1, v2) - return newEvalInt64(result), err +func mathAdd_ui(v1 uint64, v2 int64) (*evalUint64, error) { + result, err := mathAdd_ui0(v1, v2) + return newEvalUint64(result), err } -func mathSub_ii0(v1, v2 int64) (int64, error) { - result := v1 - v2 - if (result < v1) != (v2 > 0) { - return 0, dataOutOfRangeError(v1, v2, "BIGINT", "-") +func mathAdd_ui0(v1 uint64, v2 int64) (uint64, error) { + result := v1 + uint64(v2) + if v2 < 0 && v1 < uint64(-v2) || v2 > 0 && (result < v1 || result < uint64(v2)) { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+") } return result, nil } -func mathMul_ii(v1, v2 int64) (*evalInt64, error) { - result, err := mathMul_ii0(v1, v2) +func mathAdd_uu(v1, v2 uint64) (*evalUint64, error) { + result, err := mathAdd_uu0(v1, v2) + return newEvalUint64(result), err +} + +func mathAdd_uu0(v1, v2 uint64) (uint64, error) { + result := v1 + v2 + if result < v1 || result < v2 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+") + } + return result, nil +} + +var errDecimalOutOfRange = vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "DECIMAL value is out of range") + +func mathAdd_fx(v1 float64, v2 evalNumeric) (*evalFloat, error) { + v2f, ok := v2.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathAdd_ff(v1, v2f.f), nil +} + +func mathAdd_ff(v1, v2 float64) *evalFloat { + return newEvalFloat(v1 + v2) +} + +func mathAdd_dx(v1 *evalDecimal, v2 evalNumeric) *evalDecimal { + return mathAdd_dd(v1, v2.toDecimal(0, 0)) +} + +func mathAdd_dd(v1, v2 *evalDecimal) *evalDecimal { + return newEvalDecimalWithPrec(v1.dec.Add(v2.dec), maxprec(v1.length, v2.length)) +} + +func mathAdd_dd0(v1, v2 *evalDecimal) { + v1.dec = v1.dec.Add(v2.dec) + v1.length = maxprec(v1.length, v2.length) +} + +func mathSub_ii(v1, v2 int64) (*evalInt64, error) { + result, err := mathSub_ii0(v1, v2) return newEvalInt64(result), err } -func mathMul_ii0(v1, v2 int64) (int64, error) { - result := v1 * v2 - if v1 != 0 && result/v1 != v2 { - return 0, dataOutOfRangeError(v1, v2, "BIGINT", "*") +func mathSub_ii0(v1, v2 int64) (int64, error) { + result := v1 - v2 + if (result < v1) != (v2 > 0) { + return 0, dataOutOfRangeError(v1, v2, "BIGINT", "-") } return result, nil } @@ -200,19 +347,6 @@ func mathSub_iu0(v1 int64, v2 uint64) (uint64, error) { return mathSub_uu0(uint64(v1), v2) } -func mathAdd_ui(v1 uint64, v2 int64) (*evalUint64, error) { - result, err := mathAdd_ui0(v1, v2) - return newEvalUint64(result), err -} - -func mathAdd_ui0(v1 uint64, v2 int64) (uint64, error) { - result := v1 + uint64(v2) - if v2 < 0 && v1 < uint64(-v2) || v2 > 0 && (result < v1 || result < uint64(v2)) { - return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+") - } - return result, nil -} - func mathSub_ui(v1 uint64, v2 int64) (*evalUint64, error) { result, err := mathSub_ui0(v1, v2) return newEvalUint64(result), err @@ -229,45 +363,82 @@ func mathSub_ui0(v1 uint64, v2 int64) (uint64, error) { return mathSub_uu0(v1, uint64(v2)) } -func mathMul_ui(v1 uint64, v2 int64) (*evalUint64, error) { - result, err := mathMul_ui0(v1, v2) +func mathSub_uu(v1, v2 uint64) (*evalUint64, error) { + result, err := mathSub_uu0(v1, v2) return newEvalUint64(result), err } -func mathMul_ui0(v1 uint64, v2 int64) (uint64, error) { - if v1 == 0 || v2 == 0 { - return 0, nil +func mathSub_uu0(v1, v2 uint64) (uint64, error) { + result := v1 - v2 + if v2 > v1 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-") } - if v2 < 0 { - return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "*") + return result, nil +} + +func mathSub_fx(v1 float64, v2 evalNumeric) (*evalFloat, error) { + v2f, ok := v2.toFloat() + if !ok { + return nil, errDecimalOutOfRange } - return mathMul_uu0(v1, uint64(v2)) + return mathSub_ff(v1, v2f.f), nil } -func mathAdd_uu(v1, v2 uint64) (*evalUint64, error) { - result, err := mathAdd_uu0(v1, v2) - return newEvalUint64(result), err +func mathSub_xf(v1 evalNumeric, v2 float64) (*evalFloat, error) { + v1f, ok := v1.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathSub_ff(v1f.f, v2), nil } -func mathAdd_uu0(v1, v2 uint64) (uint64, error) { - result := v1 + v2 - if result < v1 || result < v2 { - return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+") +func mathSub_ff(v1, v2 float64) *evalFloat { + return newEvalFloat(v1 - v2) +} + +func mathSub_dx(v1 *evalDecimal, v2 evalNumeric) *evalDecimal { + return mathSub_dd(v1, v2.toDecimal(0, 0)) +} + +func mathSub_xd(v1 evalNumeric, v2 *evalDecimal) *evalDecimal { + return mathSub_dd(v1.toDecimal(0, 0), v2) +} + +func mathSub_dd(v1, v2 *evalDecimal) *evalDecimal { + return newEvalDecimalWithPrec(v1.dec.Sub(v2.dec), maxprec(v1.length, v2.length)) +} + +func mathSub_dd0(v1, v2 *evalDecimal) { + v1.dec = v1.dec.Sub(v2.dec) + v1.length = maxprec(v1.length, v2.length) +} + +func mathMul_ii(v1, v2 int64) (*evalInt64, error) { + result, err := mathMul_ii0(v1, v2) + return newEvalInt64(result), err +} + +func mathMul_ii0(v1, v2 int64) (int64, error) { + result := v1 * v2 + if v1 != 0 && result/v1 != v2 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT", "*") } return result, nil } -func mathSub_uu(v1, v2 uint64) (*evalUint64, error) { - result, err := mathSub_uu0(v1, v2) +func mathMul_ui(v1 uint64, v2 int64) (*evalUint64, error) { + result, err := mathMul_ui0(v1, v2) return newEvalUint64(result), err } -func mathSub_uu0(v1, v2 uint64) (uint64, error) { - result := v1 - v2 - if v2 > v1 { - return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-") +func mathMul_ui0(v1 uint64, v2 int64) (uint64, error) { + if v1 == 0 || v2 == 0 { + return 0, nil } - return result, nil + if v2 < 0 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "*") + } + return mathMul_uu0(v1, uint64(v2)) } func mathMul_uu(v1, v2 uint64) (*evalUint64, error) { @@ -286,28 +457,6 @@ func mathMul_uu0(v1, v2 uint64) (uint64, error) { return result, nil } -var errDecimalOutOfRange = vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "DECIMAL value is out of range") - -func mathAdd_fx(v1 float64, v2 evalNumeric) (*evalFloat, error) { - v2f, ok := v2.toFloat() - if !ok { - return nil, errDecimalOutOfRange - } - return mathAdd_ff(v1, v2f.f), nil -} - -func mathAdd_ff(v1, v2 float64) *evalFloat { - return newEvalFloat(v1 + v2) -} - -func mathSub_fx(v1 float64, v2 evalNumeric) (*evalFloat, error) { - v2f, ok := v2.toFloat() - if !ok { - return nil, errDecimalOutOfRange - } - return mathSub_ff(v1, v2f.f), nil -} - func mathMul_fx(v1 float64, v2 evalNumeric) (eval, error) { v2f, ok := v2.toFloat() if !ok { @@ -327,36 +476,6 @@ func maxprec(a, b int32) int32 { return b } -func mathAdd_dx(v1 *evalDecimal, v2 evalNumeric) *evalDecimal { - return mathAdd_dd(v1, v2.toDecimal(0, 0)) -} - -func mathAdd_dd(v1, v2 *evalDecimal) *evalDecimal { - return newEvalDecimalWithPrec(v1.dec.Add(v2.dec), maxprec(v1.length, v2.length)) -} - -func mathAdd_dd0(v1, v2 *evalDecimal) { - v1.dec = v1.dec.Add(v2.dec) - v1.length = maxprec(v1.length, v2.length) -} - -func mathSub_dx(v1 *evalDecimal, v2 evalNumeric) *evalDecimal { - return mathSub_dd(v1, v2.toDecimal(0, 0)) -} - -func mathSub_xd(v1 evalNumeric, v2 *evalDecimal) *evalDecimal { - return mathSub_dd(v1.toDecimal(0, 0), v2) -} - -func mathSub_dd(v1, v2 *evalDecimal) *evalDecimal { - return newEvalDecimalWithPrec(v1.dec.Sub(v2.dec), maxprec(v1.length, v2.length)) -} - -func mathSub_dd0(v1, v2 *evalDecimal) { - v1.dec = v1.dec.Sub(v2.dec) - v1.length = maxprec(v1.length, v2.length) -} - func mathMul_dx(v1 *evalDecimal, v2 evalNumeric) *evalDecimal { return mathMul_dd(v1, v2.toDecimal(0, 0)) } @@ -413,16 +532,174 @@ func mathDiv_ff0(v1, v2 float64) (float64, error) { return result, nil } -func mathSub_xf(v1 evalNumeric, v2 float64) (*evalFloat, error) { - v1f, ok := v1.toFloat() +func mathIntDiv_ii(v1, v2 *evalInt64) (eval, error) { + if v2.i == 0 { + return nil, nil + } + result := v1.i / v2.i + return newEvalInt64(result), nil +} + +func mathIntDiv_iu(v1 *evalInt64, v2 *evalUint64) (eval, error) { + if v2.u == 0 { + return nil, nil + } + result, err := mathIntDiv_iu0(v1.i, v2.u) + return newEvalUint64(result), err +} + +func mathIntDiv_iu0(v1 int64, v2 uint64) (uint64, error) { + if v1 < 0 { + if v2 >= math.MaxInt64 { + // We know here that v2 is always so large the result + // must be 0. + return 0, nil + } + result := v1 / int64(v2) + if result < 0 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "DIV") + } + return uint64(result), nil + + } + return uint64(v1) / v2, nil +} + +func mathIntDiv_ui(v1 *evalUint64, v2 *evalInt64) (eval, error) { + if v2.i == 0 { + return nil, nil + } + result, err := mathIntDiv_ui0(v1.u, v2.i) + return newEvalUint64(result), err +} + +func mathIntDiv_ui0(v1 uint64, v2 int64) (uint64, error) { + if v2 < 0 { + if v1 >= math.MaxInt64 { + // We know that v1 is always large here and with v2, the result + // must be at least -1 so we can't store this in the available range. + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "DIV") + } + // Safe to cast since we know it fits in int64 when we get here. + result := int64(v1) / v2 + if result < 0 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "DIV") + } + return uint64(result), nil + } + return v1 / uint64(v2), nil +} + +func mathIntDiv_uu(v1, v2 *evalUint64) (eval, error) { + if v2.u == 0 { + return nil, nil + } + return newEvalUint64(v1.u / v2.u), nil +} + +func mathIntDiv_di(v1, v2 *evalDecimal) (eval, error) { + if v2.dec.IsZero() { + return nil, nil + } + result, err := mathIntDiv_di0(v1, v2) + return newEvalInt64(result), err +} + +func mathIntDiv_di0(v1, v2 *evalDecimal) (int64, error) { + div, _ := v1.dec.QuoRem(v2.dec, 0) + result, ok := div.Int64() if !ok { - return nil, errDecimalOutOfRange + return 0, dataOutOfRangeErrorDecimal(v1.dec, v2.dec, "BIGINT", "DIV") } - return mathSub_ff(v1f.f, v2), nil + return result, nil } -func mathSub_ff(v1, v2 float64) *evalFloat { - return newEvalFloat(v1 - v2) +func mathIntDiv_du(v1, v2 *evalDecimal) (eval, error) { + if v2.dec.IsZero() { + return nil, nil + } + result, err := mathIntDiv_du0(v1, v2) + return newEvalUint64(result), err +} + +func mathIntDiv_du0(v1, v2 *evalDecimal) (uint64, error) { + div, _ := v1.dec.QuoRem(v2.dec, 0) + result, ok := div.Uint64() + if !ok { + return 0, dataOutOfRangeErrorDecimal(v1.dec, v2.dec, "BIGINT UNSIGNED", "DIV") + } + return result, nil +} + +func mathMod_ii(v1, v2 *evalInt64) (eval, error) { + if v2.i == 0 { + return nil, nil + } + return newEvalInt64(v1.i % v2.i), nil +} + +func mathMod_iu(v1 *evalInt64, v2 *evalUint64) (eval, error) { + if v2.u == 0 { + return nil, nil + } + return newEvalInt64(mathMod_iu0(v1.i, v2.u)), nil +} + +func mathMod_iu0(v1 int64, v2 uint64) int64 { + if v1 == math.MinInt64 && v2 == math.MaxInt64+1 { + return 0 + } + if v2 > math.MaxInt64 { + return v1 + } + return v1 % int64(v2) +} + +func mathMod_ui(v1 *evalUint64, v2 *evalInt64) (eval, error) { + if v2.i == 0 { + return nil, nil + } + result, err := mathMod_ui0(v1.u, v2.i) + return newEvalUint64(result), err +} + +func mathMod_ui0(v1 uint64, v2 int64) (uint64, error) { + if v2 < 0 { + return v1 % uint64(-v2), nil + } + return v1 % uint64(v2), nil +} + +func mathMod_uu(v1, v2 *evalUint64) (eval, error) { + if v2.u == 0 { + return nil, nil + } + return newEvalUint64(v1.u % v2.u), nil +} + +func mathMod_ff(v1, v2 *evalFloat) (eval, error) { + if v2.f == 0.0 { + return nil, nil + } + return newEvalFloat(math.Mod(v1.f, v2.f)), nil +} + +func mathMod_dd(v1, v2 *evalDecimal) (eval, error) { + if v2.dec.IsZero() { + return nil, nil + } + + dec, prec := mathMod_dd0(v1, v2) + return newEvalDecimalWithPrec(dec, prec), nil +} + +func mathMod_dd0(v1, v2 *evalDecimal) (decimal.Decimal, int32) { + length := v1.length + if v2.length > length { + length = v2.length + } + _, rem := v1.dec.QuoRem(v2.dec, 0) + return rem, length } func parseStringToFloat(str string) float64 { diff --git a/go/vt/vtgate/evalengine/compiler_arithmetic.go b/go/vt/vtgate/evalengine/compiler_arithmetic.go index a81afacc53a..772ca623c50 100644 --- a/go/vt/vtgate/evalengine/compiler_arithmetic.go +++ b/go/vt/vtgate/evalengine/compiler_arithmetic.go @@ -16,7 +16,11 @@ limitations under the License. package evalengine -import "vitess.io/vitess/go/sqltypes" +import ( + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" +) func (c *compiler) compileNegate(expr *NegateExpr) (ctype, error) { arg, err := c.compileExpr(expr.Inner) @@ -64,8 +68,12 @@ func (c *compiler) compileArithmetic(expr *ArithmeticExpr) (ctype, error) { return c.compileArithmeticMul(expr.Left, expr.Right) case *opArithDiv: return c.compileArithmeticDiv(expr.Left, expr.Right) + case *opArithIntDiv: + return c.compileArithmeticIntDiv(expr.Left, expr.Right) + case *opArithMod: + return c.compileArithmeticMod(expr.Left, expr.Right) default: - panic("unexpected arithmetic operator") + return ctype{}, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented") } } @@ -299,3 +307,152 @@ func (c *compiler) compileArithmeticDiv(left, right Expr) (ctype, error) { return ctype{Type: sqltypes.Decimal, Col: collationNumeric}, nil } } + +func (c *compiler) compileArithmeticIntDiv(left, right Expr) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck2(lt, rt) + lt = c.compileToNumeric(lt, 2) + rt = c.compileToNumeric(rt, 1) + + ct := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable} + switch lt.Type { + case sqltypes.Int64: + switch rt.Type { + case sqltypes.Int64: + c.asm.IntDiv_ii() + case sqltypes.Uint64: + ct.Type = sqltypes.Uint64 + c.asm.IntDiv_iu() + case sqltypes.Float64: + c.asm.Convert_xd(2, 0, 0) + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_di() + case sqltypes.Decimal: + c.asm.Convert_xd(2, 0, 0) + c.asm.IntDiv_di() + } + case sqltypes.Uint64: + switch rt.Type { + case sqltypes.Int64: + c.asm.IntDiv_ui() + case sqltypes.Uint64: + ct.Type = sqltypes.Uint64 + c.asm.IntDiv_uu() + case sqltypes.Float64: + c.asm.Convert_xd(2, 0, 0) + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_du() + case sqltypes.Decimal: + c.asm.Convert_xd(2, 0, 0) + c.asm.IntDiv_du() + } + case sqltypes.Float64: + switch rt.Type { + case sqltypes.Decimal: + c.asm.Convert_xd(2, 0, 0) + c.asm.IntDiv_di() + case sqltypes.Uint64: + ct.Type = sqltypes.Uint64 + c.asm.Convert_xd(2, 0, 0) + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_du() + default: + c.asm.Convert_xd(2, 0, 0) + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_di() + } + case sqltypes.Decimal: + switch rt.Type { + case sqltypes.Decimal: + c.asm.IntDiv_di() + case sqltypes.Uint64: + ct.Type = sqltypes.Uint64 + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_du() + default: + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_di() + } + } + c.asm.jumpDestination(skip) + return ct, nil +} + +func (c *compiler) compileArithmeticMod(left, right Expr) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck2(lt, rt) + lt = c.compileToNumeric(lt, 2) + rt = c.compileToNumeric(rt, 1) + + ct := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable} + switch lt.Type { + case sqltypes.Int64: + ct.Type = sqltypes.Int64 + switch rt.Type { + case sqltypes.Int64: + c.asm.Mod_ii() + case sqltypes.Uint64: + c.asm.Mod_iu() + case sqltypes.Float64: + ct.Type = sqltypes.Float64 + c.asm.Convert_xf(2) + c.asm.Mod_ff() + case sqltypes.Decimal: + ct.Type = sqltypes.Decimal + c.asm.Convert_xd(2, 0, 0) + c.asm.Mod_dd() + } + case sqltypes.Uint64: + ct.Type = sqltypes.Uint64 + switch rt.Type { + case sqltypes.Int64: + c.asm.Mod_ui() + case sqltypes.Uint64: + c.asm.Mod_uu() + case sqltypes.Float64: + ct.Type = sqltypes.Float64 + c.asm.Convert_xf(2) + c.asm.Mod_ff() + case sqltypes.Decimal: + ct.Type = sqltypes.Decimal + c.asm.Convert_xd(2, 0, 0) + c.asm.Mod_dd() + } + case sqltypes.Decimal: + ct.Type = sqltypes.Decimal + switch rt.Type { + case sqltypes.Float64: + ct.Type = sqltypes.Float64 + c.asm.Convert_xf(2) + c.asm.Mod_ff() + default: + c.asm.Convert_xd(1, 0, 0) + c.asm.Mod_dd() + } + case sqltypes.Float64: + ct.Type = sqltypes.Float64 + c.asm.Convert_xf(1) + c.asm.Mod_ff() + } + + c.asm.jumpDestination(skip) + return ct, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 4e484e9996d..771d5e08bfc 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -984,6 +984,203 @@ func (asm *assembler) Div_ff() { }, "DIV FLOAT64(SP-2), FLOAT64(SP-1)") } +func (asm *assembler) IntDiv_ii() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalInt64) + if r.i == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.i = l.i / r.i + } + vm.sp-- + return 1 + }, "INTDIV INT64(SP-2), INT64(SP-1)") +} + +func (asm *assembler) IntDiv_iu() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalUint64) + if r.u == 0 { + vm.stack[vm.sp-2] = nil + } else { + r.u, vm.err = mathIntDiv_iu0(l.i, r.u) + vm.stack[vm.sp-2] = r + } + vm.sp-- + return 1 + }, "INTDIV INT64(SP-2), UINT64(SP-1)") +} + +func (asm *assembler) IntDiv_ui() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalInt64) + if r.i == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.u, vm.err = mathIntDiv_ui0(l.u, r.i) + } + vm.sp-- + return 1 + }, "INTDIV UINT64(SP-2), INT64(SP-1)") +} + +func (asm *assembler) IntDiv_uu() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + if r.u == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.u = l.u / r.u + } + vm.sp-- + return 1 + }, "INTDIV UINT64(SP-2), UINT64(SP-1)") +} + +func (asm *assembler) IntDiv_di() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + if r.dec.IsZero() { + vm.stack[vm.sp-2] = nil + } else { + var res int64 + res, vm.err = mathIntDiv_di0(l, r) + vm.stack[vm.sp-2] = vm.arena.newEvalInt64(res) + } + vm.sp-- + return 1 + }, "INTDIV DECIMAL(SP-2), DECIMAL(SP-1)") +} + +func (asm *assembler) IntDiv_du() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + if r.dec.IsZero() { + vm.stack[vm.sp-2] = nil + } else { + var res uint64 + res, vm.err = mathIntDiv_du0(l, r) + vm.stack[vm.sp-2] = vm.arena.newEvalUint64(res) + } + vm.sp-- + return 1 + }, "UINTDIV DECIMAL(SP-2), DECIMAL(SP-1)") +} + +func (asm *assembler) Mod_ii() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalInt64) + if r.i == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.i = l.i % r.i + } + vm.sp-- + return 1 + }, "MOD INT64(SP-2), INT64(SP-1)") +} + +func (asm *assembler) Mod_iu() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalUint64) + if r.u == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.i = mathMod_iu0(l.i, r.u) + } + vm.sp-- + return 1 + }, "MOD INT64(SP-2), UINT64(SP-1)") +} + +func (asm *assembler) Mod_ui() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalInt64) + if r.i == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.u, vm.err = mathMod_ui0(l.u, r.i) + } + vm.sp-- + return 1 + }, "MOD UINT64(SP-2), INT64(SP-1)") +} + +func (asm *assembler) Mod_uu() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + if r.u == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.u = l.u % r.u + } + vm.sp-- + return 1 + }, "MOD UINT64(SP-2), UINT64(SP-1)") +} + +func (asm *assembler) Mod_ff() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalFloat) + r := vm.stack[vm.sp-1].(*evalFloat) + if r.f == 0.0 { + vm.stack[vm.sp-2] = nil + } else { + l.f = math.Mod(l.f, r.f) + } + vm.sp-- + return 1 + }, "MOD FLOAT64(SP-2), FLOAT64(SP-1)") +} + +func (asm *assembler) Mod_dd() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + if r.dec.IsZero() { + vm.stack[vm.sp-2] = nil + } else { + l.dec, l.length = mathMod_dd0(l, r) + } + vm.sp-- + return 1 + }, "MOD DECIMAL(SP-2), DECIMAL(SP-1)") +} + func (asm *assembler) Fn_ASCII() { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-1].(*evalBytes) diff --git a/go/vt/vtgate/evalengine/expr_arithmetic.go b/go/vt/vtgate/evalengine/expr_arithmetic.go index a7c8e12ed4d..63323d0f4bf 100644 --- a/go/vt/vtgate/evalengine/expr_arithmetic.go +++ b/go/vt/vtgate/evalengine/expr_arithmetic.go @@ -35,10 +35,12 @@ type ( String() string } - opArithAdd struct{} - opArithSub struct{} - opArithMul struct{} - opArithDiv struct{} + opArithAdd struct{} + opArithSub struct{} + opArithMul struct{} + opArithDiv struct{} + opArithIntDiv struct{} + opArithMod struct{} ) var _ Expr = (*ArithmeticExpr)(nil) @@ -47,6 +49,8 @@ var _ opArith = (*opArithAdd)(nil) var _ opArith = (*opArithSub)(nil) var _ opArith = (*opArithMul)(nil) var _ opArith = (*opArithDiv)(nil) +var _ opArith = (*opArithIntDiv)(nil) +var _ opArith = (*opArithMod)(nil) func (b *ArithmeticExpr) eval(env *ExpressionEnv) (eval, error) { left, right, err := b.arguments(env) @@ -78,9 +82,22 @@ func (b *ArithmeticExpr) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { switch b.Op.(type) { case *opArithDiv: if t1 == sqltypes.Float64 || t2 == sqltypes.Float64 { - return sqltypes.Float64, flags + return sqltypes.Float64, flags | flagNullable } - return sqltypes.Decimal, flags + return sqltypes.Decimal, flags | flagNullable + case *opArithIntDiv: + if t1 == sqltypes.Uint64 || t2 == sqltypes.Uint64 { + return sqltypes.Uint64, flags | flagNullable + } + return sqltypes.Int64, flags | flagNullable + case *opArithMod: + if t1 == sqltypes.Float64 || t2 == sqltypes.Float64 { + return sqltypes.Float64, flags | flagNullable + } + if t1 == sqltypes.Decimal || t2 == sqltypes.Decimal { + return sqltypes.Decimal, flags | flagNullable + } + return t1, flags | flagNullable } switch t1 { @@ -122,6 +139,16 @@ func (op *opArithDiv) eval(left, right eval) (eval, error) { } func (op *opArithDiv) String() string { return "/" } +func (op *opArithIntDiv) eval(left, right eval) (eval, error) { + return integerDivideNumericWithError(left, right, true) +} +func (op *opArithIntDiv) String() string { return "DIV" } + +func (op *opArithMod) eval(left, right eval) (eval, error) { + return modNumericWithError(left, right, true) +} +func (op *opArithMod) String() string { return "DIV" } + func (n *NegateExpr) eval(env *ExpressionEnv) (eval, error) { e, err := n.Inner.eval(env) if err != nil { diff --git a/go/vt/vtgate/evalengine/internal/decimal/decimal.go b/go/vt/vtgate/evalengine/internal/decimal/decimal.go index b93a1b9069b..14b708601b0 100644 --- a/go/vt/vtgate/evalengine/internal/decimal/decimal.go +++ b/go/vt/vtgate/evalengine/internal/decimal/decimal.go @@ -362,7 +362,7 @@ func (d Decimal) Div(d2 Decimal, scaleIncr int32) Decimal { scaleIncr = 0 } scale := myBigDigits(fracLeft+fracRight+scaleIncr) * 9 - q, _ := d.quoRem(d2, scale) + q, _ := d.QuoRem(d2, scale) return q } @@ -372,15 +372,15 @@ func (d Decimal) div(d2 Decimal) Decimal { return d.divRound(d2, int32(divisionPrecision)) } -// quoRem does division with remainder -// d.quoRem(d2,precision) returns quotient q and remainder r such that +// QuoRem does division with remainder +// d.QuoRem(d2,precision) returns quotient q and remainder r such that // // d = d2 * q + r, q an integer multiple of 10^(-precision) // 0 <= r < abs(d2) * 10 ^(-precision) if d>=0 // 0 >= r > -abs(d2) * 10 ^(-precision) if d<0 // // Note that precision<0 is allowed as input. -func (d Decimal) quoRem(d2 Decimal, precision int32) (Decimal, Decimal) { +func (d Decimal) QuoRem(d2 Decimal, precision int32) (Decimal, Decimal) { d.ensureInitialized() d2.ensureInitialized() if d2.value.Sign() == 0 { @@ -389,7 +389,7 @@ func (d Decimal) quoRem(d2 Decimal, precision int32) (Decimal, Decimal) { scale := -precision e := int64(d.exp - d2.exp - scale) if e > math.MaxInt32 || e < math.MinInt32 { - panic("overflow in decimal quoRem") + panic("overflow in decimal QuoRem") } var aa, bb, expo big.Int var scalerest int32 @@ -428,7 +428,7 @@ func (d Decimal) quoRem(d2 Decimal, precision int32) (Decimal, Decimal) { // Note that precision<0 is allowed as input. func (d Decimal) divRound(d2 Decimal, precision int32) Decimal { // quoRem already checks initialization - q, r := d.quoRem(d2, precision) + q, r := d.QuoRem(d2, precision) // the actual rounding decision is based on comparing r*10^precision and d2/2 // instead compare 2 r 10 ^precision and d2 diff --git a/go/vt/vtgate/evalengine/internal/decimal/decimal_test.go b/go/vt/vtgate/evalengine/internal/decimal/decimal_test.go index 333d24e5d4d..0f0c16763a7 100644 --- a/go/vt/vtgate/evalengine/internal/decimal/decimal_test.go +++ b/go/vt/vtgate/evalengine/internal/decimal/decimal_test.go @@ -746,11 +746,11 @@ func TestDecimal_QuoRem(t *testing.T) { d, _ := NewFromString(inp4.d) d2, _ := NewFromString(inp4.d2) prec := inp4.exp - q, r := d.quoRem(d2, prec) + q, r := d.QuoRem(d2, prec) expectedQ, _ := NewFromString(inp4.q) expectedR, _ := NewFromString(inp4.r) if !q.Equal(expectedQ) || !r.Equal(expectedR) { - t.Errorf("bad quoRem division %s , %s , %d got %v, %v expected %s , %s", + t.Errorf("bad QuoRem division %s , %s , %d got %v, %v expected %s , %s", inp4.d, inp4.d2, prec, q, r, inp4.q, inp4.r) } if !d.Equal(d2.mul(q).Add(r)) { @@ -813,7 +813,7 @@ func TestDecimal_QuoRem2(t *testing.T) { } d2 := tc.d2 prec := tc.prec - q, r := d.quoRem(d2, prec) + q, r := d.QuoRem(d2, prec) // rule 1: d = d2*q +r if !d.Equal(d2.mul(q).Add(r)) { t.Errorf("not fitting, d=%v, d2=%v, prec=%d, q=%v, r=%v", diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index d93cb3b837b..360217d092a 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -537,7 +537,7 @@ func Types(yield Query) { } func Arithmetic(yield Query) { - operators := []string{"+", "-", "*", "/"} + operators := []string{"+", "-", "*", "/", "DIV", "%", "MOD"} for _, op := range operators { for _, lhs := range inputConversions { diff --git a/go/vt/vtgate/evalengine/translate.go b/go/vt/vtgate/evalengine/translate.go index 4ace1b54f4f..b3751cdfbcf 100644 --- a/go/vt/vtgate/evalengine/translate.go +++ b/go/vt/vtgate/evalengine/translate.go @@ -234,6 +234,10 @@ func (ast *astCompiler) translateBinaryExpr(binary *sqlparser.BinaryExpr) (Expr, return &ArithmeticExpr{BinaryExpr: binaryExpr, Op: &opArithMul{}}, nil case sqlparser.DivOp: return &ArithmeticExpr{BinaryExpr: binaryExpr, Op: &opArithDiv{}}, nil + case sqlparser.IntDivOp: + return &ArithmeticExpr{BinaryExpr: binaryExpr, Op: &opArithIntDiv{}}, nil + case sqlparser.ModOp: + return &ArithmeticExpr{BinaryExpr: binaryExpr, Op: &opArithMod{}}, nil case sqlparser.BitAndOp: return &BitwiseExpr{BinaryExpr: binaryExpr, Op: &opBitAnd{}}, nil case sqlparser.BitOrOp: From 7b83dbb64b41b255e6848a39743fdfea66a06e05 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Sun, 19 Mar 2023 13:23:12 +0100 Subject: [PATCH 2/4] evalengine: Implement NOT and logical operations This adds the NOT and generic logical operations to the compiler and fixes a number of existing bugs in the evalengine. Specifically NOT is currently broken as we don't translate it properly. Some main issues are that we need to ensure lazy evaluation for the logical operations but also needing it for arithmetic as well. All cases where we've been pushing the boolean singleton value need to be fixed as well in the compiler, because we inline update things in arithmetic operations and we'd update the singleton value before. It also needs to split parsing into JSON from using an argument as a partial JSON value. This specifically is needed for CAST() with an input string where it should parse it. Additionally, convering a JSON boolean to numeric needs to create a floating point value, not an integer one. Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/arena.go | 7 + go/vt/vtgate/evalengine/compiler.go | 6 + .../vtgate/evalengine/compiler_arithmetic.go | 43 ++-- go/vt/vtgate/evalengine/compiler_asm.go | 225 +++++++++++++++--- go/vt/vtgate/evalengine/compiler_bit.go | 20 +- go/vt/vtgate/evalengine/compiler_compare.go | 99 ++++++++ go/vt/vtgate/evalengine/compiler_fn.go | 11 +- go/vt/vtgate/evalengine/compiler_json.go | 24 +- go/vt/vtgate/evalengine/compiler_test.go | 30 ++- go/vt/vtgate/evalengine/eval.go | 15 ++ go/vt/vtgate/evalengine/eval_json.go | 29 +++ go/vt/vtgate/evalengine/eval_numeric.go | 4 +- go/vt/vtgate/evalengine/expr_arithmetic.go | 9 +- go/vt/vtgate/evalengine/expr_bit.go | 9 +- go/vt/vtgate/evalengine/expr_logical.go | 96 +++++--- go/vt/vtgate/evalengine/fn_base64.go | 8 +- go/vt/vtgate/evalengine/fn_json.go | 4 +- go/vt/vtgate/evalengine/testcases/cases.go | 32 ++- go/vt/vtgate/evalengine/testcases/inputs.go | 6 + go/vt/vtgate/evalengine/translate.go | 8 +- go/vt/vtgate/evalengine/translate_card.go | 2 + 21 files changed, 581 insertions(+), 106 deletions(-) diff --git a/go/vt/vtgate/evalengine/arena.go b/go/vt/vtgate/evalengine/arena.go index bda082f99cf..0a4627636ae 100644 --- a/go/vt/vtgate/evalengine/arena.go +++ b/go/vt/vtgate/evalengine/arena.go @@ -60,6 +60,13 @@ func (a *Arena) newEvalDecimal(dec decimal.Decimal, m, d int32) *evalDecimal { return a.newEvalDecimalWithPrec(dec.Clamp(m-d, d), d) } +func (a *Arena) newEvalBool(b bool) *evalInt64 { + if b { + return a.newEvalInt64(1) + } + return a.newEvalInt64(0) +} + func (a *Arena) newEvalInt64(i int64) *evalInt64 { if cap(a.aInt64) > len(a.aInt64) { a.aInt64 = a.aInt64[:len(a.aInt64)+1] diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 24ff089c047..9138816855d 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -148,6 +148,12 @@ func (c *compiler) compileExpr(expr Expr) (ctype, error) { case *InExpr: return c.compileIn(expr) + case *NotExpr: + return c.compileNot(expr) + + case *LogicalExpr: + return c.compileLogical(expr) + case callable: return c.compileFn(expr) diff --git a/go/vt/vtgate/evalengine/compiler_arithmetic.go b/go/vt/vtgate/evalengine/compiler_arithmetic.go index 772ca623c50..972e0b0b539 100644 --- a/go/vt/vtgate/evalengine/compiler_arithmetic.go +++ b/go/vt/vtgate/evalengine/compiler_arithmetic.go @@ -100,6 +100,7 @@ func (c *compiler) compileArithmeticAdd(left, right Expr) (ctype, error) { if err != nil { return ctype{}, err } + skip1 := c.compileNullCheck1(lt) rt, err := c.compileExpr(right) if err != nil { @@ -107,7 +108,7 @@ func (c *compiler) compileArithmeticAdd(left, right Expr) (ctype, error) { } swap := false - skip := c.compileNullCheck2(lt, rt) + skip2 := c.compileNullCheck2(lt, rt) lt = c.compileToNumeric(lt, 2) rt = c.compileToNumeric(rt, 1) lt, rt, swap = c.compileNumericPriority(lt, rt) @@ -144,7 +145,8 @@ func (c *compiler) compileArithmeticAdd(left, right Expr) (ctype, error) { sumtype = sqltypes.Float64 } - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) return ctype{Type: sumtype, Col: collationNumeric}, nil } @@ -153,13 +155,14 @@ func (c *compiler) compileArithmeticSub(left, right Expr) (ctype, error) { if err != nil { return ctype{}, err } + skip1 := c.compileNullCheck1(lt) rt, err := c.compileExpr(right) if err != nil { return ctype{}, err } - skip := c.compileNullCheck2(lt, rt) + skip2 := c.compileNullCheck2(lt, rt) lt = c.compileToNumeric(lt, 2) rt = c.compileToNumeric(rt, 1) @@ -221,7 +224,8 @@ func (c *compiler) compileArithmeticSub(left, right Expr) (ctype, error) { panic("did not compile?") } - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) return ctype{Type: subtype, Col: collationNumeric}, nil } @@ -230,6 +234,7 @@ func (c *compiler) compileArithmeticMul(left, right Expr) (ctype, error) { if err != nil { return ctype{}, err } + skip1 := c.compileNullCheck1(lt) rt, err := c.compileExpr(right) if err != nil { @@ -237,7 +242,7 @@ func (c *compiler) compileArithmeticMul(left, right Expr) (ctype, error) { } swap := false - skip := c.compileNullCheck2(lt, rt) + skip2 := c.compileNullCheck2(lt, rt) lt = c.compileToNumeric(lt, 2) rt = c.compileToNumeric(rt, 1) lt, rt, swap = c.compileNumericPriority(lt, rt) @@ -274,7 +279,8 @@ func (c *compiler) compileArithmeticMul(left, right Expr) (ctype, error) { multype = sqltypes.Decimal } - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) return ctype{Type: multype, Col: collationNumeric}, nil } @@ -283,29 +289,32 @@ func (c *compiler) compileArithmeticDiv(left, right Expr) (ctype, error) { if err != nil { return ctype{}, err } + skip1 := c.compileNullCheck1(lt) rt, err := c.compileExpr(right) if err != nil { return ctype{}, err } + skip2 := c.compileNullCheck2(lt, rt) - skip := c.compileNullCheck2(lt, rt) lt = c.compileToNumeric(lt, 2) rt = c.compileToNumeric(rt, 1) + ct := ctype{Col: collationNumeric, Flag: flagNullable} if lt.Type == sqltypes.Float64 || rt.Type == sqltypes.Float64 { + ct.Type = sqltypes.Float64 c.compileToFloat(lt, 2) c.compileToFloat(rt, 1) c.asm.Div_ff() - c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil } else { + ct.Type = sqltypes.Decimal c.compileToDecimal(lt, 2) c.compileToDecimal(rt, 1) c.asm.Div_dd() - c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Decimal, Col: collationNumeric}, nil } + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) + return ct, nil } func (c *compiler) compileArithmeticIntDiv(left, right Expr) (ctype, error) { @@ -313,13 +322,14 @@ func (c *compiler) compileArithmeticIntDiv(left, right Expr) (ctype, error) { if err != nil { return ctype{}, err } + skip1 := c.compileNullCheck1(lt) rt, err := c.compileExpr(right) if err != nil { return ctype{}, err } - skip := c.compileNullCheck2(lt, rt) + skip2 := c.compileNullCheck2(lt, rt) lt = c.compileToNumeric(lt, 2) rt = c.compileToNumeric(rt, 1) @@ -383,7 +393,8 @@ func (c *compiler) compileArithmeticIntDiv(left, right Expr) (ctype, error) { c.asm.IntDiv_di() } } - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) return ct, nil } @@ -392,13 +403,14 @@ func (c *compiler) compileArithmeticMod(left, right Expr) (ctype, error) { if err != nil { return ctype{}, err } + skip1 := c.compileNullCheck1(lt) rt, err := c.compileExpr(right) if err != nil { return ctype{}, err } - skip := c.compileNullCheck2(lt, rt) + skip2 := c.compileNullCheck2(lt, rt) lt = c.compileToNumeric(lt, 2) rt = c.compileToNumeric(rt, 1) @@ -453,6 +465,7 @@ func (c *compiler) compileArithmeticMod(left, right Expr) (ctype, error) { c.asm.Mod_ff() } - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) return ct, nil } diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 771d5e08bfc..0bf28cac748 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -363,7 +363,7 @@ func (asm *assembler) BitwiseNot_u() { func (asm *assembler) Cmp_eq() { asm.adjustStack(1) asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp == 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp == 0) vm.sp++ return 1 }, "CMPFLAG EQ") @@ -375,7 +375,7 @@ func (asm *assembler) Cmp_eq_n() { if vm.flags.null { vm.stack[vm.sp] = nil } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp == 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp == 0) } vm.sp++ return 1 @@ -385,7 +385,7 @@ func (asm *assembler) Cmp_eq_n() { func (asm *assembler) Cmp_ge() { asm.adjustStack(1) asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp >= 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp >= 0) vm.sp++ return 1 }, "CMPFLAG GE") @@ -397,7 +397,7 @@ func (asm *assembler) Cmp_ge_n() { if vm.flags.null { vm.stack[vm.sp] = nil } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp >= 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp >= 0) } vm.sp++ return 1 @@ -407,7 +407,7 @@ func (asm *assembler) Cmp_ge_n() { func (asm *assembler) Cmp_gt() { asm.adjustStack(1) asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp > 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp > 0) vm.sp++ return 1 }, "CMPFLAG GT") @@ -419,7 +419,7 @@ func (asm *assembler) Cmp_gt_n() { if vm.flags.null { vm.stack[vm.sp] = nil } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp > 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp > 0) } vm.sp++ return 1 @@ -429,7 +429,7 @@ func (asm *assembler) Cmp_gt_n() { func (asm *assembler) Cmp_le() { asm.adjustStack(1) asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp <= 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp <= 0) vm.sp++ return 1 }, "CMPFLAG LE") @@ -441,7 +441,7 @@ func (asm *assembler) Cmp_le_n() { if vm.flags.null { vm.stack[vm.sp] = nil } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp <= 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp <= 0) } vm.sp++ return 1 @@ -451,7 +451,7 @@ func (asm *assembler) Cmp_le_n() { func (asm *assembler) Cmp_lt() { asm.adjustStack(1) asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp < 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp < 0) vm.sp++ return 1 }, "CMPFLAG LT") @@ -463,7 +463,7 @@ func (asm *assembler) Cmp_lt_n() { if vm.flags.null { vm.stack[vm.sp] = nil } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp < 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp < 0) } vm.sp++ return 1 @@ -472,7 +472,7 @@ func (asm *assembler) Cmp_lt_n() { func (asm *assembler) Cmp_ne() { asm.adjustStack(1) asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp != 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp != 0) vm.sp++ return 1 }, "CMPFLAG NE") @@ -484,7 +484,7 @@ func (asm *assembler) Cmp_ne_n() { if vm.flags.null { vm.stack[vm.sp] = nil } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp != 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp != 0) } vm.sp++ return 1 @@ -701,7 +701,7 @@ func (asm *assembler) CmpTupleNullsafe() { var equals bool equals, vm.err = evalCompareTuplesNullSafe(l.t, r.t) - vm.stack[vm.sp-2] = newEvalBool(equals) + vm.stack[vm.sp-2] = vm.arena.newEvalBool(equals) vm.sp -= 1 return 1 }, "CMP NULLSAFE TUPLE(SP-2), TUPLE(SP-1)") @@ -719,11 +719,24 @@ func (asm *assembler) Collate(col collations.ID) { func (asm *assembler) Convert_bB(offset int) { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-offset] - vm.stack[vm.sp-offset] = newEvalBool(arg != nil && parseStringToFloat(arg.(*evalBytes).string()) != 0.0) + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(arg != nil && parseStringToFloat(arg.(*evalBytes).string()) != 0.0) return 1 }, "CONV VARBINARY(SP-%d), BOOL", offset) } +func (asm *assembler) Convert_jB(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalJSON) + switch arg.Type() { + case json.TypeNumber: + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(parseStringToFloat(arg.String()) != 0.0) + default: + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(true) + } + return 1 + }, "CONV JSON(SP-%d), BOOL", offset) +} + func (asm *assembler) Convert_bj(offset int) { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-offset].(*evalBytes) @@ -732,6 +745,14 @@ func (asm *assembler) Convert_bj(offset int) { }, "CONV VARBINARY(SP-%d), JSON", offset) } +func (asm *assembler) ConvertArg_cj(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalBytes) + vm.stack[vm.sp-offset], vm.err = evalConvertArg_cj(arg) + return 1 + }, "CONVA VARCHAR(SP-%d), JSON", offset) +} + func (asm *assembler) Convert_cj(offset int) { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-offset].(*evalBytes) @@ -743,7 +764,7 @@ func (asm *assembler) Convert_cj(offset int) { func (asm *assembler) Convert_dB(offset int) { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-offset] - vm.stack[vm.sp-offset] = newEvalBool(arg != nil && !arg.(*evalDecimal).dec.IsZero()) + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(arg != nil && !arg.(*evalDecimal).dec.IsZero()) return 1 }, "CONV DECIMAL(SP-%d), BOOL", offset) } @@ -763,7 +784,7 @@ func (asm *assembler) Convert_dbit(offset int) { func (asm *assembler) Convert_fB(offset int) { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-offset] - vm.stack[vm.sp-offset] = newEvalBool(arg != nil && arg.(*evalFloat).f != 0.0) + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(arg != nil && arg.(*evalFloat).f != 0.0) return 1 }, "CONV FLOAT64(SP-%d), BOOL", offset) } @@ -790,7 +811,7 @@ func (asm *assembler) Convert_hex(offset int) { func (asm *assembler) Convert_iB(offset int) { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-offset] - vm.stack[vm.sp-offset] = newEvalBool(arg != nil && arg.(*evalInt64).i != 0) + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(arg != nil && arg.(*evalInt64).i != 0) return 1 }, "CONV INT64(SP-%d), BOOL", offset) } @@ -848,7 +869,7 @@ func (asm *assembler) Convert_Nj(offset int) { func (asm *assembler) Convert_uB(offset int) { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-offset] - vm.stack[vm.sp-offset] = newEvalBool(arg != nil && arg.(*evalUint64).u != 0) + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(arg != nil && arg.(*evalUint64).u != 0) return 1 }, "CONV UINT64(SP-%d), BOOL", offset) } @@ -1379,7 +1400,7 @@ func (asm *assembler) Fn_COLLATION(col collations.TypedCollation) { }, "FN COLLATION (SP-1)") } -func (asm *assembler) Fn_FROM_BASE64() { +func (asm *assembler) Fn_FROM_BASE64(t sqltypes.Type) { asm.emit(func(vm *VirtualMachine) int { str := vm.stack[vm.sp-1].(*evalBytes) @@ -1390,7 +1411,7 @@ func (asm *assembler) Fn_FROM_BASE64() { vm.stack[vm.sp-1] = nil return 1 } - str.tt = int16(sqltypes.VarBinary) + str.tt = int16(t) str.bytes = decoded[:n] return 1 }, "FN FROM_BASE64 VARCHAR(SP-1)") @@ -1439,7 +1460,7 @@ func (asm *assembler) Fn_JSON_CONTAINS_PATH(match jsonMatch, paths []*json.Path) break } } - vm.stack[vm.sp-1] = newEvalBool(matched) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(matched) return 1 }, "FN JSON_CONTAINS_PATH, SP-1, 'one', [static]") case jsonMatchAll: @@ -1453,7 +1474,7 @@ func (asm *assembler) Fn_JSON_CONTAINS_PATH(match jsonMatch, paths []*json.Path) break } } - vm.stack[vm.sp-1] = newEvalBool(matched) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(matched) return 1 }, "FN JSON_CONTAINS_PATH, SP-1, 'all', [static]") } @@ -1812,7 +1833,7 @@ func (asm *assembler) In_table(not bool, table map[vthash.Hash]struct{}) { vm.hash.Reset() lhs.(hashable).Hash(&vm.hash) _, in := table[vm.hash.Sum128()] - vm.stack[vm.sp-1] = newEvalBool(!in) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(!in) } return 1 }, "NOT IN (SP-1), [static table]") @@ -1823,7 +1844,7 @@ func (asm *assembler) In_table(not bool, table map[vthash.Hash]struct{}) { vm.hash.Reset() lhs.(hashable).Hash(&vm.hash) _, in := table[vm.hash.Sum128()] - vm.stack[vm.sp-1] = newEvalBool(in) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(in) } return 1 }, "IN (SP-1), [static table]") @@ -1862,11 +1883,149 @@ func (asm *assembler) In_slow(not bool) { func (asm *assembler) Is(check func(eval) bool) { asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp-1] = newEvalBool(check(vm.stack[vm.sp-1])) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(check(vm.stack[vm.sp-1])) return 1 }, "IS (SP-1), [static]") } +func (asm *assembler) Not_i() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalInt64) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(arg.i == 0) + return 1 + }, "NOT INT64(SP-1)") +} + +func (asm *assembler) Not_u() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalUint64) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(arg.u == 0) + return 1 + }, "NOT UINT64(SP-1)") +} + +func (asm *assembler) Not_f() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalFloat) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(arg.f == 0.0) + return 1 + }, "NOT FLOAT64(SP-1)") +} + +func (asm *assembler) Not_d() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalDecimal) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(arg.dec.IsZero()) + return 1 + }, "NOT DECIMAL(SP-1)") +} + +func (asm *assembler) LogicalLeft(opname string) *jump { + switch opname { + case "AND": + j := asm.jumpFrom() + asm.emit(func(vm *VirtualMachine) int { + left, ok := vm.stack[vm.sp-1].(*evalInt64) + if ok && left.i == 0 { + return j.offset() + } + return 1 + }, "AND CHECK INT64(SP-1)") + return j + case "OR": + j := asm.jumpFrom() + asm.emit(func(vm *VirtualMachine) int { + left, ok := vm.stack[vm.sp-1].(*evalInt64) + if ok && left.i != 0 { + left.i = 1 + return j.offset() + } + return 1 + }, "OR CHECK INT64(SP-1)") + return j + case "XOR": + j := asm.jumpFrom() + asm.emit(func(vm *VirtualMachine) int { + if vm.stack[vm.sp-1] == nil { + return j.offset() + } + return 1 + }, "XOR CHECK INT64(SP-1)") + return j + } + return nil +} + +func (asm *assembler) LogicalRight(opname string) { + asm.adjustStack(-1) + switch opname { + case "AND": + asm.emit(func(vm *VirtualMachine) int { + left, lok := vm.stack[vm.sp-2].(*evalInt64) + right, rok := vm.stack[vm.sp-1].(*evalInt64) + + isLeft := lok && left.i != 0 + isRight := rok && right.i != 0 + + if isLeft && isRight { + left.i = 1 + } else if rok && !isRight { + vm.stack[vm.sp-2] = vm.arena.newEvalBool(false) + } else { + vm.stack[vm.sp-2] = nil + } + vm.sp-- + return 1 + }, "AND INT64(SP-2), INT64(SP-1)") + case "OR": + asm.emit(func(vm *VirtualMachine) int { + left, lok := vm.stack[vm.sp-2].(*evalInt64) + right, rok := vm.stack[vm.sp-1].(*evalInt64) + + isLeft := lok && left.i != 0 + isRight := rok && right.i != 0 + + switch { + case !lok: + if isRight { + vm.stack[vm.sp-2] = vm.arena.newEvalBool(true) + } + case !rok: + vm.stack[vm.sp-2] = nil + default: + if isLeft || isRight { + left.i = 1 + } else { + left.i = 0 + } + } + vm.sp-- + return 1 + }, "OR INT64(SP-2), INT64(SP-1)") + case "XOR": + asm.emit(func(vm *VirtualMachine) int { + left := vm.stack[vm.sp-2].(*evalInt64) + right, rok := vm.stack[vm.sp-1].(*evalInt64) + + isLeft := left.i != 0 + isRight := rok && right.i != 0 + + switch { + case !rok: + vm.stack[vm.sp-2] = nil + default: + if isLeft != isRight { + left.i = 1 + } else { + left.i = 0 + } + } + vm.sp-- + return 1 + }, "XOR INT64(SP-2), INT64(SP-1)") + } +} + func (asm *assembler) Like_coerce(expr *LikeExpr, coercion *compiledCoercion) { asm.adjustStack(-1) @@ -1886,11 +2045,7 @@ func (asm *assembler) Like_coerce(expr *LikeExpr, coercion *compiledCoercion) { } match := expr.matchWildcard(bl, br, coercion.col.ID()) - if match { - vm.stack[vm.sp-1] = evalBoolTrue - } else { - vm.stack[vm.sp-1] = evalBoolFalse - } + vm.stack[vm.sp-1] = vm.arena.newEvalBool(match) return 1 }, "LIKE VARCHAR(SP-2), VARCHAR(SP-1) COERCE AND COLLATE '%s'", coercion.col.Name()) } @@ -1904,11 +2059,7 @@ func (asm *assembler) Like_collate(expr *LikeExpr, collation collations.Collatio vm.sp-- match := expr.matchWildcard(l.bytes, r.bytes, collation.ID()) - if match { - vm.stack[vm.sp-1] = evalBoolTrue - } else { - vm.stack[vm.sp-1] = evalBoolFalse - } + vm.stack[vm.sp-1] = vm.arena.newEvalBool(match) return 1 }, "LIKE VARCHAR(SP-2), VARCHAR(SP-1) COLLATE '%s'", collation.Name()) } @@ -2250,12 +2401,12 @@ func (asm *assembler) PushNull() { func (asm *assembler) SetBool(offset int, b bool) { if b { asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp-offset] = evalBoolTrue + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(true) return 1 }, "SET (SP-%d), BOOL(true)", offset) } else { asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp-offset] = evalBoolFalse + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(false) return 1 }, "SET (SP-%d), BOOL(false)", offset) } diff --git a/go/vt/vtgate/evalengine/compiler_bit.go b/go/vt/vtgate/evalengine/compiler_bit.go index 2bf3f0d98d3..bc37dd7bc8d 100644 --- a/go/vt/vtgate/evalengine/compiler_bit.go +++ b/go/vt/vtgate/evalengine/compiler_bit.go @@ -41,17 +41,20 @@ func (c *compiler) compileBitwiseOp(left Expr, right Expr, asm_ins_bb, asm_ins_u return ctype{}, err } + skip1 := c.compileNullCheck1(lt) + rt, err := c.compileExpr(right) if err != nil { return ctype{}, err } - skip := c.compileNullCheck2(lt, rt) + skip2 := c.compileNullCheck2(lt, rt) if lt.Type == sqltypes.VarBinary && rt.Type == sqltypes.VarBinary { if !lt.isHexOrBitLiteral() || !rt.isHexOrBitLiteral() { asm_ins_bb() - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil } } @@ -60,7 +63,8 @@ func (c *compiler) compileBitwiseOp(left Expr, right Expr, asm_ins_bb, asm_ins_u rt = c.compileToBitwiseUint64(rt, 1) asm_ins_uu() - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil } @@ -70,12 +74,14 @@ func (c *compiler) compileBitwiseShift(left Expr, right Expr, i int) (ctype, err return ctype{}, err } + skip1 := c.compileNullCheck1(lt) + rt, err := c.compileExpr(right) if err != nil { return ctype{}, err } - skip := c.compileNullCheck2(lt, rt) + skip2 := c.compileNullCheck2(lt, rt) if lt.Type == sqltypes.VarBinary && !lt.isHexOrBitLiteral() { _ = c.compileToUint64(rt, 1) @@ -84,7 +90,8 @@ func (c *compiler) compileBitwiseShift(left Expr, right Expr, i int) (ctype, err } else { c.asm.BitShiftRight_bu() } - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil } @@ -97,7 +104,8 @@ func (c *compiler) compileBitwiseShift(left Expr, right Expr, i int) (ctype, err c.asm.BitShiftRight_uu() } - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil } diff --git a/go/vt/vtgate/evalengine/compiler_compare.go b/go/vt/vtgate/evalengine/compiler_compare.go index 15d28370ef5..bece3bd4aaa 100644 --- a/go/vt/vtgate/evalengine/compiler_compare.go +++ b/go/vt/vtgate/evalengine/compiler_compare.go @@ -376,3 +376,102 @@ func (c *compiler) compileIn(expr *InExpr) (ctype, error) { } return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}, nil } + +func (c *compiler) compileNot(expr *NotExpr) (ctype, error) { + arg, err := c.compileExpr(expr.Inner) + if err != nil { + return ctype{}, nil + } + + skip := c.compileNullCheck1(arg) + + switch arg.Type { + case sqltypes.Null: + // No-op. + case sqltypes.Int64: + c.asm.Not_i() + case sqltypes.Uint64: + c.asm.Not_u() + case sqltypes.Float64: + c.asm.Not_f() + case sqltypes.Decimal: + c.asm.Not_d() + case sqltypes.VarChar, sqltypes.VarBinary: + if arg.isHexOrBitLiteral() { + c.asm.Convert_xu(1) + c.asm.Not_u() + } else { + c.asm.Convert_bB(1) + c.asm.Not_i() + } + case sqltypes.TypeJSON: + c.asm.Convert_jB(1) + c.asm.Not_i() + default: + return ctype{}, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "unsupported Not check: %s", arg.Type) + } + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Int64, Flag: flagNullable | flagIsBoolean, Col: collationNumeric}, nil +} + +func (c *compiler) compileLogical(expr *LogicalExpr) (ctype, error) { + lt, err := c.compileExpr(expr.Left) + if err != nil { + return ctype{}, err + } + + switch lt.Type { + case sqltypes.Null, sqltypes.Int64: + // No-op. + case sqltypes.Uint64: + c.asm.Convert_uB(1) + case sqltypes.Float64: + c.asm.Convert_fB(1) + case sqltypes.Decimal: + c.asm.Convert_dB(1) + case sqltypes.VarChar, sqltypes.VarBinary: + if lt.isHexOrBitLiteral() { + c.asm.Convert_xu(1) + c.asm.Convert_uB(1) + } else { + c.asm.Convert_bB(1) + } + case sqltypes.TypeJSON: + c.asm.Convert_jB(1) + default: + c.asm.Convert_bB(1) + } + + jump := c.asm.LogicalLeft(expr.opname) + + rt, err := c.compileExpr(expr.Right) + if err != nil { + return ctype{}, err + } + + switch rt.Type { + case sqltypes.Null, sqltypes.Int64: + // No-op. + case sqltypes.Uint64: + c.asm.Convert_uB(1) + case sqltypes.Float64: + c.asm.Convert_fB(1) + case sqltypes.Decimal: + c.asm.Convert_dB(1) + case sqltypes.VarChar, sqltypes.VarBinary: + if rt.isHexOrBitLiteral() { + c.asm.Convert_xu(1) + c.asm.Convert_uB(1) + } else { + c.asm.Convert_bB(1) + } + case sqltypes.TypeJSON: + c.asm.Convert_jB(1) + default: + c.asm.Convert_bB(1) + } + + c.asm.LogicalRight(expr.opname) + c.asm.jumpDestination(jump) + return ctype{Type: sqltypes.Int64, Flag: flagNullable | flagIsBoolean, Col: collationNumeric}, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go index b178d713031..c2a38db608d 100644 --- a/go/vt/vtgate/evalengine/compiler_fn.go +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -261,16 +261,21 @@ func (c *compiler) compileFn_FROM_BASE64(call *builtinFromBase64) (ctype, error) skip := c.compileNullCheck1(str) + t := sqltypes.VarBinary + if str.Type == sqltypes.Blob || str.Type == sqltypes.TypeJSON { + t = sqltypes.Blob + } + switch { case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): default: - c.asm.Convert_xc(1, sqltypes.VarBinary, c.defaultCollation, 0, false) + c.asm.Convert_xc(1, t, c.defaultCollation, 0, false) } - c.asm.Fn_FROM_BASE64() + c.asm.Fn_FROM_BASE64(t) c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil + return ctype{Type: t, Col: collationBinary}, nil } func (c *compiler) compileFn_CCASE(call *builtinChangeCase) (ctype, error) { diff --git a/go/vt/vtgate/evalengine/compiler_json.go b/go/vt/vtgate/evalengine/compiler_json.go index f6232394849..e786709ee35 100644 --- a/go/vt/vtgate/evalengine/compiler_json.go +++ b/go/vt/vtgate/evalengine/compiler_json.go @@ -66,6 +66,26 @@ func (c *compiler) compileToJSON(doct ctype, offset int) (ctype, error) { return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil } +func (c *compiler) compileArgToJSON(doct ctype, offset int) (ctype, error) { + switch doct.Type { + case sqltypes.TypeJSON: + return doct, nil + case sqltypes.Float64: + c.asm.Convert_fj(offset) + case sqltypes.Int64, sqltypes.Uint64, sqltypes.Decimal: + c.asm.Convert_nj(offset, doct.Flag&flagIsBoolean != 0) + case sqltypes.VarChar: + c.asm.ConvertArg_cj(offset) + case sqltypes.VarBinary: + c.asm.Convert_bj(offset) + case sqltypes.Null: + c.asm.Convert_Nj(offset) + default: + return ctype{}, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "Unsupported type conversion: %s AS JSON", doct.Type) + } + return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil +} + func (c *compiler) compileFn_JSON_ARRAY(call *builtinJSONArray) (ctype, error) { for _, arg := range call.Arguments { tt, err := c.compileExpr(arg) @@ -73,7 +93,7 @@ func (c *compiler) compileFn_JSON_ARRAY(call *builtinJSONArray) (ctype, error) { return ctype{}, err } - _, err = c.compileToJSON(tt, 1) + _, err = c.compileArgToJSON(tt, 1) if err != nil { return ctype{}, err } @@ -95,7 +115,7 @@ func (c *compiler) compileFn_JSON_OBJECT(call *builtinJSONObject) (ctype, error) if err != nil { return ctype{}, err } - _, err = c.compileToJSON(val, 1) + _, err = c.compileArgToJSON(val, 1) if err != nil { return ctype{}, err } diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index ac03d186607..cb76489c8c0 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -186,7 +186,7 @@ func (d *debugCompiler) Stack(old, new int) { d.t.Logf("\tsp = %d -> %d", old, new) } -func TestCompiler(t *testing.T) { +func TestCompilerSingle(t *testing.T) { var testCases = []struct { expression string values []sqltypes.Value @@ -255,6 +255,22 @@ func TestCompiler(t *testing.T) { expression: `JSON_ARRAY(true, 1.0)`, result: `JSON("[true, 1.0]")`, }, + { + expression: `cast(true as json) + 0`, + result: `FLOAT64(1)`, + }, + { + expression: `CAST(CAST(0 AS JSON) AS CHAR(16))`, + result: `VARCHAR("0")`, + }, + { + expression: `1 OR cast('invalid' as json)`, + result: `INT64(1)`, + }, + { + expression: `NULL AND 1`, + result: `NULL`, + }, } for _, tc := range testCases { @@ -269,6 +285,16 @@ func TestCompiler(t *testing.T) { t.Fatal(err) } + env := evalengine.EmptyExpressionEnv() + env.Row = tc.values + expected, err := env.Evaluate(converted) + if err != nil { + t.Fatal(err) + } + if expected.String() != tc.result { + t.Fatalf("bad evaluation from eval engine: got %s, want %s", expected.String(), tc.result) + } + compiled, err := evalengine.Compile(converted, makeFields(tc.values), evalengine.WithAssemblerLog(&debugCompiler{t})) if err != nil { t.Fatal(err) @@ -283,7 +309,7 @@ func TestCompiler(t *testing.T) { } if res.String() != tc.result { - t.Fatalf("bad evaluation: got %s, want %s (iteration %d)", res, tc.result, i) + t.Fatalf("bad evaluation from compiler: got %s, want %s (iteration %d)", res, tc.result, i) } } }) diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index 0c740cb3efc..d4d4e9b3539 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -130,7 +130,22 @@ func evalIsTruthy(e eval) boolean { case *evalDecimal: return makeboolean(!e.dec.IsZero()) case *evalBytes: + if e.isHexLiteral { + hex, ok := e.toNumericHex() + if !ok { + // overflow + return makeboolean(true) + } + return makeboolean(hex.u != 0) + } return makeboolean(parseStringToFloat(e.string()) != 0.0) + case *evalJSON: + switch e.Type() { + case json.TypeNumber: + return makeboolean(parseStringToFloat(e.Raw()) != 0.0) + default: + return makeboolean(true) + } default: panic("unhandled case: evalIsTruthy") } diff --git a/go/vt/vtgate/evalengine/eval_json.go b/go/vt/vtgate/evalengine/eval_json.go index 8275f241c70..56222d18946 100644 --- a/go/vt/vtgate/evalengine/eval_json.go +++ b/go/vt/vtgate/evalengine/eval_json.go @@ -92,6 +92,15 @@ func evalConvert_nj(e evalNumeric) *evalJSON { } func evalConvert_cj(e *evalBytes) (*evalJSON, error) { + jsonText, err := charset.Convert(nil, charset.Charset_utf8mb4{}, e.bytes, e.col.Collation.Get().Charset()) + if err != nil { + return nil, err + } + var p json.Parser + return p.ParseBytes(jsonText) +} + +func evalConvertArg_cj(e *evalBytes) (*evalJSON, error) { jsonText, err := charset.Convert(nil, charset.Charset_utf8mb4{}, e.bytes, e.col.Collation.Get().Charset()) if err != nil { return nil, err @@ -118,3 +127,23 @@ func evalToJSON(e eval) (*evalJSON, error) { return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Unsupported type conversion: %s AS JSON", e.SQLType()) } } + +func argToJSON(e eval) (*evalJSON, error) { + switch e := e.(type) { + case nil: + return json.ValueNull, nil + case *evalJSON: + return e, nil + case *evalFloat: + return evalConvert_fj(e), nil + case evalNumeric: + return evalConvert_nj(e), nil + case *evalBytes: + if sqltypes.IsBinary(e.SQLType()) { + return evalConvert_bj(e), nil + } + return evalConvertArg_cj(e) + default: + return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Unsupported type conversion: %s AS JSON", e.SQLType()) + } +} diff --git a/go/vt/vtgate/evalengine/eval_numeric.go b/go/vt/vtgate/evalengine/eval_numeric.go index 4d4dc3ef133..b6628ab4dfe 100644 --- a/go/vt/vtgate/evalengine/eval_numeric.go +++ b/go/vt/vtgate/evalengine/eval_numeric.go @@ -111,9 +111,9 @@ func evalToNumeric(e eval) evalNumeric { case *evalJSON: switch e.Type() { case json.TypeTrue: - return newEvalBool(true) + return &evalFloat{f: 1.0} case json.TypeFalse: - return newEvalBool(false) + return &evalFloat{f: 0.0} case json.TypeNumber, json.TypeString: return &evalFloat{f: parseStringToFloat(e.Raw())} default: diff --git a/go/vt/vtgate/evalengine/expr_arithmetic.go b/go/vt/vtgate/evalengine/expr_arithmetic.go index 63323d0f4bf..b592c681b41 100644 --- a/go/vt/vtgate/evalengine/expr_arithmetic.go +++ b/go/vt/vtgate/evalengine/expr_arithmetic.go @@ -53,8 +53,13 @@ var _ opArith = (*opArithIntDiv)(nil) var _ opArith = (*opArithMod)(nil) func (b *ArithmeticExpr) eval(env *ExpressionEnv) (eval, error) { - left, right, err := b.arguments(env) - if left == nil || right == nil || err != nil { + left, err := b.Left.eval(env) + if left == nil || err != nil { + return nil, err + } + + right, err := b.Right.eval(env) + if right == nil || err != nil { return nil, err } return b.Op.eval(left, right) diff --git a/go/vt/vtgate/evalengine/expr_bit.go b/go/vt/vtgate/evalengine/expr_bit.go index 1e68438b45f..3e87fdd2fe4 100644 --- a/go/vt/vtgate/evalengine/expr_bit.go +++ b/go/vt/vtgate/evalengine/expr_bit.go @@ -173,8 +173,13 @@ func (o opBitAnd) BitwiseOp() string { return "&" } var errBitwiseOperandsLength = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Binary operands of bitwise operators must be of equal length") func (bit *BitwiseExpr) eval(env *ExpressionEnv) (eval, error) { - l, r, err := bit.arguments(env) - if l == nil || r == nil || err != nil { + l, err := bit.Left.eval(env) + if l == nil || err != nil { + return nil, err + } + + r, err := bit.Right.eval(env) + if r == nil || err != nil { return nil, err } diff --git a/go/vt/vtgate/evalengine/expr_logical.go b/go/vt/vtgate/evalengine/expr_logical.go index 79ed6b97d90..63a7e170adb 100644 --- a/go/vt/vtgate/evalengine/expr_logical.go +++ b/go/vt/vtgate/evalengine/expr_logical.go @@ -25,7 +25,7 @@ import ( type ( LogicalExpr struct { BinaryExpr - op func(left, right boolean) boolean + op func(left, right Expr, env *ExpressionEnv) (boolean, error) opname string } @@ -100,55 +100,100 @@ func (left boolean) not() boolean { } } -func (left boolean) and(right boolean) boolean { +func opAnd(le, re Expr, env *ExpressionEnv) (boolean, error) { // Logical AND. // Evaluates to 1 if all operands are nonzero and not NULL, to 0 if one or more operands are 0, otherwise NULL is returned. + l, err := le.eval(env) + if err != nil { + return boolNULL, err + } + + left := evalIsTruthy(l) + if left == boolFalse { + return boolFalse, nil + } + + r, err := re.eval(env) + if err != nil { + return boolNULL, err + } + right := evalIsTruthy(r) + switch { case left == boolTrue && right == boolTrue: - return boolTrue - case left == boolFalse || right == boolFalse: - return boolFalse + return boolTrue, nil + case right == boolFalse: + return boolFalse, nil default: - return boolNULL + return boolNULL, nil } } -func (left boolean) or(right boolean) boolean { +func opOr(le, re Expr, env *ExpressionEnv) (boolean, error) { // Logical OR. When both operands are non-NULL, the result is 1 if any operand is nonzero, and 0 otherwise. // With a NULL operand, the result is 1 if the other operand is nonzero, and NULL otherwise. // If both operands are NULL, the result is NULL. + l, err := le.eval(env) + if err != nil { + return boolNULL, err + } + + left := evalIsTruthy(l) + if left == boolTrue { + return boolTrue, nil + } + + r, err := re.eval(env) + if err != nil { + return boolNULL, err + } + right := evalIsTruthy(r) + switch { case left == boolNULL: if right == boolTrue { - return boolTrue + return boolTrue, nil } - return boolNULL + return boolNULL, nil case right == boolNULL: - if left == boolTrue { - return boolTrue - } - return boolNULL + return boolNULL, nil default: - if left == boolTrue || right == boolTrue { - return boolTrue + if right == boolTrue { + return boolTrue, nil } - return boolFalse + return boolFalse, nil } } -func (left boolean) xor(right boolean) boolean { +func opXor(le, re Expr, env *ExpressionEnv) (boolean, error) { // Logical XOR. Returns NULL if either operand is NULL. // For non-NULL operands, evaluates to 1 if an odd number of operands is nonzero, otherwise 0 is returned. + l, err := le.eval(env) + if err != nil { + return boolNULL, err + } + + left := evalIsTruthy(l) + if left == boolNULL { + return boolNULL, nil + } + + r, err := re.eval(env) + if err != nil { + return boolNULL, err + } + right := evalIsTruthy(r) + switch { case left == boolNULL || right == boolNULL: - return boolNULL + return boolNULL, nil default: if left != right { - return boolTrue + return boolTrue, nil } - return boolFalse + return boolFalse, nil } } @@ -162,21 +207,18 @@ func (n *NotExpr) eval(env *ExpressionEnv) (eval, error) { func (n *NotExpr) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { _, flags := n.Inner.typeof(env) - return sqltypes.Uint64, flags + return sqltypes.Int64, flags | flagIsBoolean } func (l *LogicalExpr) eval(env *ExpressionEnv) (eval, error) { - left, right, err := l.arguments(env) - if err != nil { - return nil, err - } - return l.op(evalIsTruthy(left), evalIsTruthy(right)).eval(), nil + res, err := l.op(l.Left, l.Right, env) + return res.eval(), err } func (l *LogicalExpr) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { _, f1 := l.Left.typeof(env) _, f2 := l.Right.typeof(env) - return sqltypes.Uint64, f1 | f2 + return sqltypes.Int64, f1 | f2 | flagIsBoolean } func (i *IsExpr) eval(env *ExpressionEnv) (eval, error) { diff --git a/go/vt/vtgate/evalengine/fn_base64.go b/go/vt/vtgate/evalengine/fn_base64.go index 7cdc2bafe98..5b6372710ec 100644 --- a/go/vt/vtgate/evalengine/fn_base64.go +++ b/go/vt/vtgate/evalengine/fn_base64.go @@ -76,12 +76,18 @@ func (call *builtinFromBase64) eval(env *ExpressionEnv) (eval, error) { b := evalToBinary(arg) decoded := make([]byte, mysqlBase64.DecodedLen(len(b.bytes))) if n, err := mysqlBase64.Decode(decoded, b.bytes); err == nil { + if arg.SQLType() == sqltypes.Text || arg.SQLType() == sqltypes.TypeJSON { + return newEvalRaw(sqltypes.Blob, decoded[:n], collationBinary), nil + } return newEvalBinary(decoded[:n]), nil } return nil, nil } func (call *builtinFromBase64) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { - _, f := call.Arguments[0].typeof(env) + tt, f := call.Arguments[0].typeof(env) + if tt == sqltypes.Text || tt == sqltypes.TypeJSON { + return sqltypes.Blob, f | flagNullable + } return sqltypes.VarBinary, f | flagNullable } diff --git a/go/vt/vtgate/evalengine/fn_json.go b/go/vt/vtgate/evalengine/fn_json.go index 75e88c82d75..ed48a325981 100644 --- a/go/vt/vtgate/evalengine/fn_json.go +++ b/go/vt/vtgate/evalengine/fn_json.go @@ -168,7 +168,7 @@ func (call *builtinJSONObject) eval(env *ExpressionEnv) (eval, error) { if err != nil { return nil, err } - val1, err := evalToJSON(val) + val1, err := argToJSON(val) if err != nil { return nil, err } @@ -189,7 +189,7 @@ func (call *builtinJSONArray) eval(env *ExpressionEnv) (eval, error) { if err != nil { return nil, err } - arg1, err := evalToJSON(arg) + arg1, err := argToJSON(arg) if err != nil { return nil, err } diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 360217d092a..5d7fe233491 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -54,6 +54,8 @@ var Cases = []TestCase{ {Run: LikeComparison}, {Run: MultiComparisons}, {Run: IsStatement}, + {Run: NotStatement}, + {Run: LogicalStatement}, {Run: TupleComparisons}, {Run: Comparisons}, {Run: InStatement}, @@ -121,7 +123,7 @@ func JSONObject(yield Query) { func CharsetConversionOperators(yield Query) { var introducers = []string{ - "", "_latin1", "_utf8mb4", "_utf8", "_binary", + "", "_lat21 in1", "_utf8mb4", "_utf8", "_binary", } var contents = []string{ `"foobar"`, `X'4D7953514C'`, @@ -436,6 +438,12 @@ func BitwiseOperators(yield Query) { yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) } } + + for _, lhs := range inputConversions { + for _, rhs := range inputConversions { + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + } + } } } @@ -720,6 +728,23 @@ func IsStatement(yield Query) { } } +func NotStatement(yield Query) { + for _, i := range inputConversions { + yield(fmt.Sprintf("NOT %s", i), nil) + } +} + +func LogicalStatement(yield Query) { + var ops = []string{"AND", "OR", "XOR"} + for _, op := range ops { + for _, l := range inputConversions { + for _, r := range inputConversions { + yield(fmt.Sprintf("%s %s %s", l, op, r), nil) + } + } + } +} + func TupleComparisons(yield Query) { var elems = []string{"NULL", "-1", "0", "1"} var operators = []string{"=", "!=", "<=>", "<", "<=", ">", ">="} @@ -872,5 +897,10 @@ func InStatement(yield Query) { yield(fmt.Sprintf("%s IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil) yield(fmt.Sprintf("%s IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil) yield(fmt.Sprintf("%s IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil) + + yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[0], inputs[1], inputs[2]), nil) + yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil) + yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil) + yield(fmt.Sprintf("%s NOT IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil) }) } diff --git a/go/vt/vtgate/evalengine/testcases/inputs.go b/go/vt/vtgate/evalengine/testcases/inputs.go index 94f818dc0f3..744ffdb05b7 100644 --- a/go/vt/vtgate/evalengine/testcases/inputs.go +++ b/go/vt/vtgate/evalengine/testcases/inputs.go @@ -97,6 +97,12 @@ var inputConversions = []string{ "18446744073709540000e0", "-18446744073709540000e0", "JSON_OBJECT()", "JSON_ARRAY()", + "cast(0 as json)", "cast(1 as json)", + "cast(true as json)", "cast(false as json)", + "cast('{}' as json)", "cast('[]' as json)", + "cast('null' as json)", "cast('true' as json)", "cast('false' as json)", + "cast('1' as json)", "cast('1.1' as json)", "cast('-1.1' as json)", + "cast('\"foo\"' as json)", "cast('invalid' as json)", } const inputPi = "314159265358979323846264338327950288419716939937510582097494459" diff --git a/go/vt/vtgate/evalengine/translate.go b/go/vt/vtgate/evalengine/translate.go index b3751cdfbcf..03225143280 100644 --- a/go/vt/vtgate/evalengine/translate.go +++ b/go/vt/vtgate/evalengine/translate.go @@ -106,14 +106,14 @@ func (ast *astCompiler) translateLogicalExpr(opname string, left, right sqlparse return nil, err } - var logic func(l, r boolean) boolean + var logic func(l, r Expr, env *ExpressionEnv) (boolean, error) switch opname { case "AND": - logic = func(l, r boolean) boolean { return l.and(r) } + logic = func(l, r Expr, env *ExpressionEnv) (boolean, error) { return opAnd(l, r, env) } case "OR": - logic = func(l, r boolean) boolean { return l.or(r) } + logic = func(l, r Expr, env *ExpressionEnv) (boolean, error) { return opOr(l, r, env) } case "XOR": - logic = func(l, r boolean) boolean { return l.xor(r) } + logic = func(l, r Expr, env *ExpressionEnv) (boolean, error) { return opXor(l, r, env) } default: panic("unexpected logical operator") } diff --git a/go/vt/vtgate/evalengine/translate_card.go b/go/vt/vtgate/evalengine/translate_card.go index 650563daaea..f6d50664dce 100644 --- a/go/vt/vtgate/evalengine/translate_card.go +++ b/go/vt/vtgate/evalengine/translate_card.go @@ -134,6 +134,8 @@ func (ast *astCompiler) cardExpr(expr Expr) error { return ast.cardUnary(expr.Inner) case *BitwiseNotExpr: return ast.cardUnary(expr.Inner) + case *NotExpr: + return ast.cardUnary(expr.Inner) case *ArithmeticExpr: return ast.cardBinary(expr.Left, expr.Right) case *LogicalExpr: From c6cef552ce15093aa84552cfa80880b8a230a891 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Mon, 20 Mar 2023 11:08:23 +0100 Subject: [PATCH 3/4] Fix tests Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/testcases/cases.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 5d7fe233491..848a9cb4029 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -123,7 +123,7 @@ func JSONObject(yield Query) { func CharsetConversionOperators(yield Query) { var introducers = []string{ - "", "_lat21 in1", "_utf8mb4", "_utf8", "_binary", + "", "_latin1", "_utf8mb4", "_utf8", "_binary", } var contents = []string{ `"foobar"`, `X'4D7953514C'`, @@ -729,13 +729,16 @@ func IsStatement(yield Query) { } func NotStatement(yield Query) { - for _, i := range inputConversions { - yield(fmt.Sprintf("NOT %s", i), nil) + var ops = []string{"NOT", "!"} + for _, op := range ops { + for _, i := range inputConversions { + yield(fmt.Sprintf("%s %s", op, i), nil) + } } } func LogicalStatement(yield Query) { - var ops = []string{"AND", "OR", "XOR"} + var ops = []string{"AND", "&&", "OR", "||", "XOR"} for _, op := range ops { for _, l := range inputConversions { for _, r := range inputConversions { From 44f2a218ea9845dc71c389a52d537b76c7621ed6 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Mon, 20 Mar 2023 11:31:39 +0100 Subject: [PATCH 4/4] evalengine/compiler: Use helpers where appropriate Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler.go | 4 ++-- go/vt/vtgate/evalengine/compiler_compare.go | 4 ++-- go/vt/vtgate/evalengine/compiler_fn.go | 14 +++++++------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 9138816855d..ff0765362fa 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -316,7 +316,7 @@ func (c *compiler) compileToDecimal(ct ctype, offset int) ctype { } func (c *compiler) compileNullCheck1(ct ctype) *jump { - if ct.Flag&flagNullable != 0 { + if ct.nullable() { j := c.asm.jumpFrom() c.asm.NullCheck1(j) return j @@ -325,7 +325,7 @@ func (c *compiler) compileNullCheck1(ct ctype) *jump { } func (c *compiler) compileNullCheck2(lt, rt ctype) *jump { - if lt.Flag&flagNullable != 0 || rt.Flag&flagNullable != 0 { + if lt.nullable() || rt.nullable() { j := c.asm.jumpFrom() c.asm.NullCheck2(j) return j diff --git a/go/vt/vtgate/evalengine/compiler_compare.go b/go/vt/vtgate/evalengine/compiler_compare.go index bece3bd4aaa..cb57c8ed932 100644 --- a/go/vt/vtgate/evalengine/compiler_compare.go +++ b/go/vt/vtgate/evalengine/compiler_compare.go @@ -271,7 +271,7 @@ func (c *compiler) compileLike(expr *LikeExpr) (ctype, error) { skip := c.compileNullCheck2(lt, rt) - if !sqltypes.IsText(lt.Type) && !sqltypes.IsBinary(lt.Type) { + if !lt.isTextual() { c.asm.Convert_xc(2, sqltypes.VarChar, c.defaultCollation, 0, false) lt.Col = collations.TypedCollation{ Collation: c.defaultCollation, @@ -280,7 +280,7 @@ func (c *compiler) compileLike(expr *LikeExpr) (ctype, error) { } } - if !sqltypes.IsText(rt.Type) && !sqltypes.IsBinary(rt.Type) { + if !rt.isTextual() { c.asm.Convert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) rt.Col = collations.TypedCollation{ Collation: c.defaultCollation, diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go index c2a38db608d..7ada401256a 100644 --- a/go/vt/vtgate/evalengine/compiler_fn.go +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -211,7 +211,7 @@ func (c *compiler) compileFn_REPEAT(expr *builtinRepeat) (ctype, error) { skip := c.compileNullCheck2(str, repeat) switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + case str.isTextual(): default: c.asm.Convert_xc(2, sqltypes.VarChar, c.defaultCollation, 0, false) } @@ -236,7 +236,7 @@ func (c *compiler) compileFn_TO_BASE64(call *builtinToBase64) (ctype, error) { } switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + case str.isTextual(): default: c.asm.Convert_xc(1, t, c.defaultCollation, 0, false) } @@ -267,7 +267,7 @@ func (c *compiler) compileFn_FROM_BASE64(call *builtinFromBase64) (ctype, error) } switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + case str.isTextual(): default: c.asm.Convert_xc(1, t, c.defaultCollation, 0, false) } @@ -287,7 +287,7 @@ func (c *compiler) compileFn_CCASE(call *builtinChangeCase) (ctype, error) { skip := c.compileNullCheck1(str) switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + case str.isTextual(): default: c.asm.Convert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) } @@ -307,7 +307,7 @@ func (c *compiler) compileFn_xLENGTH(call callable, asm_ins func()) (ctype, erro skip := c.compileNullCheck1(str) switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + case str.isTextual(): default: c.asm.Convert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) } @@ -327,7 +327,7 @@ func (c *compiler) compileFn_ASCII(call *builtinASCII) (ctype, error) { skip := c.compileNullCheck1(str) switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + case str.isTextual(): default: c.asm.Convert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) } @@ -360,7 +360,7 @@ func (c *compiler) compileFn_HEX(call *builtinHex) (ctype, error) { switch { case sqltypes.IsNumber(str.Type), sqltypes.IsDecimal(str.Type): c.asm.Fn_HEX_d(col) - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + case str.isTextual(): c.asm.Fn_HEX_c(t, col) default: c.asm.Convert_xc(1, t, c.defaultCollation, 0, false)