diff --git a/go/vt/vtgate/evalengine/arithmetic.go b/go/vt/vtgate/evalengine/arithmetic.go index 5b6c0ee538c..89ecb6fd22f 100644 --- a/go/vt/vtgate/evalengine/arithmetic.go +++ b/go/vt/vtgate/evalengine/arithmetic.go @@ -26,12 +26,17 @@ import ( "vitess.io/vitess/go/sqltypes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine/internal/decimal" ) func dataOutOfRangeError[N1, N2 constraints.Integer | constraints.Float](v1 N1, v2 N2, typ, sign string) error { return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "%s value is out of range in '(%v %s %v)'", typ, v1, sign, v2) } +func dataOutOfRangeErrorDecimal(v1 decimal.Decimal, v2 decimal.Decimal, typ, sign string) error { + return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "%s value is out of range in '(%v %s %v)'", typ, v1.String(), sign, v2.String()) +} + func addNumericWithError(left, right eval) (eval, error) { v1, v2 := makeNumericAndPrioritize(left, right) switch v1 := v1.(type) { @@ -127,6 +132,108 @@ func divideNumericWithError(left, right eval, precise bool) (eval, error) { return mathDiv_xx(v1, v2, divPrecisionIncrement) } +func integerDivideNumericWithError(left, right eval, precise bool) (eval, error) { + v1 := evalToNumeric(left) + v2 := evalToNumeric(right) + switch v1 := v1.(type) { + case *evalInt64: + switch v2 := v2.(type) { + case *evalInt64: + return mathIntDiv_ii(v1, v2) + case *evalUint64: + return mathIntDiv_iu(v1, v2) + case *evalFloat: + return mathIntDiv_di(v1.toDecimal(0, 0), v2.toDecimal(0, 0)) + case *evalDecimal: + return mathIntDiv_di(v1.toDecimal(0, 0), v2) + } + case *evalUint64: + switch v2 := v2.(type) { + case *evalInt64: + return mathIntDiv_ui(v1, v2) + case *evalUint64: + return mathIntDiv_uu(v1, v2) + case *evalFloat: + return mathIntDiv_du(v1.toDecimal(0, 0), v2.toDecimal(0, 0)) + case *evalDecimal: + return mathIntDiv_du(v1.toDecimal(0, 0), v2) + } + case *evalFloat: + switch v2 := v2.(type) { + case *evalUint64: + return mathIntDiv_du(v1.toDecimal(0, 0), v2.toDecimal(0, 0)) + default: + return mathIntDiv_di(v1.toDecimal(0, 0), v2.toDecimal(0, 0)) + } + case *evalDecimal: + switch v2 := v2.(type) { + case *evalUint64: + return mathIntDiv_du(v1, v2.toDecimal(0, 0)) + default: + return mathIntDiv_di(v1, v2.toDecimal(0, 0)) + } + } + + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", evalToSQLValue(v1), evalToSQLValue(v2)) +} + +func modNumericWithError(left, right eval, precise bool) (eval, error) { + v1 := evalToNumeric(left) + v2 := evalToNumeric(right) + + switch v1 := v1.(type) { + case *evalInt64: + switch v2 := v2.(type) { + case *evalInt64: + return mathMod_ii(v1, v2) + case *evalUint64: + return mathMod_iu(v1, v2) + case *evalFloat: + v1f, ok := v1.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathMod_ff(v1f, v2) + case *evalDecimal: + return mathMod_dd(v1.toDecimal(0, 0), v2) + } + case *evalUint64: + switch v2 := v2.(type) { + case *evalInt64: + return mathMod_ui(v1, v2) + case *evalUint64: + return mathMod_uu(v1, v2) + case *evalFloat: + v1f, ok := v1.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathMod_ff(v1f, v2) + case *evalDecimal: + return mathMod_dd(v1.toDecimal(0, 0), v2) + } + case *evalDecimal: + switch v2 := v2.(type) { + case *evalFloat: + v1f, ok := v1.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathMod_ff(v1f, v2) + default: + return mathMod_dd(v1, v2.toDecimal(0, 0)) + } + case *evalFloat: + v2f, ok := v2.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathMod_ff(v1, v2f) + } + + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", evalToSQLValue(v1), evalToSQLValue(v2)) +} + // makeNumericAndPrioritize reorders the input parameters // to be Float64, Decimal, Uint64, Int64. func makeNumericAndPrioritize(left, right eval) (evalNumeric, evalNumeric) { @@ -162,28 +269,68 @@ func mathAdd_ii0(v1, v2 int64) (int64, error) { return result, nil } -func mathSub_ii(v1, v2 int64) (*evalInt64, error) { - result, err := mathSub_ii0(v1, v2) - return newEvalInt64(result), err +func mathAdd_ui(v1 uint64, v2 int64) (*evalUint64, error) { + result, err := mathAdd_ui0(v1, v2) + return newEvalUint64(result), err } -func mathSub_ii0(v1, v2 int64) (int64, error) { - result := v1 - v2 - if (result < v1) != (v2 > 0) { - return 0, dataOutOfRangeError(v1, v2, "BIGINT", "-") +func mathAdd_ui0(v1 uint64, v2 int64) (uint64, error) { + result := v1 + uint64(v2) + if v2 < 0 && v1 < uint64(-v2) || v2 > 0 && (result < v1 || result < uint64(v2)) { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+") } return result, nil } -func mathMul_ii(v1, v2 int64) (*evalInt64, error) { - result, err := mathMul_ii0(v1, v2) +func mathAdd_uu(v1, v2 uint64) (*evalUint64, error) { + result, err := mathAdd_uu0(v1, v2) + return newEvalUint64(result), err +} + +func mathAdd_uu0(v1, v2 uint64) (uint64, error) { + result := v1 + v2 + if result < v1 || result < v2 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+") + } + return result, nil +} + +var errDecimalOutOfRange = vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "DECIMAL value is out of range") + +func mathAdd_fx(v1 float64, v2 evalNumeric) (*evalFloat, error) { + v2f, ok := v2.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathAdd_ff(v1, v2f.f), nil +} + +func mathAdd_ff(v1, v2 float64) *evalFloat { + return newEvalFloat(v1 + v2) +} + +func mathAdd_dx(v1 *evalDecimal, v2 evalNumeric) *evalDecimal { + return mathAdd_dd(v1, v2.toDecimal(0, 0)) +} + +func mathAdd_dd(v1, v2 *evalDecimal) *evalDecimal { + return newEvalDecimalWithPrec(v1.dec.Add(v2.dec), maxprec(v1.length, v2.length)) +} + +func mathAdd_dd0(v1, v2 *evalDecimal) { + v1.dec = v1.dec.Add(v2.dec) + v1.length = maxprec(v1.length, v2.length) +} + +func mathSub_ii(v1, v2 int64) (*evalInt64, error) { + result, err := mathSub_ii0(v1, v2) return newEvalInt64(result), err } -func mathMul_ii0(v1, v2 int64) (int64, error) { - result := v1 * v2 - if v1 != 0 && result/v1 != v2 { - return 0, dataOutOfRangeError(v1, v2, "BIGINT", "*") +func mathSub_ii0(v1, v2 int64) (int64, error) { + result := v1 - v2 + if (result < v1) != (v2 > 0) { + return 0, dataOutOfRangeError(v1, v2, "BIGINT", "-") } return result, nil } @@ -200,19 +347,6 @@ func mathSub_iu0(v1 int64, v2 uint64) (uint64, error) { return mathSub_uu0(uint64(v1), v2) } -func mathAdd_ui(v1 uint64, v2 int64) (*evalUint64, error) { - result, err := mathAdd_ui0(v1, v2) - return newEvalUint64(result), err -} - -func mathAdd_ui0(v1 uint64, v2 int64) (uint64, error) { - result := v1 + uint64(v2) - if v2 < 0 && v1 < uint64(-v2) || v2 > 0 && (result < v1 || result < uint64(v2)) { - return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+") - } - return result, nil -} - func mathSub_ui(v1 uint64, v2 int64) (*evalUint64, error) { result, err := mathSub_ui0(v1, v2) return newEvalUint64(result), err @@ -229,45 +363,82 @@ func mathSub_ui0(v1 uint64, v2 int64) (uint64, error) { return mathSub_uu0(v1, uint64(v2)) } -func mathMul_ui(v1 uint64, v2 int64) (*evalUint64, error) { - result, err := mathMul_ui0(v1, v2) +func mathSub_uu(v1, v2 uint64) (*evalUint64, error) { + result, err := mathSub_uu0(v1, v2) return newEvalUint64(result), err } -func mathMul_ui0(v1 uint64, v2 int64) (uint64, error) { - if v1 == 0 || v2 == 0 { - return 0, nil +func mathSub_uu0(v1, v2 uint64) (uint64, error) { + result := v1 - v2 + if v2 > v1 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-") } - if v2 < 0 { - return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "*") + return result, nil +} + +func mathSub_fx(v1 float64, v2 evalNumeric) (*evalFloat, error) { + v2f, ok := v2.toFloat() + if !ok { + return nil, errDecimalOutOfRange } - return mathMul_uu0(v1, uint64(v2)) + return mathSub_ff(v1, v2f.f), nil } -func mathAdd_uu(v1, v2 uint64) (*evalUint64, error) { - result, err := mathAdd_uu0(v1, v2) - return newEvalUint64(result), err +func mathSub_xf(v1 evalNumeric, v2 float64) (*evalFloat, error) { + v1f, ok := v1.toFloat() + if !ok { + return nil, errDecimalOutOfRange + } + return mathSub_ff(v1f.f, v2), nil } -func mathAdd_uu0(v1, v2 uint64) (uint64, error) { - result := v1 + v2 - if result < v1 || result < v2 { - return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "+") +func mathSub_ff(v1, v2 float64) *evalFloat { + return newEvalFloat(v1 - v2) +} + +func mathSub_dx(v1 *evalDecimal, v2 evalNumeric) *evalDecimal { + return mathSub_dd(v1, v2.toDecimal(0, 0)) +} + +func mathSub_xd(v1 evalNumeric, v2 *evalDecimal) *evalDecimal { + return mathSub_dd(v1.toDecimal(0, 0), v2) +} + +func mathSub_dd(v1, v2 *evalDecimal) *evalDecimal { + return newEvalDecimalWithPrec(v1.dec.Sub(v2.dec), maxprec(v1.length, v2.length)) +} + +func mathSub_dd0(v1, v2 *evalDecimal) { + v1.dec = v1.dec.Sub(v2.dec) + v1.length = maxprec(v1.length, v2.length) +} + +func mathMul_ii(v1, v2 int64) (*evalInt64, error) { + result, err := mathMul_ii0(v1, v2) + return newEvalInt64(result), err +} + +func mathMul_ii0(v1, v2 int64) (int64, error) { + result := v1 * v2 + if v1 != 0 && result/v1 != v2 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT", "*") } return result, nil } -func mathSub_uu(v1, v2 uint64) (*evalUint64, error) { - result, err := mathSub_uu0(v1, v2) +func mathMul_ui(v1 uint64, v2 int64) (*evalUint64, error) { + result, err := mathMul_ui0(v1, v2) return newEvalUint64(result), err } -func mathSub_uu0(v1, v2 uint64) (uint64, error) { - result := v1 - v2 - if v2 > v1 { - return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "-") +func mathMul_ui0(v1 uint64, v2 int64) (uint64, error) { + if v1 == 0 || v2 == 0 { + return 0, nil } - return result, nil + if v2 < 0 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "*") + } + return mathMul_uu0(v1, uint64(v2)) } func mathMul_uu(v1, v2 uint64) (*evalUint64, error) { @@ -286,28 +457,6 @@ func mathMul_uu0(v1, v2 uint64) (uint64, error) { return result, nil } -var errDecimalOutOfRange = vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "DECIMAL value is out of range") - -func mathAdd_fx(v1 float64, v2 evalNumeric) (*evalFloat, error) { - v2f, ok := v2.toFloat() - if !ok { - return nil, errDecimalOutOfRange - } - return mathAdd_ff(v1, v2f.f), nil -} - -func mathAdd_ff(v1, v2 float64) *evalFloat { - return newEvalFloat(v1 + v2) -} - -func mathSub_fx(v1 float64, v2 evalNumeric) (*evalFloat, error) { - v2f, ok := v2.toFloat() - if !ok { - return nil, errDecimalOutOfRange - } - return mathSub_ff(v1, v2f.f), nil -} - func mathMul_fx(v1 float64, v2 evalNumeric) (eval, error) { v2f, ok := v2.toFloat() if !ok { @@ -327,36 +476,6 @@ func maxprec(a, b int32) int32 { return b } -func mathAdd_dx(v1 *evalDecimal, v2 evalNumeric) *evalDecimal { - return mathAdd_dd(v1, v2.toDecimal(0, 0)) -} - -func mathAdd_dd(v1, v2 *evalDecimal) *evalDecimal { - return newEvalDecimalWithPrec(v1.dec.Add(v2.dec), maxprec(v1.length, v2.length)) -} - -func mathAdd_dd0(v1, v2 *evalDecimal) { - v1.dec = v1.dec.Add(v2.dec) - v1.length = maxprec(v1.length, v2.length) -} - -func mathSub_dx(v1 *evalDecimal, v2 evalNumeric) *evalDecimal { - return mathSub_dd(v1, v2.toDecimal(0, 0)) -} - -func mathSub_xd(v1 evalNumeric, v2 *evalDecimal) *evalDecimal { - return mathSub_dd(v1.toDecimal(0, 0), v2) -} - -func mathSub_dd(v1, v2 *evalDecimal) *evalDecimal { - return newEvalDecimalWithPrec(v1.dec.Sub(v2.dec), maxprec(v1.length, v2.length)) -} - -func mathSub_dd0(v1, v2 *evalDecimal) { - v1.dec = v1.dec.Sub(v2.dec) - v1.length = maxprec(v1.length, v2.length) -} - func mathMul_dx(v1 *evalDecimal, v2 evalNumeric) *evalDecimal { return mathMul_dd(v1, v2.toDecimal(0, 0)) } @@ -413,16 +532,174 @@ func mathDiv_ff0(v1, v2 float64) (float64, error) { return result, nil } -func mathSub_xf(v1 evalNumeric, v2 float64) (*evalFloat, error) { - v1f, ok := v1.toFloat() +func mathIntDiv_ii(v1, v2 *evalInt64) (eval, error) { + if v2.i == 0 { + return nil, nil + } + result := v1.i / v2.i + return newEvalInt64(result), nil +} + +func mathIntDiv_iu(v1 *evalInt64, v2 *evalUint64) (eval, error) { + if v2.u == 0 { + return nil, nil + } + result, err := mathIntDiv_iu0(v1.i, v2.u) + return newEvalUint64(result), err +} + +func mathIntDiv_iu0(v1 int64, v2 uint64) (uint64, error) { + if v1 < 0 { + if v2 >= math.MaxInt64 { + // We know here that v2 is always so large the result + // must be 0. + return 0, nil + } + result := v1 / int64(v2) + if result < 0 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "DIV") + } + return uint64(result), nil + + } + return uint64(v1) / v2, nil +} + +func mathIntDiv_ui(v1 *evalUint64, v2 *evalInt64) (eval, error) { + if v2.i == 0 { + return nil, nil + } + result, err := mathIntDiv_ui0(v1.u, v2.i) + return newEvalUint64(result), err +} + +func mathIntDiv_ui0(v1 uint64, v2 int64) (uint64, error) { + if v2 < 0 { + if v1 >= math.MaxInt64 { + // We know that v1 is always large here and with v2, the result + // must be at least -1 so we can't store this in the available range. + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "DIV") + } + // Safe to cast since we know it fits in int64 when we get here. + result := int64(v1) / v2 + if result < 0 { + return 0, dataOutOfRangeError(v1, v2, "BIGINT UNSIGNED", "DIV") + } + return uint64(result), nil + } + return v1 / uint64(v2), nil +} + +func mathIntDiv_uu(v1, v2 *evalUint64) (eval, error) { + if v2.u == 0 { + return nil, nil + } + return newEvalUint64(v1.u / v2.u), nil +} + +func mathIntDiv_di(v1, v2 *evalDecimal) (eval, error) { + if v2.dec.IsZero() { + return nil, nil + } + result, err := mathIntDiv_di0(v1, v2) + return newEvalInt64(result), err +} + +func mathIntDiv_di0(v1, v2 *evalDecimal) (int64, error) { + div, _ := v1.dec.QuoRem(v2.dec, 0) + result, ok := div.Int64() if !ok { - return nil, errDecimalOutOfRange + return 0, dataOutOfRangeErrorDecimal(v1.dec, v2.dec, "BIGINT", "DIV") } - return mathSub_ff(v1f.f, v2), nil + return result, nil } -func mathSub_ff(v1, v2 float64) *evalFloat { - return newEvalFloat(v1 - v2) +func mathIntDiv_du(v1, v2 *evalDecimal) (eval, error) { + if v2.dec.IsZero() { + return nil, nil + } + result, err := mathIntDiv_du0(v1, v2) + return newEvalUint64(result), err +} + +func mathIntDiv_du0(v1, v2 *evalDecimal) (uint64, error) { + div, _ := v1.dec.QuoRem(v2.dec, 0) + result, ok := div.Uint64() + if !ok { + return 0, dataOutOfRangeErrorDecimal(v1.dec, v2.dec, "BIGINT UNSIGNED", "DIV") + } + return result, nil +} + +func mathMod_ii(v1, v2 *evalInt64) (eval, error) { + if v2.i == 0 { + return nil, nil + } + return newEvalInt64(v1.i % v2.i), nil +} + +func mathMod_iu(v1 *evalInt64, v2 *evalUint64) (eval, error) { + if v2.u == 0 { + return nil, nil + } + return newEvalInt64(mathMod_iu0(v1.i, v2.u)), nil +} + +func mathMod_iu0(v1 int64, v2 uint64) int64 { + if v1 == math.MinInt64 && v2 == math.MaxInt64+1 { + return 0 + } + if v2 > math.MaxInt64 { + return v1 + } + return v1 % int64(v2) +} + +func mathMod_ui(v1 *evalUint64, v2 *evalInt64) (eval, error) { + if v2.i == 0 { + return nil, nil + } + result, err := mathMod_ui0(v1.u, v2.i) + return newEvalUint64(result), err +} + +func mathMod_ui0(v1 uint64, v2 int64) (uint64, error) { + if v2 < 0 { + return v1 % uint64(-v2), nil + } + return v1 % uint64(v2), nil +} + +func mathMod_uu(v1, v2 *evalUint64) (eval, error) { + if v2.u == 0 { + return nil, nil + } + return newEvalUint64(v1.u % v2.u), nil +} + +func mathMod_ff(v1, v2 *evalFloat) (eval, error) { + if v2.f == 0.0 { + return nil, nil + } + return newEvalFloat(math.Mod(v1.f, v2.f)), nil +} + +func mathMod_dd(v1, v2 *evalDecimal) (eval, error) { + if v2.dec.IsZero() { + return nil, nil + } + + dec, prec := mathMod_dd0(v1, v2) + return newEvalDecimalWithPrec(dec, prec), nil +} + +func mathMod_dd0(v1, v2 *evalDecimal) (decimal.Decimal, int32) { + length := v1.length + if v2.length > length { + length = v2.length + } + _, rem := v1.dec.QuoRem(v2.dec, 0) + return rem, length } func parseStringToFloat(str string) float64 { diff --git a/go/vt/vtgate/evalengine/compiler_arithmetic.go b/go/vt/vtgate/evalengine/compiler_arithmetic.go index a81afacc53a..772ca623c50 100644 --- a/go/vt/vtgate/evalengine/compiler_arithmetic.go +++ b/go/vt/vtgate/evalengine/compiler_arithmetic.go @@ -16,7 +16,11 @@ limitations under the License. package evalengine -import "vitess.io/vitess/go/sqltypes" +import ( + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" +) func (c *compiler) compileNegate(expr *NegateExpr) (ctype, error) { arg, err := c.compileExpr(expr.Inner) @@ -64,8 +68,12 @@ func (c *compiler) compileArithmetic(expr *ArithmeticExpr) (ctype, error) { return c.compileArithmeticMul(expr.Left, expr.Right) case *opArithDiv: return c.compileArithmeticDiv(expr.Left, expr.Right) + case *opArithIntDiv: + return c.compileArithmeticIntDiv(expr.Left, expr.Right) + case *opArithMod: + return c.compileArithmeticMod(expr.Left, expr.Right) default: - panic("unexpected arithmetic operator") + return ctype{}, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented") } } @@ -299,3 +307,152 @@ func (c *compiler) compileArithmeticDiv(left, right Expr) (ctype, error) { return ctype{Type: sqltypes.Decimal, Col: collationNumeric}, nil } } + +func (c *compiler) compileArithmeticIntDiv(left, right Expr) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck2(lt, rt) + lt = c.compileToNumeric(lt, 2) + rt = c.compileToNumeric(rt, 1) + + ct := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable} + switch lt.Type { + case sqltypes.Int64: + switch rt.Type { + case sqltypes.Int64: + c.asm.IntDiv_ii() + case sqltypes.Uint64: + ct.Type = sqltypes.Uint64 + c.asm.IntDiv_iu() + case sqltypes.Float64: + c.asm.Convert_xd(2, 0, 0) + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_di() + case sqltypes.Decimal: + c.asm.Convert_xd(2, 0, 0) + c.asm.IntDiv_di() + } + case sqltypes.Uint64: + switch rt.Type { + case sqltypes.Int64: + c.asm.IntDiv_ui() + case sqltypes.Uint64: + ct.Type = sqltypes.Uint64 + c.asm.IntDiv_uu() + case sqltypes.Float64: + c.asm.Convert_xd(2, 0, 0) + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_du() + case sqltypes.Decimal: + c.asm.Convert_xd(2, 0, 0) + c.asm.IntDiv_du() + } + case sqltypes.Float64: + switch rt.Type { + case sqltypes.Decimal: + c.asm.Convert_xd(2, 0, 0) + c.asm.IntDiv_di() + case sqltypes.Uint64: + ct.Type = sqltypes.Uint64 + c.asm.Convert_xd(2, 0, 0) + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_du() + default: + c.asm.Convert_xd(2, 0, 0) + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_di() + } + case sqltypes.Decimal: + switch rt.Type { + case sqltypes.Decimal: + c.asm.IntDiv_di() + case sqltypes.Uint64: + ct.Type = sqltypes.Uint64 + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_du() + default: + c.asm.Convert_xd(1, 0, 0) + c.asm.IntDiv_di() + } + } + c.asm.jumpDestination(skip) + return ct, nil +} + +func (c *compiler) compileArithmeticMod(left, right Expr) (ctype, error) { + lt, err := c.compileExpr(left) + if err != nil { + return ctype{}, err + } + + rt, err := c.compileExpr(right) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck2(lt, rt) + lt = c.compileToNumeric(lt, 2) + rt = c.compileToNumeric(rt, 1) + + ct := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable} + switch lt.Type { + case sqltypes.Int64: + ct.Type = sqltypes.Int64 + switch rt.Type { + case sqltypes.Int64: + c.asm.Mod_ii() + case sqltypes.Uint64: + c.asm.Mod_iu() + case sqltypes.Float64: + ct.Type = sqltypes.Float64 + c.asm.Convert_xf(2) + c.asm.Mod_ff() + case sqltypes.Decimal: + ct.Type = sqltypes.Decimal + c.asm.Convert_xd(2, 0, 0) + c.asm.Mod_dd() + } + case sqltypes.Uint64: + ct.Type = sqltypes.Uint64 + switch rt.Type { + case sqltypes.Int64: + c.asm.Mod_ui() + case sqltypes.Uint64: + c.asm.Mod_uu() + case sqltypes.Float64: + ct.Type = sqltypes.Float64 + c.asm.Convert_xf(2) + c.asm.Mod_ff() + case sqltypes.Decimal: + ct.Type = sqltypes.Decimal + c.asm.Convert_xd(2, 0, 0) + c.asm.Mod_dd() + } + case sqltypes.Decimal: + ct.Type = sqltypes.Decimal + switch rt.Type { + case sqltypes.Float64: + ct.Type = sqltypes.Float64 + c.asm.Convert_xf(2) + c.asm.Mod_ff() + default: + c.asm.Convert_xd(1, 0, 0) + c.asm.Mod_dd() + } + case sqltypes.Float64: + ct.Type = sqltypes.Float64 + c.asm.Convert_xf(1) + c.asm.Mod_ff() + } + + c.asm.jumpDestination(skip) + return ct, nil +} diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 4e484e9996d..771d5e08bfc 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -984,6 +984,203 @@ func (asm *assembler) Div_ff() { }, "DIV FLOAT64(SP-2), FLOAT64(SP-1)") } +func (asm *assembler) IntDiv_ii() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalInt64) + if r.i == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.i = l.i / r.i + } + vm.sp-- + return 1 + }, "INTDIV INT64(SP-2), INT64(SP-1)") +} + +func (asm *assembler) IntDiv_iu() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalUint64) + if r.u == 0 { + vm.stack[vm.sp-2] = nil + } else { + r.u, vm.err = mathIntDiv_iu0(l.i, r.u) + vm.stack[vm.sp-2] = r + } + vm.sp-- + return 1 + }, "INTDIV INT64(SP-2), UINT64(SP-1)") +} + +func (asm *assembler) IntDiv_ui() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalInt64) + if r.i == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.u, vm.err = mathIntDiv_ui0(l.u, r.i) + } + vm.sp-- + return 1 + }, "INTDIV UINT64(SP-2), INT64(SP-1)") +} + +func (asm *assembler) IntDiv_uu() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + if r.u == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.u = l.u / r.u + } + vm.sp-- + return 1 + }, "INTDIV UINT64(SP-2), UINT64(SP-1)") +} + +func (asm *assembler) IntDiv_di() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + if r.dec.IsZero() { + vm.stack[vm.sp-2] = nil + } else { + var res int64 + res, vm.err = mathIntDiv_di0(l, r) + vm.stack[vm.sp-2] = vm.arena.newEvalInt64(res) + } + vm.sp-- + return 1 + }, "INTDIV DECIMAL(SP-2), DECIMAL(SP-1)") +} + +func (asm *assembler) IntDiv_du() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + if r.dec.IsZero() { + vm.stack[vm.sp-2] = nil + } else { + var res uint64 + res, vm.err = mathIntDiv_du0(l, r) + vm.stack[vm.sp-2] = vm.arena.newEvalUint64(res) + } + vm.sp-- + return 1 + }, "UINTDIV DECIMAL(SP-2), DECIMAL(SP-1)") +} + +func (asm *assembler) Mod_ii() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalInt64) + if r.i == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.i = l.i % r.i + } + vm.sp-- + return 1 + }, "MOD INT64(SP-2), INT64(SP-1)") +} + +func (asm *assembler) Mod_iu() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalInt64) + r := vm.stack[vm.sp-1].(*evalUint64) + if r.u == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.i = mathMod_iu0(l.i, r.u) + } + vm.sp-- + return 1 + }, "MOD INT64(SP-2), UINT64(SP-1)") +} + +func (asm *assembler) Mod_ui() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalInt64) + if r.i == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.u, vm.err = mathMod_ui0(l.u, r.i) + } + vm.sp-- + return 1 + }, "MOD UINT64(SP-2), INT64(SP-1)") +} + +func (asm *assembler) Mod_uu() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalUint64) + r := vm.stack[vm.sp-1].(*evalUint64) + if r.u == 0 { + vm.stack[vm.sp-2] = nil + } else { + l.u = l.u % r.u + } + vm.sp-- + return 1 + }, "MOD UINT64(SP-2), UINT64(SP-1)") +} + +func (asm *assembler) Mod_ff() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalFloat) + r := vm.stack[vm.sp-1].(*evalFloat) + if r.f == 0.0 { + vm.stack[vm.sp-2] = nil + } else { + l.f = math.Mod(l.f, r.f) + } + vm.sp-- + return 1 + }, "MOD FLOAT64(SP-2), FLOAT64(SP-1)") +} + +func (asm *assembler) Mod_dd() { + asm.adjustStack(-1) + + asm.emit(func(vm *VirtualMachine) int { + l := vm.stack[vm.sp-2].(*evalDecimal) + r := vm.stack[vm.sp-1].(*evalDecimal) + if r.dec.IsZero() { + vm.stack[vm.sp-2] = nil + } else { + l.dec, l.length = mathMod_dd0(l, r) + } + vm.sp-- + return 1 + }, "MOD DECIMAL(SP-2), DECIMAL(SP-1)") +} + func (asm *assembler) Fn_ASCII() { asm.emit(func(vm *VirtualMachine) int { arg := vm.stack[vm.sp-1].(*evalBytes) diff --git a/go/vt/vtgate/evalengine/expr_arithmetic.go b/go/vt/vtgate/evalengine/expr_arithmetic.go index a7c8e12ed4d..63323d0f4bf 100644 --- a/go/vt/vtgate/evalengine/expr_arithmetic.go +++ b/go/vt/vtgate/evalengine/expr_arithmetic.go @@ -35,10 +35,12 @@ type ( String() string } - opArithAdd struct{} - opArithSub struct{} - opArithMul struct{} - opArithDiv struct{} + opArithAdd struct{} + opArithSub struct{} + opArithMul struct{} + opArithDiv struct{} + opArithIntDiv struct{} + opArithMod struct{} ) var _ Expr = (*ArithmeticExpr)(nil) @@ -47,6 +49,8 @@ var _ opArith = (*opArithAdd)(nil) var _ opArith = (*opArithSub)(nil) var _ opArith = (*opArithMul)(nil) var _ opArith = (*opArithDiv)(nil) +var _ opArith = (*opArithIntDiv)(nil) +var _ opArith = (*opArithMod)(nil) func (b *ArithmeticExpr) eval(env *ExpressionEnv) (eval, error) { left, right, err := b.arguments(env) @@ -78,9 +82,22 @@ func (b *ArithmeticExpr) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) { switch b.Op.(type) { case *opArithDiv: if t1 == sqltypes.Float64 || t2 == sqltypes.Float64 { - return sqltypes.Float64, flags + return sqltypes.Float64, flags | flagNullable } - return sqltypes.Decimal, flags + return sqltypes.Decimal, flags | flagNullable + case *opArithIntDiv: + if t1 == sqltypes.Uint64 || t2 == sqltypes.Uint64 { + return sqltypes.Uint64, flags | flagNullable + } + return sqltypes.Int64, flags | flagNullable + case *opArithMod: + if t1 == sqltypes.Float64 || t2 == sqltypes.Float64 { + return sqltypes.Float64, flags | flagNullable + } + if t1 == sqltypes.Decimal || t2 == sqltypes.Decimal { + return sqltypes.Decimal, flags | flagNullable + } + return t1, flags | flagNullable } switch t1 { @@ -122,6 +139,16 @@ func (op *opArithDiv) eval(left, right eval) (eval, error) { } func (op *opArithDiv) String() string { return "/" } +func (op *opArithIntDiv) eval(left, right eval) (eval, error) { + return integerDivideNumericWithError(left, right, true) +} +func (op *opArithIntDiv) String() string { return "DIV" } + +func (op *opArithMod) eval(left, right eval) (eval, error) { + return modNumericWithError(left, right, true) +} +func (op *opArithMod) String() string { return "DIV" } + func (n *NegateExpr) eval(env *ExpressionEnv) (eval, error) { e, err := n.Inner.eval(env) if err != nil { diff --git a/go/vt/vtgate/evalengine/internal/decimal/decimal.go b/go/vt/vtgate/evalengine/internal/decimal/decimal.go index b93a1b9069b..14b708601b0 100644 --- a/go/vt/vtgate/evalengine/internal/decimal/decimal.go +++ b/go/vt/vtgate/evalengine/internal/decimal/decimal.go @@ -362,7 +362,7 @@ func (d Decimal) Div(d2 Decimal, scaleIncr int32) Decimal { scaleIncr = 0 } scale := myBigDigits(fracLeft+fracRight+scaleIncr) * 9 - q, _ := d.quoRem(d2, scale) + q, _ := d.QuoRem(d2, scale) return q } @@ -372,15 +372,15 @@ func (d Decimal) div(d2 Decimal) Decimal { return d.divRound(d2, int32(divisionPrecision)) } -// quoRem does division with remainder -// d.quoRem(d2,precision) returns quotient q and remainder r such that +// QuoRem does division with remainder +// d.QuoRem(d2,precision) returns quotient q and remainder r such that // // d = d2 * q + r, q an integer multiple of 10^(-precision) // 0 <= r < abs(d2) * 10 ^(-precision) if d>=0 // 0 >= r > -abs(d2) * 10 ^(-precision) if d<0 // // Note that precision<0 is allowed as input. -func (d Decimal) quoRem(d2 Decimal, precision int32) (Decimal, Decimal) { +func (d Decimal) QuoRem(d2 Decimal, precision int32) (Decimal, Decimal) { d.ensureInitialized() d2.ensureInitialized() if d2.value.Sign() == 0 { @@ -389,7 +389,7 @@ func (d Decimal) quoRem(d2 Decimal, precision int32) (Decimal, Decimal) { scale := -precision e := int64(d.exp - d2.exp - scale) if e > math.MaxInt32 || e < math.MinInt32 { - panic("overflow in decimal quoRem") + panic("overflow in decimal QuoRem") } var aa, bb, expo big.Int var scalerest int32 @@ -428,7 +428,7 @@ func (d Decimal) quoRem(d2 Decimal, precision int32) (Decimal, Decimal) { // Note that precision<0 is allowed as input. func (d Decimal) divRound(d2 Decimal, precision int32) Decimal { // quoRem already checks initialization - q, r := d.quoRem(d2, precision) + q, r := d.QuoRem(d2, precision) // the actual rounding decision is based on comparing r*10^precision and d2/2 // instead compare 2 r 10 ^precision and d2 diff --git a/go/vt/vtgate/evalengine/internal/decimal/decimal_test.go b/go/vt/vtgate/evalengine/internal/decimal/decimal_test.go index 333d24e5d4d..0f0c16763a7 100644 --- a/go/vt/vtgate/evalengine/internal/decimal/decimal_test.go +++ b/go/vt/vtgate/evalengine/internal/decimal/decimal_test.go @@ -746,11 +746,11 @@ func TestDecimal_QuoRem(t *testing.T) { d, _ := NewFromString(inp4.d) d2, _ := NewFromString(inp4.d2) prec := inp4.exp - q, r := d.quoRem(d2, prec) + q, r := d.QuoRem(d2, prec) expectedQ, _ := NewFromString(inp4.q) expectedR, _ := NewFromString(inp4.r) if !q.Equal(expectedQ) || !r.Equal(expectedR) { - t.Errorf("bad quoRem division %s , %s , %d got %v, %v expected %s , %s", + t.Errorf("bad QuoRem division %s , %s , %d got %v, %v expected %s , %s", inp4.d, inp4.d2, prec, q, r, inp4.q, inp4.r) } if !d.Equal(d2.mul(q).Add(r)) { @@ -813,7 +813,7 @@ func TestDecimal_QuoRem2(t *testing.T) { } d2 := tc.d2 prec := tc.prec - q, r := d.quoRem(d2, prec) + q, r := d.QuoRem(d2, prec) // rule 1: d = d2*q +r if !d.Equal(d2.mul(q).Add(r)) { t.Errorf("not fitting, d=%v, d2=%v, prec=%d, q=%v, r=%v", diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index d93cb3b837b..360217d092a 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -537,7 +537,7 @@ func Types(yield Query) { } func Arithmetic(yield Query) { - operators := []string{"+", "-", "*", "/"} + operators := []string{"+", "-", "*", "/", "DIV", "%", "MOD"} for _, op := range operators { for _, lhs := range inputConversions { diff --git a/go/vt/vtgate/evalengine/translate.go b/go/vt/vtgate/evalengine/translate.go index 4ace1b54f4f..b3751cdfbcf 100644 --- a/go/vt/vtgate/evalengine/translate.go +++ b/go/vt/vtgate/evalengine/translate.go @@ -234,6 +234,10 @@ func (ast *astCompiler) translateBinaryExpr(binary *sqlparser.BinaryExpr) (Expr, return &ArithmeticExpr{BinaryExpr: binaryExpr, Op: &opArithMul{}}, nil case sqlparser.DivOp: return &ArithmeticExpr{BinaryExpr: binaryExpr, Op: &opArithDiv{}}, nil + case sqlparser.IntDivOp: + return &ArithmeticExpr{BinaryExpr: binaryExpr, Op: &opArithIntDiv{}}, nil + case sqlparser.ModOp: + return &ArithmeticExpr{BinaryExpr: binaryExpr, Op: &opArithMod{}}, nil case sqlparser.BitAndOp: return &BitwiseExpr{BinaryExpr: binaryExpr, Op: &opBitAnd{}}, nil case sqlparser.BitOrOp: