diff --git a/go/vt/vtgate/evalengine/arena.go b/go/vt/vtgate/evalengine/arena.go index bda082f99cf..0a4627636ae 100644 --- a/go/vt/vtgate/evalengine/arena.go +++ b/go/vt/vtgate/evalengine/arena.go @@ -60,6 +60,13 @@ func (a *Arena) newEvalDecimal(dec decimal.Decimal, m, d int32) *evalDecimal { return a.newEvalDecimalWithPrec(dec.Clamp(m-d, d), d) } +func (a *Arena) newEvalBool(b bool) *evalInt64 { + if b { + return a.newEvalInt64(1) + } + return a.newEvalInt64(0) +} + func (a *Arena) newEvalInt64(i int64) *evalInt64 { if cap(a.aInt64) > len(a.aInt64) { a.aInt64 = a.aInt64[:len(a.aInt64)+1] diff --git a/go/vt/vtgate/evalengine/arithmetic.go b/go/vt/vtgate/evalengine/arithmetic.go index 5b6c0ee538c..89ecb6fd22f 100644 --- a/go/vt/vtgate/evalengine/arithmetic.go +++ b/go/vt/vtgate/evalengine/arithmetic.go @@ -26,12 +26,17 @@ import ( "vitess.io/vitess/go/sqltypes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine/internal/decimal" ) func dataOutOfRangeError[N1, N2 constraints.Integer | constraints.Float](v1 N1, v2 N2, typ, sign string) error { return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "%s value is out of range in '(%v %s %v)'", typ, v1, sign, v2) } +func dataOutOfRangeErrorDecimal(v1 decimal.Decimal, v2 decimal.Decimal, typ, sign string) error { + return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "%s value is out of range in '(%v %s %v)'", typ, v1.String(), sign, v2.String()) +} + func addNumericWithError(left, right eval) (eval, error) { v1, v2 := makeNumericAndPrioritize(left, right) switch v1 := v1.(type) { @@ -127,6 +132,108 @@ func divideNumericWithError(left, right eval, precise bool) (eval, error) { return mathDiv_xx(v1, v2, divPrecisionIncrement) } +func integerDivideNumericWithError(left, right eval, precise bool) (eval, error) { + v1 := evalToNumeric(left) + v2 := evalToNumeric(right) + switch v1 := v1.(type) { + case *evalInt64: + switch v2 := v2.(type) { + case *evalInt64: + return mathIntDiv_ii(v1, v2) + case *evalUint64: + return mathIntDiv_iu(v1, v2) + case *evalFloat: + return mathIntDiv_di(v1.toDecimal(0, 0), v2.toDecimal(0, 0)) + case *evalDecimal: + return mathIntDiv_di(v1.toDecimal(0, 0), v2) + } + case *evalUint64: + switch v2 := v2.(type) { + case *evalInt64: + return mathIntDiv_ui(v1, v2) + case *evalUint64: + return mathIntDiv_uu(v1, v2) + case *evalFloat: + return mathIntDiv_du(v1.toDecimal(0, 0), v2.toDecimal(0, 0)) + case *evalDecimal: + return mathIntDiv_du(v1.toDecimal(0, 0), v2) + } + case *evalFloat: + switch v2 := v2.(type) { + case *evalUint64: + return mathIntDiv_du(v1.toDecimal(0, 0), v2.toDecimal(0, 0)) + default: + return mathIntDiv_di(v1.toDecimal(0, 0), v2.toDecimal(0, 0)) + } + case *evalDecimal: + switch v2 := v2.(type) { + case *evalUint64: + return mathIntDiv_du(v1, v2.toDecimal(0, 0)) + default: + return mathIntDiv_di(v1, v2.toDecimal(0, 0)) + } + } + + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", evalToSQLValue(v1), evalToSQLValue(v2)) +} + +func modNumericWithError(left, right eval, precise bool) (eval, error) { + v1 := evalToNumeric(left) + v2 := evalToNumeric(right) + + switch v1 := v1.(type) { + case *evalInt64: + switch v2 := v2.(type) { + case *evalInt64: + return mathMod_ii(v1, v2) + case *evalUint64: + return mathMod_iu(v1, v2) + case *evalFloat: + v1f, ok := v1.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathMod_ff(v1f, v2) + case *evalDecimal: + return mathMod_dd(v1.toDecimal(0, 0), v2) + } + case *evalUint64: + switch v2 := v2.(type) { + case *evalInt64: + return mathMod_ui(v1, v2) + case *evalUint64: + return mathMod_uu(v1, v2) + case *evalFloat: + v1f, ok := v1.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathMod_ff(v1f, v2) + case *evalDecimal: + return mathMod_dd(v1.toDecimal(0, 0), v2) + } + case *evalDecimal: + switch v2 := v2.(type) { + case *evalFloat: + v1f, ok := v1.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathMod_ff(v1f, v2) + default: + return mathMod_dd(v1, v2.toDecimal(0, 0)) + } + case *evalFloat: + v2f, ok := v2.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathMod_ff(v1, v2f) + } + + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", evalToSQLValue(v1), evalToSQLValue(v2)) +} + // makeNumericAndPrioritize reorders the input parameters // to be Float64, Decimal, Uint64, Int64. func makeNumericAndPrioritize(left, right eval) (evalNumeric, evalNumeric) { @@ -162,28 +269,68 @@ func mathAdd_ii0(v1, v2 int64) (int64, error) { return result, nil } -func mathSub_ii(v1, v2 int64) (*evalInt64, error) { - result, err := mathSub_ii0(v1, v2) - return newEvalInt64(result), err +func mathAdd_ui(v1 uint64, v2 int64) (*evalUint64, error) { + result, err := mathAdd_ui0(v1, v2) + return newEvalUint64(result), err } -func mathSub_ii0(v1, v2 int64) (int64, error) { - result := v1 - v2 - if (result < v1) != (v2 > 0) { - return 0, dataOutOfRangeError(v1, v2, "BIGINT", "-") +func mathAdd_ui0(v1 uint64, v2 int64) (uint64, error) { + result := v1 + uint64(v2) + if v2 < 0 && v1 < uint64(-v2) || v2 > 0 && (result < v1 || result < uint64(v2)) { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+") } return result, nil } -func mathMul_ii(v1, v2 int64) (*evalInt64, error) { - result, err := mathMul_ii0(v1, v2) +func mathAdd_uu(v1, v2 uint64) (*evalUint64, error) { + result, err := mathAdd_uu0(v1, v2) + return newEvalUint64(result), err +} + +func mathAdd_uu0(v1, v2 uint64) (uint64, error) { + result := v1 + v2 + if result < v1 || result < v2 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+") + } + return result, nil +} + +var errDecimalOutOfRange = vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "DECIMAL value is out of range") + +func mathAdd_fx(v1 float64, v2 evalNumeric) (*evalFloat, error) { + v2f, ok := v2.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathAdd_ff(v1, v2f.f), nil +} + +func mathAdd_ff(v1, v2 float64) *evalFloat { + return newEvalFloat(v1 + v2) +} + +func mathAdd_dx(v1 *evalDecimal, v2 evalNumeric) *evalDecimal { + return mathAdd_dd(v1, v2.toDecimal(0, 0)) +} + +func mathAdd_dd(v1, v2 *evalDecimal) *evalDecimal { + return newEvalDecimalWithPrec(v1.dec.Add(v2.dec), maxprec(v1.length, v2.length)) +} + +func mathAdd_dd0(v1, v2 *evalDecimal) { + v1.dec = v1.dec.Add(v2.dec) + v1.length = maxprec(v1.length, v2.length) +} + +func mathSub_ii(v1, v2 int64) (*evalInt64, error) { + result, err := mathSub_ii0(v1, v2) return newEvalInt64(result), err } -func mathMul_ii0(v1, v2 int64) (int64, error) { - result := v1 * v2 - if v1 != 0 && result/v1 != v2 { - return 0, dataOutOfRangeError(v1, v2, "BIGINT", "*") +func mathSub_ii0(v1, v2 int64) (int64, error) { + result := v1 - v2 + if (result < v1) != (v2 > 0) { + return 0, dataOutOfRangeError(v1, v2, "BIGINT", "-") } return result, nil } @@ -200,19 +347,6 @@ func mathSub_iu0(v1 int64, v2 uint64) (uint64, error) { return mathSub_uu0(uint64(v1), v2) } -func mathAdd_ui(v1 uint64, v2 int64) (*evalUint64, error) { - result, err := mathAdd_ui0(v1, v2) - return newEvalUint64(result), err -} - -func mathAdd_ui0(v1 uint64, v2 int64) (uint64, error) { - result := v1 + uint64(v2) - if v2 < 0 && v1 < uint64(-v2) || v2 > 0 && (result < v1 || result < uint64(v2)) { - return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+") - } - return result, nil -} - func mathSub_ui(v1 uint64, v2 int64) (*evalUint64, error) { result, err := mathSub_ui0(v1, v2) return newEvalUint64(result), err @@ -229,45 +363,82 @@ func mathSub_ui0(v1 uint64, v2 int64) (uint64, error) { return mathSub_uu0(v1, uint64(v2)) } -func mathMul_ui(v1 uint64, v2 int64) (*evalUint64, error) { - result, err := mathMul_ui0(v1, v2) +func mathSub_uu(v1, v2 uint64) (*evalUint64, error) { + result, err := mathSub_uu0(v1, v2) return newEvalUint64(result), err } -func mathMul_ui0(v1 uint64, v2 int64) (uint64, error) { - if v1 == 0 || v2 == 0 { - return 0, nil +func mathSub_uu0(v1, v2 uint64) (uint64, error) { + result := v1 - v2 + if v2 > v1 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-") } - if v2 < 0 { - return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "*") + return result, nil +} + +func mathSub_fx(v1 float64, v2 evalNumeric) (*evalFloat, error) { + v2f, ok := v2.toFloat() + if !ok { + return nil, errDecimalOutOfRange } - return mathMul_uu0(v1, uint64(v2)) + return mathSub_ff(v1, v2f.f), nil } -func mathAdd_uu(v1, v2 uint64) (*evalUint64, error) { - result, err := mathAdd_uu0(v1, v2) - return newEvalUint64(result), err +func mathSub_xf(v1 evalNumeric, v2 float64) (*evalFloat, error) { + v1f, ok := v1.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathSub_ff(v1f.f, v2), nil } -func mathAdd_uu0(v1, v2 uint64) (uint64, error) { - result := v1 + v2 - if result < v1 || result < v2 { - return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+") +func mathSub_ff(v1, v2 float64) *evalFloat { + return newEvalFloat(v1 - v2) +} + +func mathSub_dx(v1 *evalDecimal, v2 evalNumeric) *evalDecimal { + return mathSub_dd(v1, v2.toDecimal(0, 0)) +} + +func mathSub_xd(v1 evalNumeric, v2 *evalDecimal) *evalDecimal { + return mathSub_dd(v1.toDecimal(0, 0), v2) +} + +func mathSub_dd(v1, v2 *evalDecimal) *evalDecimal { + return newEvalDecimalWithPrec(v1.dec.Sub(v2.dec), maxprec(v1.length, v2.length)) +} + +func mathSub_dd0(v1, v2 *evalDecimal) { + v1.dec = v1.dec.Sub(v2.dec) + v1.length = maxprec(v1.length, v2.length) +} + +func mathMul_ii(v1, v2 int64) (*evalInt64, error) { + result, err := mathMul_ii0(v1, v2) + return newEvalInt64(result), err +} + +func mathMul_ii0(v1, v2 int64) (int64, error) { + result := v1 * v2 + if v1 != 0 && result/v1 != v2 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT", "*") } return result, nil } -func mathSub_uu(v1, v2 uint64) (*evalUint64, error) { - result, err := mathSub_uu0(v1, v2) +func mathMul_ui(v1 uint64, v2 int64) (*evalUint64, error) { + result, err := mathMul_ui0(v1, v2) return newEvalUint64(result), err } -func mathSub_uu0(v1, v2 uint64) (uint64, error) { - result := v1 - v2 - if v2 > v1 { - return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-") +func mathMul_ui0(v1 uint64, v2 int64) (uint64, error) { + if v1 == 0 || v2 == 0 { + return 0, nil } - return result, nil + if v2 < 0 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "*") + } + return mathMul_uu0(v1, uint64(v2)) } func mathMul_uu(v1, v2 uint64) (*evalUint64, error) { @@ -286,28 +457,6 @@ func mathMul_uu0(v1, v2 uint64) (uint64, error) { return result, nil } -var errDecimalOutOfRange = vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "DECIMAL value is out of range") - -func mathAdd_fx(v1 float64, v2 evalNumeric) (*evalFloat, error) { - v2f, ok := v2.toFloat() - if !ok { - return nil, errDecimalOutOfRange - } - return mathAdd_ff(v1, v2f.f), nil -} - -func mathAdd_ff(v1, v2 float64) *evalFloat { - return newEvalFloat(v1 + v2) -} - -func mathSub_fx(v1 float64, v2 evalNumeric) (*evalFloat, error) { - v2f, ok := v2.toFloat() - if !ok { - return nil, errDecimalOutOfRange - } - return mathSub_ff(v1, v2f.f), nil -} - func mathMul_fx(v1 float64, v2 evalNumeric) (eval, error) { v2f, ok := v2.toFloat() if !ok { @@ -327,36 +476,6 @@ func maxprec(a, b int32) int32 { return b } -func mathAdd_dx(v1 *evalDecimal, v2 evalNumeric) *evalDecimal { - return mathAdd_dd(v1, v2.toDecimal(0, 0)) -} - -func mathAdd_dd(v1, v2 *evalDecimal) *evalDecimal { - return newEvalDecimalWithPrec(v1.dec.Add(v2.dec), maxprec(v1.length, v2.length)) -} - -func mathAdd_dd0(v1, v2 *evalDecimal) { - v1.dec = v1.dec.Add(v2.dec) - v1.length = maxprec(v1.length, v2.length) -} - -func mathSub_dx(v1 *evalDecimal, v2 evalNumeric) *evalDecimal { - return mathSub_dd(v1, v2.toDecimal(0, 0)) -} - -func mathSub_xd(v1 evalNumeric, v2 *evalDecimal) *evalDecimal { - return mathSub_dd(v1.toDecimal(0, 0), v2) -} - -func mathSub_dd(v1, v2 *evalDecimal) *evalDecimal { - return newEvalDecimalWithPrec(v1.dec.Sub(v2.dec), maxprec(v1.length, v2.length)) -} - -func mathSub_dd0(v1, v2 *evalDecimal) { - v1.dec = v1.dec.Sub(v2.dec) - v1.length = maxprec(v1.length, v2.length) -} - func mathMul_dx(v1 *evalDecimal, v2 evalNumeric) *evalDecimal { return mathMul_dd(v1, v2.toDecimal(0, 0)) } @@ -413,16 +532,174 @@ func mathDiv_ff0(v1, v2 float64) (float64, error) { return result, nil } -func mathSub_xf(v1 evalNumeric, v2 float64) (*evalFloat, error) { - v1f, ok := v1.toFloat() +func mathIntDiv_ii(v1, v2 *evalInt64) (eval, error) { + if v2.i == 0 { + return nil, nil + } + result := v1.i / v2.i + return newEvalInt64(result), nil +} + +func mathIntDiv_iu(v1 *evalInt64, v2 *evalUint64) (eval, error) { + if v2.u == 0 { + return nil, nil + } + result, err := mathIntDiv_iu0(v1.i, v2.u) + return newEvalUint64(result), err +} + +func mathIntDiv_iu0(v1 int64, v2 uint64) (uint64, error) { + if v1 < 0 { + if v2 >= math.MaxInt64 { + // We know here that v2 is always so large the result + // must be 0. + return 0, nil + } + result := v1 / int64(v2) + if result < 0 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "DIV") + } + return uint64(result), nil + + } + return uint64(v1) / v2, nil +} + +func mathIntDiv_ui(v1 *evalUint64, v2 *evalInt64) (eval, error) { + if v2.i == 0 { + return nil, nil + } + result, err := mathIntDiv_ui0(v1.u, v2.i) + return newEvalUint64(result), err +} + +func mathIntDiv_ui0(v1 uint64, v2 int64) (uint64, error) { + if v2 < 0 { + if v1 >= math.MaxInt64 { + // We know that v1 is always large here and with v2, the result + // must be at least -1 so we can't store this in the available range. + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "DIV") + } + // Safe to cast since we know it fits in int64 when we get here. + result := int64(v1) / v2 + if result < 0 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "DIV") + } + return uint64(result), nil + } + return v1 / uint64(v2), nil +} + +func mathIntDiv_uu(v1, v2 *evalUint64) (eval, error) { + if v2.u == 0 { + return nil, nil + } + return newEvalUint64(v1.u / v2.u), nil +} + +func mathIntDiv_di(v1, v2 *evalDecimal) (eval, error) { + if v2.dec.IsZero() { + return nil, nil + } + result, err := mathIntDiv_di0(v1, v2) + return newEvalInt64(result), err +} + +func mathIntDiv_di0(v1, v2 *evalDecimal) (int64, error) { + div, _ := v1.dec.QuoRem(v2.dec, 0) + result, ok := div.Int64() if !ok { - return nil, errDecimalOutOfRange + return 0, dataOutOfRangeErrorDecimal(v1.dec, v2.dec, "BIGINT", "DIV") } - return mathSub_ff(v1f.f, v2), nil + return result, nil } -func mathSub_ff(v1, v2 float64) *evalFloat { - return newEvalFloat(v1 - v2) +func mathIntDiv_du(v1, v2 *evalDecimal) (eval, error) { + if v2.dec.IsZero() { + return nil, nil + } + result, err := mathIntDiv_du0(v1, v2) + return newEvalUint64(result), err +} + +func mathIntDiv_du0(v1, v2 *evalDecimal) (uint64, error) { + div, _ := v1.dec.QuoRem(v2.dec, 0) + result, ok := div.Uint64() + if !ok { + return 0, dataOutOfRangeErrorDecimal(v1.dec, v2.dec, "BIGINT UNSIGNED", "DIV") + } + return result, nil +} + +func mathMod_ii(v1, v2 *evalInt64) (eval, error) { + if v2.i == 0 { + return nil, nil + } + return newEvalInt64(v1.i % v2.i), nil +} + +func mathMod_iu(v1 *evalInt64, v2 *evalUint64) (eval, error) { + if v2.u == 0 { + return nil, nil + } + return newEvalInt64(mathMod_iu0(v1.i, v2.u)), nil +} + +func mathMod_iu0(v1 int64, v2 uint64) int64 { + if v1 == math.MinInt64 && v2 == math.MaxInt64+1 { + return 0 + } + if v2 > math.MaxInt64 { + return v1 + } + return v1 % int64(v2) +} + +func mathMod_ui(v1 *evalUint64, v2 *evalInt64) (eval, error) { + if v2.i == 0 { + return nil, nil + } + result, err := mathMod_ui0(v1.u, v2.i) + return newEvalUint64(result), err +} + +func mathMod_ui0(v1 uint64, v2 int64) (uint64, error) { + if v2 < 0 { + return v1 % uint64(-v2), nil + } + return v1 % uint64(v2), nil +} + +func mathMod_uu(v1, v2 *evalUint64) (eval, error) { + if v2.u == 0 { + return nil, nil + } + return newEvalUint64(v1.u % v2.u), nil +} + +func mathMod_ff(v1, v2 *evalFloat) (eval, error) { + if v2.f == 0.0 { + return nil, nil + } + return newEvalFloat(math.Mod(v1.f, v2.f)), nil +} + +func mathMod_dd(v1, v2 *evalDecimal) (eval, error) { + if v2.dec.IsZero() { + return nil, nil + } + + dec, prec := mathMod_dd0(v1, v2) + return newEvalDecimalWithPrec(dec, prec), nil +} + +func mathMod_dd0(v1, v2 *evalDecimal) (decimal.Decimal, int32) { + length := v1.length + if v2.length > length { + length = v2.length + } + _, rem := v1.dec.QuoRem(v2.dec, 0) + return rem, length } func parseStringToFloat(str string) float64 { diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 24ff089c047..ff0765362fa 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -148,6 +148,12 @@ func (c *compiler) compileExpr(expr Expr) (ctype, error) { case *InExpr: return c.compileIn(expr) + case *NotExpr: + return c.compileNot(expr) + + case *LogicalExpr: + return c.compileLogical(expr) + case callable: return c.compileFn(expr) @@ -310,7 +316,7 @@ func (c *compiler) compileToDecimal(ct ctype, offset int) ctype { } func (c *compiler) compileNullCheck1(ct ctype) *jump { - if ct.Flag&flagNullable != 0 { + if ct.nullable() { j := c.asm.jumpFrom() c.asm.NullCheck1(j) return j @@ -319,7 +325,7 @@ func (c *compiler) compileNullCheck1(ct ctype) *jump { } func (c *compiler) compileNullCheck2(lt, rt ctype) *jump { - if lt.Flag&flagNullable != 0 || rt.Flag&flagNullable != 0 { + if lt.nullable() || rt.nullable() { j := c.asm.jumpFrom() c.asm.NullCheck2(j) return j diff --git a/go/vt/vtgate/evalengine/compiler_arithmetic.go b/go/vt/vtgate/evalengine/compiler_arithmetic.go index a81afacc53a..972e0b0b539 100644 --- a/go/vt/vtgate/evalengine/compiler_arithmetic.go +++ b/go/vt/vtgate/evalengine/compiler_arithmetic.go @@ -16,7 +16,11 @@ limitations under the License. package evalengine -import "vitess.io/vitess/go/sqltypes" +import ( + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" +) func (c *compiler) compileNegate(expr *NegateExpr) (ctype, error) { arg, err := c.compileExpr(expr.Inner) @@ -64,8 +68,12 @@ func (c *compiler) compileArithmetic(expr *ArithmeticExpr) (ctype, error) { return c.compileArithmeticMul(expr.Left, expr.Right) case *opArithDiv: return c.compileArithmeticDiv(expr.Left, expr.Right) + case *opArithIntDiv: + return c.compileArithmeticIntDiv(expr.Left, expr.Right) + case *opArithMod: + return c.compileArithmeticMod(expr.Left, expr.Right) default: - panic("unexpected arithmetic operator") + return ctype{}, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented") } } @@ -92,6 +100,7 @@ func (c *compiler) compileArithmeticAdd(left, right Expr) (ctype, error) { if err != nil { return ctype{}, err } + skip1 := c.compileNullCheck1(lt) rt, err := c.compileExpr(right) if err != nil { @@ -99,7 +108,7 @@ func (c *compiler) compileArithmeticAdd(left, right Expr) (ctype, error) { } swap := false - skip := c.compileNullCheck2(lt, rt) + skip2 := c.compileNullCheck2(lt, rt) lt = c.compileToNumeric(lt, 2) rt = c.compileToNumeric(rt, 1) lt, rt, swap = c.compileNumericPriority(lt, rt) @@ -136,7 +145,8 @@ func (c *compiler) compileArithmeticAdd(left, right Expr) (ctype, error) { sumtype = sqltypes.Float64 } - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) return ctype{Type: sumtype, Col: collationNumeric}, nil } @@ -145,13 +155,14 @@ func (c *compiler) compileArithmeticSub(left, right Expr) (ctype, error) { if err != nil { return ctype{}, err } + skip1 := c.compileNullCheck1(lt) rt, err := c.compileExpr(right) if err != nil { return ctype{}, err } - skip := c.compileNullCheck2(lt, rt) + skip2 := c.compileNullCheck2(lt, rt) lt = c.compileToNumeric(lt, 2) rt = c.compileToNumeric(rt, 1) @@ -213,7 +224,8 @@ func (c *compiler) compileArithmeticSub(left, right Expr) (ctype, error) { panic("did not compile?") } - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) return ctype{Type: subtype, Col: collationNumeric}, nil } @@ -222,6 +234,7 @@ func (c *compiler) compileArithmeticMul(left, right Expr) (ctype, error) { if err != nil { return ctype{}, err } + skip1 := c.compileNullCheck1(lt) rt, err := c.compileExpr(right) if err != nil { @@ -229,7 +242,7 @@ func (c *compiler) compileArithmeticMul(left, right Expr) (ctype, error) { } swap := false - skip := c.compileNullCheck2(lt, rt) + skip2 := c.compileNullCheck2(lt, rt) lt = c.compileToNumeric(lt, 2) rt = c.compileToNumeric(rt, 1) lt, rt, swap = c.compileNumericPriority(lt, rt) @@ -266,7 +279,8 @@ func (c *compiler) compileArithmeticMul(left, right Expr) (ctype, error) { multype = sqltypes.Decimal } - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) return ctype{Type: multype, Col: collationNumeric}, nil } @@ -275,27 +289,183 @@ func (c *compiler) compileArithmeticDiv(left, right Expr) (ctype, error) { if err != nil { return ctype{}, err } + skip1 := c.compileNullCheck1(lt) rt, err := c.compileExpr(right) if err != nil { return ctype{}, err } + skip2 := c.compileNullCheck2(lt, rt) - skip := c.compileNullCheck2(lt, rt) lt = c.compileToNumeric(lt, 2) rt = c.compileToNumeric(rt, 1) + ct := ctype{Col: collationNumeric, Flag: flagNullable} if lt.Type == sqltypes.Float64 || rt.Type == sqltypes.Float64 { + ct.Type = sqltypes.Float64 c.compileToFloat(lt, 2) c.compileToFloat(rt, 1) c.asm.Div_ff() - c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil } else { + ct.Type = sqltypes.Decimal c.compileToDecimal(lt, 2) c.compileToDecimal(rt, 1) c.asm.Div_dd() - c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Decimal, Col: collationNumeric}, nil } + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) + return ct, nil +} + +func (c *compiler) compileArithmeticIntDiv(left, right Expr) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + skip1 := c.compileNullCheck1(lt) + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + skip2 := c.compileNullCheck2(lt, rt) + lt = c.compileToNumeric(lt, 2) + rt = c.compileToNumeric(rt, 1) + + ct := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable} + switch lt.Type { + case sqltypes.Int64: + switch rt.Type { + case sqltypes.Int64: + c.asm.IntDiv_ii() + case sqltypes.Uint64: + ct.Type = sqltypes.Uint64 + c.asm.IntDiv_iu() + case sqltypes.Float64: + c.asm.Convert_xd(2, 0, 0) + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_di() + case sqltypes.Decimal: + c.asm.Convert_xd(2, 0, 0) + c.asm.IntDiv_di() + } + case sqltypes.Uint64: + switch rt.Type { + case sqltypes.Int64: + c.asm.IntDiv_ui() + case sqltypes.Uint64: + ct.Type = sqltypes.Uint64 + c.asm.IntDiv_uu() + case sqltypes.Float64: + c.asm.Convert_xd(2, 0, 0) + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_du() + case sqltypes.Decimal: + c.asm.Convert_xd(2, 0, 0) + c.asm.IntDiv_du() + } + case sqltypes.Float64: + switch rt.Type { + case sqltypes.Decimal: + c.asm.Convert_xd(2, 0, 0) + c.asm.IntDiv_di() + case sqltypes.Uint64: + ct.Type = sqltypes.Uint64 + c.asm.Convert_xd(2, 0, 0) + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_du() + default: + c.asm.Convert_xd(2, 0, 0) + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_di() + } + case sqltypes.Decimal: + switch rt.Type { + case sqltypes.Decimal: + c.asm.IntDiv_di() + case sqltypes.Uint64: + ct.Type = sqltypes.Uint64 + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_du() + default: + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_di() + } + } + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) + return ct, nil +} + +func (c *compiler) compileArithmeticMod(left, right Expr) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + skip1 := c.compileNullCheck1(lt) + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + skip2 := c.compileNullCheck2(lt, rt) + lt = c.compileToNumeric(lt, 2) + rt = c.compileToNumeric(rt, 1) + + ct := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable} + switch lt.Type { + case sqltypes.Int64: + ct.Type = sqltypes.Int64 + switch rt.Type { + case sqltypes.Int64: + c.asm.Mod_ii() + case sqltypes.Uint64: + c.asm.Mod_iu() + case sqltypes.Float64: + ct.Type = sqltypes.Float64 + c.asm.Convert_xf(2) + c.asm.Mod_ff() + case sqltypes.Decimal: + ct.Type = sqltypes.Decimal + c.asm.Convert_xd(2, 0, 0) + c.asm.Mod_dd() + } + case sqltypes.Uint64: + ct.Type = sqltypes.Uint64 + switch rt.Type { + case sqltypes.Int64: + c.asm.Mod_ui() + case sqltypes.Uint64: + c.asm.Mod_uu() + case sqltypes.Float64: + ct.Type = sqltypes.Float64 + c.asm.Convert_xf(2) + c.asm.Mod_ff() + case sqltypes.Decimal: + ct.Type = sqltypes.Decimal + c.asm.Convert_xd(2, 0, 0) + c.asm.Mod_dd() + } + case sqltypes.Decimal: + ct.Type = sqltypes.Decimal + switch rt.Type { + case sqltypes.Float64: + ct.Type = sqltypes.Float64 + c.asm.Convert_xf(2) + c.asm.Mod_ff() + default: + c.asm.Convert_xd(1, 0, 0) + c.asm.Mod_dd() + } + case sqltypes.Float64: + ct.Type = sqltypes.Float64 + c.asm.Convert_xf(1) + c.asm.Mod_ff() + } + + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) + return ct, nil } diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 4e484e9996d..0bf28cac748 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -363,7 +363,7 @@ func (asm *assembler) BitwiseNot_u() { func (asm *assembler) Cmp_eq() { asm.adjustStack(1) asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp == 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp == 0) vm.sp++ return 1 }, "CMPFLAG EQ") @@ -375,7 +375,7 @@ func (asm *assembler) Cmp_eq_n() { if vm.flags.null { vm.stack[vm.sp] = nil } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp == 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp == 0) } vm.sp++ return 1 @@ -385,7 +385,7 @@ func (asm *assembler) Cmp_eq_n() { func (asm *assembler) Cmp_ge() { asm.adjustStack(1) asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp >= 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp >= 0) vm.sp++ return 1 }, "CMPFLAG GE") @@ -397,7 +397,7 @@ func (asm *assembler) Cmp_ge_n() { if vm.flags.null { vm.stack[vm.sp] = nil } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp >= 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp >= 0) } vm.sp++ return 1 @@ -407,7 +407,7 @@ func (asm *assembler) Cmp_ge_n() { func (asm *assembler) Cmp_gt() { asm.adjustStack(1) asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp > 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp > 0) vm.sp++ return 1 }, "CMPFLAG GT") @@ -419,7 +419,7 @@ func (asm *assembler) Cmp_gt_n() { if vm.flags.null { vm.stack[vm.sp] = nil } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp > 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp > 0) } vm.sp++ return 1 @@ -429,7 +429,7 @@ func (asm *assembler) Cmp_gt_n() { func (asm *assembler) Cmp_le() { asm.adjustStack(1) asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp <= 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp <= 0) vm.sp++ return 1 }, "CMPFLAG LE") @@ -441,7 +441,7 @@ func (asm *assembler) Cmp_le_n() { if vm.flags.null { vm.stack[vm.sp] = nil } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp <= 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp <= 0) } vm.sp++ return 1 @@ -451,7 +451,7 @@ func (asm *assembler) Cmp_le_n() { func (asm *assembler) Cmp_lt() { asm.adjustStack(1) asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp < 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp < 0) vm.sp++ return 1 }, "CMPFLAG LT") @@ -463,7 +463,7 @@ func (asm *assembler) Cmp_lt_n() { if vm.flags.null { vm.stack[vm.sp] = nil } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp < 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp < 0) } vm.sp++ return 1 @@ -472,7 +472,7 @@ func (asm *assembler) Cmp_lt_n() { func (asm *assembler) Cmp_ne() { asm.adjustStack(1) asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp != 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp != 0) vm.sp++ return 1 }, "CMPFLAG NE") @@ -484,7 +484,7 @@ func (asm *assembler) Cmp_ne_n() { if vm.flags.null { vm.stack[vm.sp] = nil } else { - vm.stack[vm.sp] = newEvalBool(vm.flags.cmp != 0) + vm.stack[vm.sp] = vm.arena.newEvalBool(vm.flags.cmp != 0) } vm.sp++ return 1 @@ -701,7 +701,7 @@ func (asm *assembler) CmpTupleNullsafe() { var equals bool equals, vm.err = evalCompareTuplesNullSafe(l.t, r.t) - vm.stack[vm.sp-2] = newEvalBool(equals) + vm.stack[vm.sp-2] = vm.arena.newEvalBool(equals) vm.sp -= 1 return 1 }, "CMP NULLSAFE TUPLE(SP-2), TUPLE(SP-1)") @@ -719,11 +719,24 @@ func (asm *assembler) Collate(col collations.ID) { func (asm *assembler) Convert_bB(offset int) { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-offset] - vm.stack[vm.sp-offset] = newEvalBool(arg != nil && parseStringToFloat(arg.(*evalBytes).string()) != 0.0) + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(arg != nil && parseStringToFloat(arg.(*evalBytes).string()) != 0.0) return 1 }, "CONV VARBINARY(SP-%d), BOOL", offset) } +func (asm *assembler) Convert_jB(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalJSON) + switch arg.Type() { + case json.TypeNumber: + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(parseStringToFloat(arg.String()) != 0.0) + default: + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(true) + } + return 1 + }, "CONV JSON(SP-%d), BOOL", offset) +} + func (asm *assembler) Convert_bj(offset int) { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-offset].(*evalBytes) @@ -732,6 +745,14 @@ func (asm *assembler) Convert_bj(offset int) { }, "CONV VARBINARY(SP-%d), JSON", offset) } +func (asm *assembler) ConvertArg_cj(offset int) { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-offset].(*evalBytes) + vm.stack[vm.sp-offset], vm.err = evalConvertArg_cj(arg) + return 1 + }, "CONVA VARCHAR(SP-%d), JSON", offset) +} + func (asm *assembler) Convert_cj(offset int) { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-offset].(*evalBytes) @@ -743,7 +764,7 @@ func (asm *assembler) Convert_cj(offset int) { func (asm *assembler) Convert_dB(offset int) { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-offset] - vm.stack[vm.sp-offset] = newEvalBool(arg != nil && !arg.(*evalDecimal).dec.IsZero()) + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(arg != nil && !arg.(*evalDecimal).dec.IsZero()) return 1 }, "CONV DECIMAL(SP-%d), BOOL", offset) } @@ -763,7 +784,7 @@ func (asm *assembler) Convert_dbit(offset int) { func (asm *assembler) Convert_fB(offset int) { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-offset] - vm.stack[vm.sp-offset] = newEvalBool(arg != nil && arg.(*evalFloat).f != 0.0) + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(arg != nil && arg.(*evalFloat).f != 0.0) return 1 }, "CONV FLOAT64(SP-%d), BOOL", offset) } @@ -790,7 +811,7 @@ func (asm *assembler) Convert_hex(offset int) { func (asm *assembler) Convert_iB(offset int) { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-offset] - vm.stack[vm.sp-offset] = newEvalBool(arg != nil && arg.(*evalInt64).i != 0) + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(arg != nil && arg.(*evalInt64).i != 0) return 1 }, "CONV INT64(SP-%d), BOOL", offset) } @@ -848,7 +869,7 @@ func (asm *assembler) Convert_Nj(offset int) { func (asm *assembler) Convert_uB(offset int) { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-offset] - vm.stack[vm.sp-offset] = newEvalBool(arg != nil && arg.(*evalUint64).u != 0) + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(arg != nil && arg.(*evalUint64).u != 0) return 1 }, "CONV UINT64(SP-%d), BOOL", offset) } @@ -984,6 +1005,203 @@ func (asm *assembler) Div_ff() { }, "DIV FLOAT64(SP-2), FLOAT64(SP-1)") } +func (asm *assembler) IntDiv_ii() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalInt64) + if r.i == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.i = l.i / r.i + } + vm.sp-- + return 1 + }, "INTDIV INT64(SP-2), INT64(SP-1)") +} + +func (asm *assembler) IntDiv_iu() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalUint64) + if r.u == 0 { + vm.stack[vm.sp-2] = nil + } else { + r.u, vm.err = mathIntDiv_iu0(l.i, r.u) + vm.stack[vm.sp-2] = r + } + vm.sp-- + return 1 + }, "INTDIV INT64(SP-2), UINT64(SP-1)") +} + +func (asm *assembler) IntDiv_ui() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalInt64) + if r.i == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.u, vm.err = mathIntDiv_ui0(l.u, r.i) + } + vm.sp-- + return 1 + }, "INTDIV UINT64(SP-2), INT64(SP-1)") +} + +func (asm *assembler) IntDiv_uu() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + if r.u == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.u = l.u / r.u + } + vm.sp-- + return 1 + }, "INTDIV UINT64(SP-2), UINT64(SP-1)") +} + +func (asm *assembler) IntDiv_di() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + if r.dec.IsZero() { + vm.stack[vm.sp-2] = nil + } else { + var res int64 + res, vm.err = mathIntDiv_di0(l, r) + vm.stack[vm.sp-2] = vm.arena.newEvalInt64(res) + } + vm.sp-- + return 1 + }, "INTDIV DECIMAL(SP-2), DECIMAL(SP-1)") +} + +func (asm *assembler) IntDiv_du() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + if r.dec.IsZero() { + vm.stack[vm.sp-2] = nil + } else { + var res uint64 + res, vm.err = mathIntDiv_du0(l, r) + vm.stack[vm.sp-2] = vm.arena.newEvalUint64(res) + } + vm.sp-- + return 1 + }, "UINTDIV DECIMAL(SP-2), DECIMAL(SP-1)") +} + +func (asm *assembler) Mod_ii() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalInt64) + if r.i == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.i = l.i % r.i + } + vm.sp-- + return 1 + }, "MOD INT64(SP-2), INT64(SP-1)") +} + +func (asm *assembler) Mod_iu() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalUint64) + if r.u == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.i = mathMod_iu0(l.i, r.u) + } + vm.sp-- + return 1 + }, "MOD INT64(SP-2), UINT64(SP-1)") +} + +func (asm *assembler) Mod_ui() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalInt64) + if r.i == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.u, vm.err = mathMod_ui0(l.u, r.i) + } + vm.sp-- + return 1 + }, "MOD UINT64(SP-2), INT64(SP-1)") +} + +func (asm *assembler) Mod_uu() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + if r.u == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.u = l.u % r.u + } + vm.sp-- + return 1 + }, "MOD UINT64(SP-2), UINT64(SP-1)") +} + +func (asm *assembler) Mod_ff() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalFloat) + r := vm.stack[vm.sp-1].(*evalFloat) + if r.f == 0.0 { + vm.stack[vm.sp-2] = nil + } else { + l.f = math.Mod(l.f, r.f) + } + vm.sp-- + return 1 + }, "MOD FLOAT64(SP-2), FLOAT64(SP-1)") +} + +func (asm *assembler) Mod_dd() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + if r.dec.IsZero() { + vm.stack[vm.sp-2] = nil + } else { + l.dec, l.length = mathMod_dd0(l, r) + } + vm.sp-- + return 1 + }, "MOD DECIMAL(SP-2), DECIMAL(SP-1)") +} + func (asm *assembler) Fn_ASCII() { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-1].(*evalBytes) @@ -1182,7 +1400,7 @@ func (asm *assembler) Fn_COLLATION(col collations.TypedCollation) { }, "FN COLLATION (SP-1)") } -func (asm *assembler) Fn_FROM_BASE64() { +func (asm *assembler) Fn_FROM_BASE64(t sqltypes.Type) { asm.emit(func(vm *VirtualMachine) int { str := vm.stack[vm.sp-1].(*evalBytes) @@ -1193,7 +1411,7 @@ func (asm *assembler) Fn_FROM_BASE64() { vm.stack[vm.sp-1] = nil return 1 } - str.tt = int16(sqltypes.VarBinary) + str.tt = int16(t) str.bytes = decoded[:n] return 1 }, "FN FROM_BASE64 VARCHAR(SP-1)") @@ -1242,7 +1460,7 @@ func (asm *assembler) Fn_JSON_CONTAINS_PATH(match jsonMatch, paths []*json.Path) break } } - vm.stack[vm.sp-1] = newEvalBool(matched) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(matched) return 1 }, "FN JSON_CONTAINS_PATH, SP-1, 'one', [static]") case jsonMatchAll: @@ -1256,7 +1474,7 @@ func (asm *assembler) Fn_JSON_CONTAINS_PATH(match jsonMatch, paths []*json.Path) break } } - vm.stack[vm.sp-1] = newEvalBool(matched) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(matched) return 1 }, "FN JSON_CONTAINS_PATH, SP-1, 'all', [static]") } @@ -1615,7 +1833,7 @@ func (asm *assembler) In_table(not bool, table map[vthash.Hash]struct{}) { vm.hash.Reset() lhs.(hashable).Hash(&vm.hash) _, in := table[vm.hash.Sum128()] - vm.stack[vm.sp-1] = newEvalBool(!in) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(!in) } return 1 }, "NOT IN (SP-1), [static table]") @@ -1626,7 +1844,7 @@ func (asm *assembler) In_table(not bool, table map[vthash.Hash]struct{}) { vm.hash.Reset() lhs.(hashable).Hash(&vm.hash) _, in := table[vm.hash.Sum128()] - vm.stack[vm.sp-1] = newEvalBool(in) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(in) } return 1 }, "IN (SP-1), [static table]") @@ -1665,11 +1883,149 @@ func (asm *assembler) In_slow(not bool) { func (asm *assembler) Is(check func(eval) bool) { asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp-1] = newEvalBool(check(vm.stack[vm.sp-1])) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(check(vm.stack[vm.sp-1])) return 1 }, "IS (SP-1), [static]") } +func (asm *assembler) Not_i() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalInt64) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(arg.i == 0) + return 1 + }, "NOT INT64(SP-1)") +} + +func (asm *assembler) Not_u() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalUint64) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(arg.u == 0) + return 1 + }, "NOT UINT64(SP-1)") +} + +func (asm *assembler) Not_f() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalFloat) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(arg.f == 0.0) + return 1 + }, "NOT FLOAT64(SP-1)") +} + +func (asm *assembler) Not_d() { + asm.emit(func(vm *VirtualMachine) int { + arg := vm.stack[vm.sp-1].(*evalDecimal) + vm.stack[vm.sp-1] = vm.arena.newEvalBool(arg.dec.IsZero()) + return 1 + }, "NOT DECIMAL(SP-1)") +} + +func (asm *assembler) LogicalLeft(opname string) *jump { + switch opname { + case "AND": + j := asm.jumpFrom() + asm.emit(func(vm *VirtualMachine) int { + left, ok := vm.stack[vm.sp-1].(*evalInt64) + if ok && left.i == 0 { + return j.offset() + } + return 1 + }, "AND CHECK INT64(SP-1)") + return j + case "OR": + j := asm.jumpFrom() + asm.emit(func(vm *VirtualMachine) int { + left, ok := vm.stack[vm.sp-1].(*evalInt64) + if ok && left.i != 0 { + left.i = 1 + return j.offset() + } + return 1 + }, "OR CHECK INT64(SP-1)") + return j + case "XOR": + j := asm.jumpFrom() + asm.emit(func(vm *VirtualMachine) int { + if vm.stack[vm.sp-1] == nil { + return j.offset() + } + return 1 + }, "XOR CHECK INT64(SP-1)") + return j + } + return nil +} + +func (asm *assembler) LogicalRight(opname string) { + asm.adjustStack(-1) + switch opname { + case "AND": + asm.emit(func(vm *VirtualMachine) int { + left, lok := vm.stack[vm.sp-2].(*evalInt64) + right, rok := vm.stack[vm.sp-1].(*evalInt64) + + isLeft := lok && left.i != 0 + isRight := rok && right.i != 0 + + if isLeft && isRight { + left.i = 1 + } else if rok && !isRight { + vm.stack[vm.sp-2] = vm.arena.newEvalBool(false) + } else { + vm.stack[vm.sp-2] = nil + } + vm.sp-- + return 1 + }, "AND INT64(SP-2), INT64(SP-1)") + case "OR": + asm.emit(func(vm *VirtualMachine) int { + left, lok := vm.stack[vm.sp-2].(*evalInt64) + right, rok := vm.stack[vm.sp-1].(*evalInt64) + + isLeft := lok && left.i != 0 + isRight := rok && right.i != 0 + + switch { + case !lok: + if isRight { + vm.stack[vm.sp-2] = vm.arena.newEvalBool(true) + } + case !rok: + vm.stack[vm.sp-2] = nil + default: + if isLeft || isRight { + left.i = 1 + } else { + left.i = 0 + } + } + vm.sp-- + return 1 + }, "OR INT64(SP-2), INT64(SP-1)") + case "XOR": + asm.emit(func(vm *VirtualMachine) int { + left := vm.stack[vm.sp-2].(*evalInt64) + right, rok := vm.stack[vm.sp-1].(*evalInt64) + + isLeft := left.i != 0 + isRight := rok && right.i != 0 + + switch { + case !rok: + vm.stack[vm.sp-2] = nil + default: + if isLeft != isRight { + left.i = 1 + } else { + left.i = 0 + } + } + vm.sp-- + return 1 + }, "XOR INT64(SP-2), INT64(SP-1)") + } +} + func (asm *assembler) Like_coerce(expr *LikeExpr, coercion *compiledCoercion) { asm.adjustStack(-1) @@ -1689,11 +2045,7 @@ func (asm *assembler) Like_coerce(expr *LikeExpr, coercion *compiledCoercion) { } match := expr.matchWildcard(bl, br, coercion.col.ID()) - if match { - vm.stack[vm.sp-1] = evalBoolTrue - } else { - vm.stack[vm.sp-1] = evalBoolFalse - } + vm.stack[vm.sp-1] = vm.arena.newEvalBool(match) return 1 }, "LIKE VARCHAR(SP-2), VARCHAR(SP-1) COERCE AND COLLATE '%s'", coercion.col.Name()) } @@ -1707,11 +2059,7 @@ func (asm *assembler) Like_collate(expr *LikeExpr, collation collations.Collatio vm.sp-- match := expr.matchWildcard(l.bytes, r.bytes, collation.ID()) - if match { - vm.stack[vm.sp-1] = evalBoolTrue - } else { - vm.stack[vm.sp-1] = evalBoolFalse - } + vm.stack[vm.sp-1] = vm.arena.newEvalBool(match) return 1 }, "LIKE VARCHAR(SP-2), VARCHAR(SP-1) COLLATE '%s'", collation.Name()) } @@ -2053,12 +2401,12 @@ func (asm *assembler) PushNull() { func (asm *assembler) SetBool(offset int, b bool) { if b { asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp-offset] = evalBoolTrue + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(true) return 1 }, "SET (SP-%d), BOOL(true)", offset) } else { asm.emit(func(vm *VirtualMachine) int { - vm.stack[vm.sp-offset] = evalBoolFalse + vm.stack[vm.sp-offset] = vm.arena.newEvalBool(false) return 1 }, "SET (SP-%d), BOOL(false)", offset) } diff --git a/go/vt/vtgate/evalengine/compiler_bit.go b/go/vt/vtgate/evalengine/compiler_bit.go index 2bf3f0d98d3..bc37dd7bc8d 100644 --- a/go/vt/vtgate/evalengine/compiler_bit.go +++ b/go/vt/vtgate/evalengine/compiler_bit.go @@ -41,17 +41,20 @@ func (c *compiler) compileBitwiseOp(left Expr, right Expr, asm_ins_bb, asm_ins_u return ctype{}, err } + skip1 := c.compileNullCheck1(lt) + rt, err := c.compileExpr(right) if err != nil { return ctype{}, err } - skip := c.compileNullCheck2(lt, rt) + skip2 := c.compileNullCheck2(lt, rt) if lt.Type == sqltypes.VarBinary && rt.Type == sqltypes.VarBinary { if !lt.isHexOrBitLiteral() || !rt.isHexOrBitLiteral() { asm_ins_bb() - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil } } @@ -60,7 +63,8 @@ func (c *compiler) compileBitwiseOp(left Expr, right Expr, asm_ins_bb, asm_ins_u rt = c.compileToBitwiseUint64(rt, 1) asm_ins_uu() - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil } @@ -70,12 +74,14 @@ func (c *compiler) compileBitwiseShift(left Expr, right Expr, i int) (ctype, err return ctype{}, err } + skip1 := c.compileNullCheck1(lt) + rt, err := c.compileExpr(right) if err != nil { return ctype{}, err } - skip := c.compileNullCheck2(lt, rt) + skip2 := c.compileNullCheck2(lt, rt) if lt.Type == sqltypes.VarBinary && !lt.isHexOrBitLiteral() { _ = c.compileToUint64(rt, 1) @@ -84,7 +90,8 @@ func (c *compiler) compileBitwiseShift(left Expr, right Expr, i int) (ctype, err } else { c.asm.BitShiftRight_bu() } - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil } @@ -97,7 +104,8 @@ func (c *compiler) compileBitwiseShift(left Expr, right Expr, i int) (ctype, err c.asm.BitShiftRight_uu() } - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1) + c.asm.jumpDestination(skip2) return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil } diff --git a/go/vt/vtgate/evalengine/compiler_compare.go b/go/vt/vtgate/evalengine/compiler_compare.go index 15d28370ef5..cb57c8ed932 100644 --- a/go/vt/vtgate/evalengine/compiler_compare.go +++ b/go/vt/vtgate/evalengine/compiler_compare.go @@ -271,7 +271,7 @@ func (c *compiler) compileLike(expr *LikeExpr) (ctype, error) { skip := c.compileNullCheck2(lt, rt) - if !sqltypes.IsText(lt.Type) && !sqltypes.IsBinary(lt.Type) { + if !lt.isTextual() { c.asm.Convert_xc(2, sqltypes.VarChar, c.defaultCollation, 0, false) lt.Col = collations.TypedCollation{ Collation: c.defaultCollation, @@ -280,7 +280,7 @@ func (c *compiler) compileLike(expr *LikeExpr) (ctype, error) { } } - if !sqltypes.IsText(rt.Type) && !sqltypes.IsBinary(rt.Type) { + if !rt.isTextual() { c.asm.Convert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) rt.Col = collations.TypedCollation{ Collation: c.defaultCollation, @@ -376,3 +376,102 @@ func (c *compiler) compileIn(expr *InExpr) (ctype, error) { } return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}, nil } + +func (c *compiler) compileNot(expr *NotExpr) (ctype, error) { + arg, err := c.compileExpr(expr.Inner) + if err != nil { + return ctype{}, nil + } + + skip := c.compileNullCheck1(arg) + + switch arg.Type { + case sqltypes.Null: + // No-op. + case sqltypes.Int64: + c.asm.Not_i() + case sqltypes.Uint64: + c.asm.Not_u() + case sqltypes.Float64: + c.asm.Not_f() + case sqltypes.Decimal: + c.asm.Not_d() + case sqltypes.VarChar, sqltypes.VarBinary: + if arg.isHexOrBitLiteral() { + c.asm.Convert_xu(1) + c.asm.Not_u() + } else { + c.asm.Convert_bB(1) + c.asm.Not_i() + } + case sqltypes.TypeJSON: + c.asm.Convert_jB(1) + c.asm.Not_i() + default: + return ctype{}, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "unsupported Not check: %s", arg.Type) + } + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Int64, Flag: flagNullable | flagIsBoolean, Col: collationNumeric}, nil +} + +func (c *compiler) compileLogical(expr *LogicalExpr) (ctype, error) { + lt, err := c.compileExpr(expr.Left) + if err != nil { + return ctype{}, err + } + + switch lt.Type { + case sqltypes.Null, sqltypes.Int64: + // No-op. + case sqltypes.Uint64: + c.asm.Convert_uB(1) + case sqltypes.Float64: + c.asm.Convert_fB(1) + case sqltypes.Decimal: + c.asm.Convert_dB(1) + case sqltypes.VarChar, sqltypes.VarBinary: + if lt.isHexOrBitLiteral() { + c.asm.Convert_xu(1) + c.asm.Convert_uB(1) + } else { + c.asm.Convert_bB(1) + } + case sqltypes.TypeJSON: + c.asm.Convert_jB(1) + default: + c.asm.Convert_bB(1) + } + + jump := c.asm.LogicalLeft(expr.opname) + + rt, err := c.compileExpr(expr.Right) + if err != nil { + return ctype{}, err + } + + switch rt.Type { + case sqltypes.Null, sqltypes.Int64: + // No-op. + case sqltypes.Uint64: + c.asm.Convert_uB(1) + case sqltypes.Float64: + c.asm.Convert_fB(1) + case sqltypes.Decimal: + c.asm.Convert_dB(1) + case sqltypes.VarChar, sqltypes.VarBinary: + if rt.isHexOrBitLiteral() { + c.asm.Convert_xu(1) + c.asm.Convert_uB(1) + } else { + c.asm.Convert_bB(1) + } + case sqltypes.TypeJSON: + c.asm.Convert_jB(1) + default: + c.asm.Convert_bB(1) + } + + c.asm.LogicalRight(expr.opname) + c.asm.jumpDestination(jump) + return ctype{Type: sqltypes.Int64, Flag: flagNullable | flagIsBoolean, Col: collationNumeric}, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_fn.go b/go/vt/vtgate/evalengine/compiler_fn.go index b178d713031..7ada401256a 100644 --- a/go/vt/vtgate/evalengine/compiler_fn.go +++ b/go/vt/vtgate/evalengine/compiler_fn.go @@ -211,7 +211,7 @@ func (c *compiler) compileFn_REPEAT(expr *builtinRepeat) (ctype, error) { skip := c.compileNullCheck2(str, repeat) switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + case str.isTextual(): default: c.asm.Convert_xc(2, sqltypes.VarChar, c.defaultCollation, 0, false) } @@ -236,7 +236,7 @@ func (c *compiler) compileFn_TO_BASE64(call *builtinToBase64) (ctype, error) { } switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + case str.isTextual(): default: c.asm.Convert_xc(1, t, c.defaultCollation, 0, false) } @@ -261,16 +261,21 @@ func (c *compiler) compileFn_FROM_BASE64(call *builtinFromBase64) (ctype, error) skip := c.compileNullCheck1(str) + t := sqltypes.VarBinary + if str.Type == sqltypes.Blob || str.Type == sqltypes.TypeJSON { + t = sqltypes.Blob + } + switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + case str.isTextual(): default: - c.asm.Convert_xc(1, sqltypes.VarBinary, c.defaultCollation, 0, false) + c.asm.Convert_xc(1, t, c.defaultCollation, 0, false) } - c.asm.Fn_FROM_BASE64() + c.asm.Fn_FROM_BASE64(t) c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil + return ctype{Type: t, Col: collationBinary}, nil } func (c *compiler) compileFn_CCASE(call *builtinChangeCase) (ctype, error) { @@ -282,7 +287,7 @@ func (c *compiler) compileFn_CCASE(call *builtinChangeCase) (ctype, error) { skip := c.compileNullCheck1(str) switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + case str.isTextual(): default: c.asm.Convert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) } @@ -302,7 +307,7 @@ func (c *compiler) compileFn_xLENGTH(call callable, asm_ins func()) (ctype, erro skip := c.compileNullCheck1(str) switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + case str.isTextual(): default: c.asm.Convert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) } @@ -322,7 +327,7 @@ func (c *compiler) compileFn_ASCII(call *builtinASCII) (ctype, error) { skip := c.compileNullCheck1(str) switch { - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + case str.isTextual(): default: c.asm.Convert_xc(1, sqltypes.VarChar, c.defaultCollation, 0, false) } @@ -355,7 +360,7 @@ func (c *compiler) compileFn_HEX(call *builtinHex) (ctype, error) { switch { case sqltypes.IsNumber(str.Type), sqltypes.IsDecimal(str.Type): c.asm.Fn_HEX_d(col) - case sqltypes.IsText(str.Type) || sqltypes.IsBinary(str.Type): + case str.isTextual(): c.asm.Fn_HEX_c(t, col) default: c.asm.Convert_xc(1, t, c.defaultCollation, 0, false) diff --git a/go/vt/vtgate/evalengine/compiler_json.go b/go/vt/vtgate/evalengine/compiler_json.go index f6232394849..e786709ee35 100644 --- a/go/vt/vtgate/evalengine/compiler_json.go +++ b/go/vt/vtgate/evalengine/compiler_json.go @@ -66,6 +66,26 @@ func (c *compiler) compileToJSON(doct ctype, offset int) (ctype, error) { return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil } +func (c *compiler) compileArgToJSON(doct ctype, offset int) (ctype, error) { + switch doct.Type { + case sqltypes.TypeJSON: + return doct, nil + case sqltypes.Float64: + c.asm.Convert_fj(offset) + case sqltypes.Int64, sqltypes.Uint64, sqltypes.Decimal: + c.asm.Convert_nj(offset, doct.Flag&flagIsBoolean != 0) + case sqltypes.VarChar: + c.asm.ConvertArg_cj(offset) + case sqltypes.VarBinary: + c.asm.Convert_bj(offset) + case sqltypes.Null: + c.asm.Convert_Nj(offset) + default: + return ctype{}, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "Unsupported type conversion: %s AS JSON", doct.Type) + } + return ctype{Type: sqltypes.TypeJSON, Col: collationJSON}, nil +} + func (c *compiler) compileFn_JSON_ARRAY(call *builtinJSONArray) (ctype, error) { for _, arg := range call.Arguments { tt, err := c.compileExpr(arg) @@ -73,7 +93,7 @@ func (c *compiler) compileFn_JSON_ARRAY(call *builtinJSONArray) (ctype, error) { return ctype{}, err } - _, err = c.compileToJSON(tt, 1) + _, err = c.compileArgToJSON(tt, 1) if err != nil { return ctype{}, err } @@ -95,7 +115,7 @@ func (c *compiler) compileFn_JSON_OBJECT(call *builtinJSONObject) (ctype, error) if err != nil { return ctype{}, err } - _, err = c.compileToJSON(val, 1) + _, err = c.compileArgToJSON(val, 1) if err != nil { return ctype{}, err } diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index ac03d186607..cb76489c8c0 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -186,7 +186,7 @@ func (d *debugCompiler) Stack(old, new int) { d.t.Logf("\tsp = %d -> %d", old, new) } -func TestCompiler(t *testing.T) { +func TestCompilerSingle(t *testing.T) { var testCases = []struct { expression string values []sqltypes.Value @@ -255,6 +255,22 @@ func TestCompiler(t *testing.T) { expression: `JSON_ARRAY(true, 1.0)`, result: `JSON("[true, 1.0]")`, }, + { + expression: `cast(true as json) + 0`, + result: `FLOAT64(1)`, + }, + { + expression: `CAST(CAST(0 AS JSON) AS CHAR(16))`, + result: `VARCHAR("0")`, + }, + { + expression: `1 OR cast('invalid' as json)`, + result: `INT64(1)`, + }, + { + expression: `NULL AND 1`, + result: `NULL`, + }, } for _, tc := range testCases { @@ -269,6 +285,16 @@ func TestCompiler(t *testing.T) { t.Fatal(err) } + env := evalengine.EmptyExpressionEnv() + env.Row = tc.values + expected, err := env.Evaluate(converted) + if err != nil { + t.Fatal(err) + } + if expected.String() != tc.result { + t.Fatalf("bad evaluation from eval engine: got %s, want %s", expected.String(), tc.result) + } + compiled, err := evalengine.Compile(converted, makeFields(tc.values), evalengine.WithAssemblerLog(&debugCompiler{t})) if err != nil { t.Fatal(err) @@ -283,7 +309,7 @@ func TestCompiler(t *testing.T) { } if res.String() != tc.result { - t.Fatalf("bad evaluation: got %s, want %s (iteration %d)", res, tc.result, i) + t.Fatalf("bad evaluation from compiler: got %s, want %s (iteration %d)", res, tc.result, i) } } }) diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index 0c740cb3efc..d4d4e9b3539 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -130,7 +130,22 @@ func evalIsTruthy(e eval) boolean { case *evalDecimal: return makeboolean(!e.dec.IsZero()) case *evalBytes: + if e.isHexLiteral { + hex, ok := e.toNumericHex() + if !ok { + // overflow + return makeboolean(true) + } + return makeboolean(hex.u != 0) + } return makeboolean(parseStringToFloat(e.string()) != 0.0) + case *evalJSON: + switch e.Type() { + case json.TypeNumber: + return makeboolean(parseStringToFloat(e.Raw()) != 0.0) + default: + return makeboolean(true) + } default: panic("unhandled case: evalIsTruthy") } diff --git a/go/vt/vtgate/evalengine/eval_json.go b/go/vt/vtgate/evalengine/eval_json.go index 8275f241c70..56222d18946 100644 --- a/go/vt/vtgate/evalengine/eval_json.go +++ b/go/vt/vtgate/evalengine/eval_json.go @@ -92,6 +92,15 @@ func evalConvert_nj(e evalNumeric) *evalJSON { } func evalConvert_cj(e *evalBytes) (*evalJSON, error) { + jsonText, err := charset.Convert(nil, charset.Charset_utf8mb4{}, e.bytes, e.col.Collation.Get().Charset()) + if err != nil { + return nil, err + } + var p json.Parser + return p.ParseBytes(jsonText) +} + +func evalConvertArg_cj(e *evalBytes) (*evalJSON, error) { jsonText, err := charset.Convert(nil, charset.Charset_utf8mb4{}, e.bytes, e.col.Collation.Get().Charset()) if err != nil { return nil, err @@ -118,3 +127,23 @@ func evalToJSON(e eval) (*evalJSON, error) { return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Unsupported type conversion: %s AS JSON", e.SQLType()) } } + +func argToJSON(e eval) (*evalJSON, error) { + switch e := e.(type) { + case nil: + return json.ValueNull, nil + case *evalJSON: + return e, nil + case *evalFloat: + return evalConvert_fj(e), nil + case evalNumeric: + return evalConvert_nj(e), nil + case *evalBytes: + if sqltypes.IsBinary(e.SQLType()) { + return evalConvert_bj(e), nil + } + return evalConvertArg_cj(e) + default: + return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Unsupported type conversion: %s AS JSON", e.SQLType()) + } +} diff --git a/go/vt/vtgate/evalengine/eval_numeric.go b/go/vt/vtgate/evalengine/eval_numeric.go index 4d4dc3ef133..b6628ab4dfe 100644 --- a/go/vt/vtgate/evalengine/eval_numeric.go +++ b/go/vt/vtgate/evalengine/eval_numeric.go @@ -111,9 +111,9 @@ func evalToNumeric(e eval) evalNumeric { case *evalJSON: switch e.Type() { case json.TypeTrue: - return newEvalBool(true) + return &evalFloat{f: 1.0} case json.TypeFalse: - return newEvalBool(false) + return &evalFloat{f: 0.0} case json.TypeNumber, json.TypeString: return &evalFloat{f: parseStringToFloat(e.Raw())} default: diff --git a/go/vt/vtgate/evalengine/expr_arithmetic.go b/go/vt/vtgate/evalengine/expr_arithmetic.go index a7c8e12ed4d..b592c681b41 100644 --- a/go/vt/vtgate/evalengine/expr_arithmetic.go +++ b/go/vt/vtgate/evalengine/expr_arithmetic.go @@ -35,10 +35,12 @@ type ( String() string } - opArithAdd struct{} - opArithSub struct{} - opArithMul struct{} - opArithDiv struct{} + opArithAdd struct{} + opArithSub struct{} + opArithMul struct{} + opArithDiv struct{} + opArithIntDiv struct{} + opArithMod struct{} ) var _ Expr = (*ArithmeticExpr)(nil) @@ -47,10 +49,17 @@ var _ opArith = (*opArithAdd)(nil) var _ opArith = (*opArithSub)(nil) var _ opArith = (*opArithMul)(nil) var _ opArith = (*opArithDiv)(nil) +var _ opArith = (*opArithIntDiv)(nil) +var _ opArith = (*opArithMod)(nil) func (b *ArithmeticExpr) eval(env *ExpressionEnv) (eval, error) { - left, right, err := b.arguments(env) - if left == nil || right == nil || err != nil { + left, err := b.Left.eval(env) + if left == nil || err != nil { + return nil, err + } + + right, err := b.Right.eval(env) + if right == nil || err != nil { return nil, err } return b.Op.eval(left, right) @@ -78,9 +87,22 @@ func (b *ArithmeticExpr) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { switch b.Op.(type) { case *opArithDiv: if t1 == sqltypes.Float64 || t2 == sqltypes.Float64 { - return sqltypes.Float64, flags + return sqltypes.Float64, flags | flagNullable + } + return sqltypes.Decimal, flags | flagNullable + case *opArithIntDiv: + if t1 == sqltypes.Uint64 || t2 == sqltypes.Uint64 { + return sqltypes.Uint64, flags | flagNullable + } + return sqltypes.Int64, flags | flagNullable + case *opArithMod: + if t1 == sqltypes.Float64 || t2 == sqltypes.Float64 { + return sqltypes.Float64, flags | flagNullable } - return sqltypes.Decimal, flags + if t1 == sqltypes.Decimal || t2 == sqltypes.Decimal { + return sqltypes.Decimal, flags | flagNullable + } + return t1, flags | flagNullable } switch t1 { @@ -122,6 +144,16 @@ func (op *opArithDiv) eval(left, right eval) (eval, error) { } func (op *opArithDiv) String() string { return "/" } +func (op *opArithIntDiv) eval(left, right eval) (eval, error) { + return integerDivideNumericWithError(left, right, true) +} +func (op *opArithIntDiv) String() string { return "DIV" } + +func (op *opArithMod) eval(left, right eval) (eval, error) { + return modNumericWithError(left, right, true) +} +func (op *opArithMod) String() string { return "DIV" } + func (n *NegateExpr) eval(env *ExpressionEnv) (eval, error) { e, err := n.Inner.eval(env) if err != nil { diff --git a/go/vt/vtgate/evalengine/expr_bit.go b/go/vt/vtgate/evalengine/expr_bit.go index 1e68438b45f..3e87fdd2fe4 100644 --- a/go/vt/vtgate/evalengine/expr_bit.go +++ b/go/vt/vtgate/evalengine/expr_bit.go @@ -173,8 +173,13 @@ func (o opBitAnd) BitwiseOp() string { return "&" } var errBitwiseOperandsLength = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Binary operands of bitwise operators must be of equal length") func (bit *BitwiseExpr) eval(env *ExpressionEnv) (eval, error) { - l, r, err := bit.arguments(env) - if l == nil || r == nil || err != nil { + l, err := bit.Left.eval(env) + if l == nil || err != nil { + return nil, err + } + + r, err := bit.Right.eval(env) + if r == nil || err != nil { return nil, err } diff --git a/go/vt/vtgate/evalengine/expr_logical.go b/go/vt/vtgate/evalengine/expr_logical.go index 79ed6b97d90..63a7e170adb 100644 --- a/go/vt/vtgate/evalengine/expr_logical.go +++ b/go/vt/vtgate/evalengine/expr_logical.go @@ -25,7 +25,7 @@ import ( type ( LogicalExpr struct { BinaryExpr - op func(left, right boolean) boolean + op func(left, right Expr, env *ExpressionEnv) (boolean, error) opname string } @@ -100,55 +100,100 @@ func (left boolean) not() boolean { } } -func (left boolean) and(right boolean) boolean { +func opAnd(le, re Expr, env *ExpressionEnv) (boolean, error) { // Logical AND. // Evaluates to 1 if all operands are nonzero and not NULL, to 0 if one or more operands are 0, otherwise NULL is returned. + l, err := le.eval(env) + if err != nil { + return boolNULL, err + } + + left := evalIsTruthy(l) + if left == boolFalse { + return boolFalse, nil + } + + r, err := re.eval(env) + if err != nil { + return boolNULL, err + } + right := evalIsTruthy(r) + switch { case left == boolTrue && right == boolTrue: - return boolTrue - case left == boolFalse || right == boolFalse: - return boolFalse + return boolTrue, nil + case right == boolFalse: + return boolFalse, nil default: - return boolNULL + return boolNULL, nil } } -func (left boolean) or(right boolean) boolean { +func opOr(le, re Expr, env *ExpressionEnv) (boolean, error) { // Logical OR. When both operands are non-NULL, the result is 1 if any operand is nonzero, and 0 otherwise. // With a NULL operand, the result is 1 if the other operand is nonzero, and NULL otherwise. // If both operands are NULL, the result is NULL. + l, err := le.eval(env) + if err != nil { + return boolNULL, err + } + + left := evalIsTruthy(l) + if left == boolTrue { + return boolTrue, nil + } + + r, err := re.eval(env) + if err != nil { + return boolNULL, err + } + right := evalIsTruthy(r) + switch { case left == boolNULL: if right == boolTrue { - return boolTrue + return boolTrue, nil } - return boolNULL + return boolNULL, nil case right == boolNULL: - if left == boolTrue { - return boolTrue - } - return boolNULL + return boolNULL, nil default: - if left == boolTrue || right == boolTrue { - return boolTrue + if right == boolTrue { + return boolTrue, nil } - return boolFalse + return boolFalse, nil } } -func (left boolean) xor(right boolean) boolean { +func opXor(le, re Expr, env *ExpressionEnv) (boolean, error) { // Logical XOR. Returns NULL if either operand is NULL. // For non-NULL operands, evaluates to 1 if an odd number of operands is nonzero, otherwise 0 is returned. + l, err := le.eval(env) + if err != nil { + return boolNULL, err + } + + left := evalIsTruthy(l) + if left == boolNULL { + return boolNULL, nil + } + + r, err := re.eval(env) + if err != nil { + return boolNULL, err + } + right := evalIsTruthy(r) + switch { case left == boolNULL || right == boolNULL: - return boolNULL + return boolNULL, nil default: if left != right { - return boolTrue + return boolTrue, nil } - return boolFalse + return boolFalse, nil } } @@ -162,21 +207,18 @@ func (n *NotExpr) eval(env *ExpressionEnv) (eval, error) { func (n *NotExpr) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { _, flags := n.Inner.typeof(env) - return sqltypes.Uint64, flags + return sqltypes.Int64, flags | flagIsBoolean } func (l *LogicalExpr) eval(env *ExpressionEnv) (eval, error) { - left, right, err := l.arguments(env) - if err != nil { - return nil, err - } - return l.op(evalIsTruthy(left), evalIsTruthy(right)).eval(), nil + res, err := l.op(l.Left, l.Right, env) + return res.eval(), err } func (l *LogicalExpr) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { _, f1 := l.Left.typeof(env) _, f2 := l.Right.typeof(env) - return sqltypes.Uint64, f1 | f2 + return sqltypes.Int64, f1 | f2 | flagIsBoolean } func (i *IsExpr) eval(env *ExpressionEnv) (eval, error) { diff --git a/go/vt/vtgate/evalengine/fn_base64.go b/go/vt/vtgate/evalengine/fn_base64.go index 7cdc2bafe98..5b6372710ec 100644 --- a/go/vt/vtgate/evalengine/fn_base64.go +++ b/go/vt/vtgate/evalengine/fn_base64.go @@ -76,12 +76,18 @@ func (call *builtinFromBase64) eval(env *ExpressionEnv) (eval, error) { b := evalToBinary(arg) decoded := make([]byte, mysqlBase64.DecodedLen(len(b.bytes))) if n, err := mysqlBase64.Decode(decoded, b.bytes); err == nil { + if arg.SQLType() == sqltypes.Text || arg.SQLType() == sqltypes.TypeJSON { + return newEvalRaw(sqltypes.Blob, decoded[:n], collationBinary), nil + } return newEvalBinary(decoded[:n]), nil } return nil, nil } func (call *builtinFromBase64) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { - _, f := call.Arguments[0].typeof(env) + tt, f := call.Arguments[0].typeof(env) + if tt == sqltypes.Text || tt == sqltypes.TypeJSON { + return sqltypes.Blob, f | flagNullable + } return sqltypes.VarBinary, f | flagNullable } diff --git a/go/vt/vtgate/evalengine/fn_json.go b/go/vt/vtgate/evalengine/fn_json.go index 75e88c82d75..ed48a325981 100644 --- a/go/vt/vtgate/evalengine/fn_json.go +++ b/go/vt/vtgate/evalengine/fn_json.go @@ -168,7 +168,7 @@ func (call *builtinJSONObject) eval(env *ExpressionEnv) (eval, error) { if err != nil { return nil, err } - val1, err := evalToJSON(val) + val1, err := argToJSON(val) if err != nil { return nil, err } @@ -189,7 +189,7 @@ func (call *builtinJSONArray) eval(env *ExpressionEnv) (eval, error) { if err != nil { return nil, err } - arg1, err := evalToJSON(arg) + arg1, err := argToJSON(arg) if err != nil { return nil, err } diff --git a/go/vt/vtgate/evalengine/internal/decimal/decimal.go b/go/vt/vtgate/evalengine/internal/decimal/decimal.go index b93a1b9069b..14b708601b0 100644 --- a/go/vt/vtgate/evalengine/internal/decimal/decimal.go +++ b/go/vt/vtgate/evalengine/internal/decimal/decimal.go @@ -362,7 +362,7 @@ func (d Decimal) Div(d2 Decimal, scaleIncr int32) Decimal { scaleIncr = 0 } scale := myBigDigits(fracLeft+fracRight+scaleIncr) * 9 - q, _ := d.quoRem(d2, scale) + q, _ := d.QuoRem(d2, scale) return q } @@ -372,15 +372,15 @@ func (d Decimal) div(d2 Decimal) Decimal { return d.divRound(d2, int32(divisionPrecision)) } -// quoRem does division with remainder -// d.quoRem(d2,precision) returns quotient q and remainder r such that +// QuoRem does division with remainder +// d.QuoRem(d2,precision) returns quotient q and remainder r such that // // d = d2 * q + r, q an integer multiple of 10^(-precision) // 0 <= r < abs(d2) * 10 ^(-precision) if d>=0 // 0 >= r > -abs(d2) * 10 ^(-precision) if d<0 // // Note that precision<0 is allowed as input. -func (d Decimal) quoRem(d2 Decimal, precision int32) (Decimal, Decimal) { +func (d Decimal) QuoRem(d2 Decimal, precision int32) (Decimal, Decimal) { d.ensureInitialized() d2.ensureInitialized() if d2.value.Sign() == 0 { @@ -389,7 +389,7 @@ func (d Decimal) quoRem(d2 Decimal, precision int32) (Decimal, Decimal) { scale := -precision e := int64(d.exp - d2.exp - scale) if e > math.MaxInt32 || e < math.MinInt32 { - panic("overflow in decimal quoRem") + panic("overflow in decimal QuoRem") } var aa, bb, expo big.Int var scalerest int32 @@ -428,7 +428,7 @@ func (d Decimal) quoRem(d2 Decimal, precision int32) (Decimal, Decimal) { // Note that precision<0 is allowed as input. func (d Decimal) divRound(d2 Decimal, precision int32) Decimal { // quoRem already checks initialization - q, r := d.quoRem(d2, precision) + q, r := d.QuoRem(d2, precision) // the actual rounding decision is based on comparing r*10^precision and d2/2 // instead compare 2 r 10 ^precision and d2 diff --git a/go/vt/vtgate/evalengine/internal/decimal/decimal_test.go b/go/vt/vtgate/evalengine/internal/decimal/decimal_test.go index 333d24e5d4d..0f0c16763a7 100644 --- a/go/vt/vtgate/evalengine/internal/decimal/decimal_test.go +++ b/go/vt/vtgate/evalengine/internal/decimal/decimal_test.go @@ -746,11 +746,11 @@ func TestDecimal_QuoRem(t *testing.T) { d, _ := NewFromString(inp4.d) d2, _ := NewFromString(inp4.d2) prec := inp4.exp - q, r := d.quoRem(d2, prec) + q, r := d.QuoRem(d2, prec) expectedQ, _ := NewFromString(inp4.q) expectedR, _ := NewFromString(inp4.r) if !q.Equal(expectedQ) || !r.Equal(expectedR) { - t.Errorf("bad quoRem division %s , %s , %d got %v, %v expected %s , %s", + t.Errorf("bad QuoRem division %s , %s , %d got %v, %v expected %s , %s", inp4.d, inp4.d2, prec, q, r, inp4.q, inp4.r) } if !d.Equal(d2.mul(q).Add(r)) { @@ -813,7 +813,7 @@ func TestDecimal_QuoRem2(t *testing.T) { } d2 := tc.d2 prec := tc.prec - q, r := d.quoRem(d2, prec) + q, r := d.QuoRem(d2, prec) // rule 1: d = d2*q +r if !d.Equal(d2.mul(q).Add(r)) { t.Errorf("not fitting, d=%v, d2=%v, prec=%d, q=%v, r=%v", diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index d93cb3b837b..848a9cb4029 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -54,6 +54,8 @@ var Cases = []TestCase{ {Run: LikeComparison}, {Run: MultiComparisons}, {Run: IsStatement}, + {Run: NotStatement}, + {Run: LogicalStatement}, {Run: TupleComparisons}, {Run: Comparisons}, {Run: InStatement}, @@ -436,6 +438,12 @@ func BitwiseOperators(yield Query) { yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) } } + + for _, lhs := range inputConversions { + for _, rhs := range inputConversions { + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + } + } } } @@ -537,7 +545,7 @@ func Types(yield Query) { } func Arithmetic(yield Query) { - operators := []string{"+", "-", "*", "/"} + operators := []string{"+", "-", "*", "/", "DIV", "%", "MOD"} for _, op := range operators { for _, lhs := range inputConversions { @@ -720,6 +728,26 @@ func IsStatement(yield Query) { } } +func NotStatement(yield Query) { + var ops = []string{"NOT", "!"} + for _, op := range ops { + for _, i := range inputConversions { + yield(fmt.Sprintf("%s %s", op, i), nil) + } + } +} + +func LogicalStatement(yield Query) { + var ops = []string{"AND", "&&", "OR", "||", "XOR"} + for _, op := range ops { + for _, l := range inputConversions { + for _, r := range inputConversions { + yield(fmt.Sprintf("%s %s %s", l, op, r), nil) + } + } + } +} + func TupleComparisons(yield Query) { var elems = []string{"NULL", "-1", "0", "1"} var operators = []string{"=", "!=", "<=>", "<", "<=", ">", ">="} @@ -872,5 +900,10 @@ func InStatement(yield Query) { yield(fmt.Sprintf("%s IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil) yield(fmt.Sprintf("%s IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil) yield(fmt.Sprintf("%s IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil) + + yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[0], inputs[1], inputs[2]), nil) + yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil) + yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil) + yield(fmt.Sprintf("%s NOT IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil) }) } diff --git a/go/vt/vtgate/evalengine/testcases/inputs.go b/go/vt/vtgate/evalengine/testcases/inputs.go index 94f818dc0f3..744ffdb05b7 100644 --- a/go/vt/vtgate/evalengine/testcases/inputs.go +++ b/go/vt/vtgate/evalengine/testcases/inputs.go @@ -97,6 +97,12 @@ var inputConversions = []string{ "18446744073709540000e0", "-18446744073709540000e0", "JSON_OBJECT()", "JSON_ARRAY()", + "cast(0 as json)", "cast(1 as json)", + "cast(true as json)", "cast(false as json)", + "cast('{}' as json)", "cast('[]' as json)", + "cast('null' as json)", "cast('true' as json)", "cast('false' as json)", + "cast('1' as json)", "cast('1.1' as json)", "cast('-1.1' as json)", + "cast('\"foo\"' as json)", "cast('invalid' as json)", } const inputPi = "314159265358979323846264338327950288419716939937510582097494459" diff --git a/go/vt/vtgate/evalengine/translate.go b/go/vt/vtgate/evalengine/translate.go index 4ace1b54f4f..03225143280 100644 --- a/go/vt/vtgate/evalengine/translate.go +++ b/go/vt/vtgate/evalengine/translate.go @@ -106,14 +106,14 @@ func (ast *astCompiler) translateLogicalExpr(opname string, left, right sqlparse return nil, err } - var logic func(l, r boolean) boolean + var logic func(l, r Expr, env *ExpressionEnv) (boolean, error) switch opname { case "AND": - logic = func(l, r boolean) boolean { return l.and(r) } + logic = func(l, r Expr, env *ExpressionEnv) (boolean, error) { return opAnd(l, r, env) } case "OR": - logic = func(l, r boolean) boolean { return l.or(r) } + logic = func(l, r Expr, env *ExpressionEnv) (boolean, error) { return opOr(l, r, env) } case "XOR": - logic = func(l, r boolean) boolean { return l.xor(r) } + logic = func(l, r Expr, env *ExpressionEnv) (boolean, error) { return opXor(l, r, env) } default: panic("unexpected logical operator") } @@ -234,6 +234,10 @@ func (ast *astCompiler) translateBinaryExpr(binary *sqlparser.BinaryExpr) (Expr, return &ArithmeticExpr{BinaryExpr: binaryExpr, Op: &opArithMul{}}, nil case sqlparser.DivOp: return &ArithmeticExpr{BinaryExpr: binaryExpr, Op: &opArithDiv{}}, nil + case sqlparser.IntDivOp: + return &ArithmeticExpr{BinaryExpr: binaryExpr, Op: &opArithIntDiv{}}, nil + case sqlparser.ModOp: + return &ArithmeticExpr{BinaryExpr: binaryExpr, Op: &opArithMod{}}, nil case sqlparser.BitAndOp: return &BitwiseExpr{BinaryExpr: binaryExpr, Op: &opBitAnd{}}, nil case sqlparser.BitOrOp: diff --git a/go/vt/vtgate/evalengine/translate_card.go b/go/vt/vtgate/evalengine/translate_card.go index 650563daaea..f6d50664dce 100644 --- a/go/vt/vtgate/evalengine/translate_card.go +++ b/go/vt/vtgate/evalengine/translate_card.go @@ -134,6 +134,8 @@ func (ast *astCompiler) cardExpr(expr Expr) error { return ast.cardUnary(expr.Inner) case *BitwiseNotExpr: return ast.cardUnary(expr.Inner) + case *NotExpr: + return ast.cardUnary(expr.Inner) case *ArithmeticExpr: return ast.cardBinary(expr.Left, expr.Right) case *LogicalExpr: