diff --git a/go/sqltypes/arithmetic.go b/go/sqltypes/arithmetic.go index f5c18dd43a1..cde014116eb 100644 --- a/go/sqltypes/arithmetic.go +++ b/go/sqltypes/arithmetic.go @@ -19,7 +19,6 @@ package sqltypes import ( "bytes" "fmt" - "math" "strconv" @@ -28,9 +27,6 @@ import ( "vitess.io/vitess/go/vt/vterrors" ) -// TODO(sougou): change these functions to be more permissive. -// Most string to number conversions should quietly convert to 0. - // numeric represents a numeric value extracted from // a Value, used for arithmetic operations. type numeric struct { @@ -50,8 +46,14 @@ func Add(v1, v2 Value) (Value, error) { } lv1, err := newNumeric(v1) + if err != nil { + return NULL, err + } lv2, err := newNumeric(v2) + if err != nil { + return NULL, err + } lresult, err := addNumericWithError(lv1, lv2) if err != nil { @@ -61,6 +63,30 @@ func Add(v1, v2 Value) (Value, error) { return castFromNumeric(lresult, lresult.typ), nil } +// Subtract takes two values and subtracts them +func Subtract(v1, v2 Value) (Value, error) { + if v1.IsNull() || v2.IsNull() { + return NULL, nil + } + + lv1, err := newNumeric(v1) + if err != nil { + return NULL, err + } + + lv2, err := newNumeric(v2) + if err != nil { + return NULL, err + } + + lresult, err := subtractNumericWithError(lv1, lv2) + if err != nil { + return NULL, err + } + + return castFromNumeric(lresult, lresult.typ), nil +} + // NullsafeAdd adds two Values in a null-safe manner. A null value // is treated as 0. If both values are null, then a null is returned. // If both values are not null, a numeric value is built @@ -243,7 +269,10 @@ func ToInt64(v Value) (int64, error) { // ToFloat64 converts Value to float64. func ToFloat64(v Value) (float64, error) { - num, _ := newNumeric(v) + num, err := newNumeric(v) + if err != nil { + return 0, err + } switch num.typ { case Int64: return float64(num.ival), nil @@ -373,7 +402,32 @@ func addNumericWithError(v1, v2 numeric) (numeric, error) { return floatPlusAny(v1.fval, v2), nil } panic("unreachable") +} +func subtractNumericWithError(v1, v2 numeric) (numeric, error) { + switch v1.typ { + case Int64: + switch v2.typ { + case Int64: + return intMinusIntWithError(v1.ival, v2.ival) + case Uint64: + return intMinusUintWithError(v1.ival, v2.uval) + case Float64: + return anyMinusFloat(v1, v2.fval), nil + } + case Uint64: + switch v2.typ { + case Int64: + return uintMinusIntWithError(v1.uval, v2.ival) + case Uint64: + return uintMinusUintWithError(v1.uval, v2.uval) + case Float64: + return anyMinusFloat(v1, v2.fval), nil + } + case Float64: + return floatMinusAny(v1.fval, v2), nil + } + panic("unreachable") } // prioritize reorders the input parameters @@ -388,7 +442,6 @@ func prioritize(v1, v2 numeric) (altv1, altv2 numeric) { if v2.typ == Float64 { return v2, v1 } - } return v1, v2 } @@ -415,36 +468,67 @@ func intPlusIntWithError(v1, v2 int64) (numeric, error) { return numeric{typ: Int64, ival: result}, nil } +func intMinusIntWithError(v1, v2 int64) (numeric, error) { + result := v1 - v2 + + if (result < v1) != (v2 > 0) { + return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v - %v", v1, v2) + } + return numeric{typ: Int64, ival: result}, nil +} + +func intMinusUintWithError(v1 int64, v2 uint64) (numeric, error) { + if v1 < 0 || v1 < int64(v2) { + return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) + } + return uintMinusUintWithError(uint64(v1), v2) +} + func uintPlusInt(v1 uint64, v2 int64) numeric { return uintPlusUint(v1, uint64(v2)) } func uintPlusIntWithError(v1 uint64, v2 int64) (numeric, error) { - if v2 >= math.MaxInt64 && v1 > 0 { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v + %v", v1, v2) + if v2 < 0 && v1 < uint64(v2) { + return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2) } - - //convert to int -> uint is because for numeric operators (such as + or -) - //where one of the operands is an unsigned integer, the result is unsigned by default. + // convert to int -> uint is because for numeric operators (such as + or -) + // where one of the operands is an unsigned integer, the result is unsigned by default. return uintPlusUintWithError(v1, uint64(v2)) } +func uintMinusIntWithError(v1 uint64, v2 int64) (numeric, error) { + if int64(v1) < v2 && v2 > 0 { + return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) + } + // uint - (- int) = uint + int + if v2 < 0 { + return uintPlusIntWithError(v1, -v2) + } + return uintMinusUintWithError(v1, uint64(v2)) +} + func uintPlusUint(v1, v2 uint64) numeric { result := v1 + v2 if result < v2 { return numeric{typ: Float64, fval: float64(v1) + float64(v2)} - } return numeric{typ: Uint64, uval: result} } func uintPlusUintWithError(v1, v2 uint64) (numeric, error) { result := v1 + v2 - if result < v2 { return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2) } + return numeric{typ: Uint64, uval: result}, nil +} +func uintMinusUintWithError(v1, v2 uint64) (numeric, error) { + result := v1 - v2 + if v2 > v1 { + return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) + } return numeric{typ: Uint64, uval: result}, nil } @@ -458,6 +542,26 @@ func floatPlusAny(v1 float64, v2 numeric) numeric { return numeric{typ: Float64, fval: v1 + v2.fval} } +func floatMinusAny(v1 float64, v2 numeric) numeric { + switch v2.typ { + case Int64: + v2.fval = float64(v2.ival) + case Uint64: + v2.fval = float64(v2.uval) + } + return numeric{typ: Float64, fval: v1 - v2.fval} +} + +func anyMinusFloat(v1 numeric, v2 float64) numeric { + switch v1.typ { + case Int64: + v1.fval = float64(v1.ival) + case Uint64: + v1.fval = float64(v1.uval) + } + return numeric{typ: Float64, fval: v1.fval - v2} +} + func castFromNumeric(v numeric, resultType querypb.Type) Value { switch { case IsSigned(resultType): diff --git a/go/sqltypes/arithmetic_test.go b/go/sqltypes/arithmetic_test.go index 81aeba0cf30..227b20c7d4c 100644 --- a/go/sqltypes/arithmetic_test.go +++ b/go/sqltypes/arithmetic_test.go @@ -29,14 +29,13 @@ import ( "vitess.io/vitess/go/vt/vterrors" ) -func TestAdd(t *testing.T) { +func TestSubtract(t *testing.T) { tcases := []struct { v1, v2 Value out Value err error }{{ - - //All Nulls + // All Nulls v1: NULL, v2: NULL, out: NULL, @@ -51,24 +50,159 @@ func TestAdd(t *testing.T) { v2: NewInt32(1), out: NULL, }, { + // case with negative value + v1: NewInt64(-1), + v2: NewInt64(-2), + out: NewInt64(1), + }, { + // testing for int64 overflow with min negative value + v1: NewInt64(math.MinInt64), + v2: NewInt64(1), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in -9223372036854775808 - 1"), + }, { + v1: NewUint64(4), + v2: NewInt64(5), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 4 - 5"), + }, { + // testing uint - int + v1: NewUint64(7), + v2: NewInt64(5), + out: NewUint64(2), + }, { + v1: NewUint64(math.MaxUint64), + v2: NewInt64(0), + out: NewUint64(math.MaxUint64), + }, { + // testing for int64 overflow + v1: NewInt64(math.MinInt64), + v2: NewUint64(0), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in -9223372036854775808 - 0"), + }, { + v1: TestValue(VarChar, "c"), + v2: NewInt64(1), + out: NewInt64(-1), + }, { + v1: NewUint64(1), + v2: TestValue(VarChar, "c"), + out: NewUint64(1), + }, { + // testing for error for parsing float value to uint64 + v1: TestValue(Uint64, "1.2"), + v2: NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), + }, { + // testing for error for parsing float value to uint64 + v1: NewUint64(2), + v2: TestValue(Uint64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), + }, { + // uint64 - uint64 + v1: NewUint64(8), + v2: NewUint64(4), + out: NewUint64(4), + }, { + // testing for float subtraction: float - int + v1: NewFloat64(1.2), + v2: NewInt64(2), + out: NewFloat64(-0.8), + }, { + // testing for float subtraction: float - uint + v1: NewFloat64(1.2), + v2: NewUint64(2), + out: NewFloat64(-0.8), + }, { + v1: NewInt64(-1), + v2: NewUint64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in -1 - 2"), + }, { + v1: NewInt64(2), + v2: NewUint64(1), + out: NewUint64(1), + }, { + // testing int64 - float64 method + v1: NewInt64(-2), + v2: NewFloat64(1.0), + out: NewFloat64(-3.0), + }, { + // testing uint64 - float64 method + v1: NewUint64(1), + v2: NewFloat64(-2.0), + out: NewFloat64(3.0), + }, { + // testing uint - int to return uintplusint + v1: NewUint64(1), + v2: NewInt64(-2), + out: NewUint64(3), + }, { + // testing for float - float + v1: NewFloat64(1.2), + v2: NewFloat64(3.2), + out: NewFloat64(-2), + }, { + // testing uint - uint if v2 > v1 + v1: NewUint64(2), + v2: NewUint64(4), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 2 - 4"), + }, { + // testing uint - (- int) + v1: NewUint64(1), + v2: NewInt64(-2), + out: NewUint64(3), + }} + + for _, tcase := range tcases { + + got, err := Subtract(tcase.v1, tcase.v2) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Subtract(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("Subtract(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) + } + } + +} + +func TestAdd(t *testing.T) { + tcases := []struct { + v1, v2 Value + out Value + err error + }{{ + // All Nulls + v1: NULL, + v2: NULL, + out: NULL, + }, { + // First value null. + v1: NewInt32(1), + v2: NULL, + out: NULL, + }, { + // Second value null. + v1: NULL, + v2: NewInt32(1), + out: NULL, + }, { // case with negatives v1: NewInt64(-1), v2: NewInt64(-2), out: NewInt64(-3), }, { - - // testing for overflow int64 + // testing for overflow int64, result will be unsigned int v1: NewInt64(math.MaxInt64), v2: NewUint64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in 2 + 9223372036854775807"), + out: NewUint64(9223372036854775809), }, { - v1: NewInt64(-2), v2: NewUint64(1), - out: NewUint64(math.MaxUint64), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 1 + -2"), }, { - v1: NewInt64(math.MaxInt64), v2: NewInt64(-2), out: NewInt64(9223372036854775805), @@ -83,30 +217,25 @@ func TestAdd(t *testing.T) { v2: NewUint64(2), err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 + 2"), }, { - // int64 underflow v1: NewInt64(math.MinInt64), v2: NewInt64(-2), err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in -9223372036854775808 + -2"), }, { - // checking int64 max value can be returned v1: NewInt64(math.MaxInt64), v2: NewUint64(0), out: NewUint64(9223372036854775807), }, { - // testing whether uint64 max value can be returned v1: NewUint64(math.MaxUint64), v2: NewInt64(0), out: NewUint64(math.MaxUint64), }, { - v1: NewUint64(math.MaxInt64), v2: NewInt64(1), out: NewUint64(9223372036854775808), }, { - v1: NewUint64(1), v2: TestValue(VarChar, "c"), out: NewUint64(1), @@ -114,6 +243,19 @@ func TestAdd(t *testing.T) { v1: NewUint64(1), v2: TestValue(VarChar, "1.2"), out: NewFloat64(2.2), + }, { + v1: TestValue(Int64, "1.2"), + v2: NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + v1: NewInt64(2), + v2: TestValue(Int64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // testing for uint64 overflow with max uint64 + int value + v1: NewUint64(math.MaxUint64), + v2: NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 + 2"), }} for _, tcase := range tcases { @@ -128,7 +270,7 @@ func TestAdd(t *testing.T) { } if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("Addition(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) + t.Errorf("Add(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) } } @@ -184,7 +326,7 @@ func TestNullsafeAdd(t *testing.T) { got := NullsafeAdd(tcase.v1, tcase.v2, Int64) if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("Add(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) + t.Errorf("NullsafeAdd(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) } } } @@ -456,6 +598,9 @@ func TestToFloat64(t *testing.T) { }, { v: NewFloat64(1.2), out: 1.2, + }, { + v: TestValue(Int64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), }} for _, tcase := range tcases { got, err := ToFloat64(tcase.v)