diff --git a/expression/builtin_time.go b/expression/builtin_time.go index 24cf93eab57ef..8c87de117687f 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -2748,6 +2748,18 @@ func (du *baseDateArithmitical) add(ctx sessionctx.Context, date types.Time, int return date, false, nil } +func (du *baseDateArithmitical) addDuration(ctx sessionctx.Context, d types.Duration, interval string, unit string) (types.Duration, bool, error) { + dur, err := types.ExtractDurationValue(unit, interval) + if err != nil { + return types.ZeroDuration, true, handleInvalidTimeError(ctx, err) + } + retDur, err := d.Add(dur) + if err != nil { + return types.ZeroDuration, true, err + } + return retDur, false, nil +} + func (du *baseDateArithmitical) sub(ctx sessionctx.Context, date types.Time, interval string, unit string) (types.Time, bool, error) { year, month, day, nano, err := types.ParseDurationValue(unit, interval) if err := handleInvalidTimeError(ctx, err); err != nil { @@ -2791,7 +2803,7 @@ func (c *addDateFunctionClass) getFunction(ctx sessionctx.Context, args []Expres } dateEvalTp := args[0].GetType().EvalType() - if dateEvalTp != types.ETString && dateEvalTp != types.ETInt { + if dateEvalTp != types.ETString && dateEvalTp != types.ETInt && dateEvalTp != types.ETDuration { dateEvalTp = types.ETDatetime } @@ -2801,8 +2813,35 @@ func (c *addDateFunctionClass) getFunction(ctx sessionctx.Context, args []Expres } argTps := []types.EvalType{dateEvalTp, intervalEvalTp, types.ETString} - bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDatetime, argTps...) - bf.tp.Flen, bf.tp.Decimal = mysql.MaxDatetimeFullWidth, types.UnspecifiedLength + var bf baseBuiltinFunc + if dateEvalTp == types.ETDuration { + unit, _, err := args[2].EvalString(ctx, chunk.Row{}) + if err != nil { + return nil, err + } + internalFsp := 0 + switch unit { + // If the unit has micro second, then the fsp must be the MaxFsp. + case "MICROSECOND", "SECOND_MICROSECOND", "MINUTE_MICROSECOND", "HOUR_MICROSECOND", "DAY_MICROSECOND": + internalFsp = types.MaxFsp + // If the unit is second, the fsp is related with the arg[1]'s. + case "SECOND": + internalFsp = types.MaxFsp + if intervalEvalTp != types.ETString { + internalFsp = mathutil.Min(args[1].GetType().Decimal, types.MaxFsp) + } + // Otherwise, the fsp should be 0. + } + bf = newBaseBuiltinFuncWithTp(ctx, args, types.ETDuration, argTps...) + arg0Dec, err := getExpressionFsp(ctx, args[0]) + if err != nil { + return nil, err + } + bf.tp.Flen, bf.tp.Decimal = mysql.MaxDurationWidthWithFsp, mathutil.Max(arg0Dec, internalFsp) + } else { + bf = newBaseBuiltinFuncWithTp(ctx, args, types.ETDatetime, argTps...) + bf.tp.Flen, bf.tp.Decimal = mysql.MaxDatetimeFullWidth, types.UnspecifiedLength + } switch { case dateEvalTp == types.ETString && intervalEvalTp == types.ETString: @@ -2865,6 +2904,21 @@ func (c *addDateFunctionClass) getFunction(ctx sessionctx.Context, args []Expres baseBuiltinFunc: bf, baseDateArithmitical: newDateArighmeticalUtil(), } + case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETString: + sig = &builtinAddDateDurationStringSig{ + baseBuiltinFunc: bf, + baseDateArithmitical: newDateArighmeticalUtil(), + } + case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETInt: + sig = &builtinAddDateDurationIntSig{ + baseBuiltinFunc: bf, + baseDateArithmitical: newDateArighmeticalUtil(), + } + case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETDecimal: + sig = &builtinAddDateDurationDecimalSig{ + baseBuiltinFunc: bf, + baseDateArithmitical: newDateArighmeticalUtil(), + } } return sig, nil } @@ -3265,6 +3319,79 @@ func (b *builtinAddDateDatetimeDecimalSig) evalTime(row chunk.Row) (types.Time, return result, isNull || err != nil, err } +type builtinAddDateDurationStringSig struct { + baseBuiltinFunc + baseDateArithmitical +} + +func (b *builtinAddDateDurationStringSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { + unit, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + dur, isNull, err := b.args[0].EvalDuration(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + interval, isNull, err := b.getIntervalFromString(b.ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + result, isNull, err := b.addDuration(b.ctx, dur, interval, unit) + return result, isNull || err != nil, err +} + +type builtinAddDateDurationIntSig struct { + baseBuiltinFunc + baseDateArithmitical +} + +func (b *builtinAddDateDurationIntSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { + unit, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + dur, isNull, err := b.args[0].EvalDuration(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + interval, isNull, err := b.getIntervalFromInt(b.ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + result, isNull, err := b.addDuration(b.ctx, dur, interval, unit) + return result, isNull || err != nil, err +} + +type builtinAddDateDurationDecimalSig struct { + baseBuiltinFunc + baseDateArithmitical +} + +func (b *builtinAddDateDurationDecimalSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { + unit, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + dur, isNull, err := b.args[0].EvalDuration(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + interval, isNull, err := b.getIntervalFromDecimal(b.ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + result, isNull, err := b.addDuration(b.ctx, dur, interval, unit) + return result, isNull || err != nil, err +} + type subDateFunctionClass struct { baseFunctionClass } diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index d6e40953ab037..2ac4b1e494293 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -1764,6 +1764,33 @@ func (s *testEvaluatorSuite) TestDateArithFuncs(c *C) { c.Assert(err, IsNil) c.Assert(v.GetMysqlTime().String(), Equals, test.expected) } + testDurations := []struct { + dur string + fsp int + unit string + format interface{} + expected string + }{ + { + dur: "00:00:00", + fsp: 0, + unit: "MICROSECOND", + format: "100", + expected: "00:00:00.000100", + }, + } + for _, tt := range testDurations { + dur, _, ok, err := types.StrToDuration(nil, tt.dur, tt.fsp) + c.Assert(err, IsNil) + c.Assert(ok, IsTrue) + args = types.MakeDatums(dur, tt.format, tt.unit) + f, err = fcAdd.getFunction(s.ctx, s.datumsToConstants(args)) + c.Assert(err, IsNil) + c.Assert(f, NotNil) + v, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v.GetMysqlDuration().String(), Equals, tt.expected) + } } func (s *testEvaluatorSuite) TestTimestamp(c *C) { diff --git a/types/const_test.go b/types/const_test.go index 2caa4c182b189..3f6573f733785 100644 --- a/types/const_test.go +++ b/types/const_test.go @@ -16,7 +16,6 @@ package types_test import ( "context" "flag" - "testing" . "github.com/pingcap/check" "github.com/pingcap/parser" @@ -30,11 +29,6 @@ import ( "github.com/pingcap/tidb/util/testleak" ) -func TestT(t *testing.T) { - CustomVerboseFlag = true - TestingT(t) -} - var _ = Suite(&testMySQLConstSuite{}) type testMySQLConstSuite struct { diff --git a/types/errors.go b/types/errors.go index f1f012cbea6e0..f58bbf34cf60e 100644 --- a/types/errors.go +++ b/types/errors.go @@ -14,6 +14,7 @@ package types import ( + "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" parser_types "github.com/pingcap/parser/types" @@ -59,28 +60,43 @@ var ( ErrWarnDataOutOfRange = terror.ClassTypes.New(codeDataOutOfRange, mysql.MySQLErrName[mysql.ErrWarnDataOutOfRange]) // ErrDuplicatedValueInType is returned when enum column has duplicated value. ErrDuplicatedValueInType = terror.ClassTypes.New(codeDuplicatedValueInType, mysql.MySQLErrName[mysql.ErrDuplicatedValueInType]) + // ErrDatetimeFunctionOverflow is returned when the calculation in datetime function cause overflow. + ErrDatetimeFunctionOverflow = terror.ClassTypes.New(codeDatetimeFunctionOverflow, mysql.MySQLErrName[mysql.ErrDatetimeFunctionOverflow]) + // ErrInvalidTimeFormat is returned when the time format is not correct. + ErrInvalidTimeFormat = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "invalid time format: '%v'") + // ErrInvalidWeekModeFormat is returned when the week mode is wrong. + ErrInvalidWeekModeFormat = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "invalid week mode format: '%v'") + // ErrInvalidYearFormat is returned when the input is not a valid year format. + ErrInvalidYearFormat = errors.New("invalid year format") + // ErrInvalidYear is returned when the input value is not a valid year. + ErrInvalidYear = errors.New("invalid year") + // ErrIncorrectDatetimeValue is returned when the input is not valid date time value. + ErrIncorrectDatetimeValue = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "Incorrect datetime value: '%s'") + // ErrTruncatedWrongValue is returned then + ErrTruncatedWrongValue = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, mysql.MySQLErrName[mysql.ErrTruncatedWrongValue]) ) const ( codeBadNumber terror.ErrCode = 1 - codeDataTooLong = terror.ErrCode(mysql.ErrDataTooLong) - codeIllegalValueForType = terror.ErrCode(mysql.ErrIllegalValueForType) - codeTruncated = terror.ErrCode(mysql.WarnDataTruncated) - codeOverflow = terror.ErrCode(mysql.ErrDataOutOfRange) - codeDivByZero = terror.ErrCode(mysql.ErrDivisionByZero) - codeTooBigDisplayWidth = terror.ErrCode(mysql.ErrTooBigDisplaywidth) - codeTooBigFieldLength = terror.ErrCode(mysql.ErrTooBigFieldlength) - codeTooBigSet = terror.ErrCode(mysql.ErrTooBigSet) - codeTooBigScale = terror.ErrCode(mysql.ErrTooBigScale) - codeTooBigPrecision = terror.ErrCode(mysql.ErrTooBigPrecision) - codeWrongFieldSpec = terror.ErrCode(mysql.ErrWrongFieldSpec) - codeTruncatedWrongValue = terror.ErrCode(mysql.ErrTruncatedWrongValue) - codeUnknown = terror.ErrCode(mysql.ErrUnknown) - codeInvalidDefault = terror.ErrCode(mysql.ErrInvalidDefault) - codeMBiggerThanD = terror.ErrCode(mysql.ErrMBiggerThanD) - codeDataOutOfRange = terror.ErrCode(mysql.ErrWarnDataOutOfRange) - codeDuplicatedValueInType = terror.ErrCode(mysql.ErrDuplicatedValueInType) + codeDataTooLong = terror.ErrCode(mysql.ErrDataTooLong) + codeIllegalValueForType = terror.ErrCode(mysql.ErrIllegalValueForType) + codeTruncated = terror.ErrCode(mysql.WarnDataTruncated) + codeOverflow = terror.ErrCode(mysql.ErrDataOutOfRange) + codeDivByZero = terror.ErrCode(mysql.ErrDivisionByZero) + codeTooBigDisplayWidth = terror.ErrCode(mysql.ErrTooBigDisplaywidth) + codeTooBigFieldLength = terror.ErrCode(mysql.ErrTooBigFieldlength) + codeTooBigSet = terror.ErrCode(mysql.ErrTooBigSet) + codeTooBigScale = terror.ErrCode(mysql.ErrTooBigScale) + codeTooBigPrecision = terror.ErrCode(mysql.ErrTooBigPrecision) + codeWrongFieldSpec = terror.ErrCode(mysql.ErrWrongFieldSpec) + codeTruncatedWrongValue = terror.ErrCode(mysql.ErrTruncatedWrongValue) + codeUnknown = terror.ErrCode(mysql.ErrUnknown) + codeInvalidDefault = terror.ErrCode(mysql.ErrInvalidDefault) + codeMBiggerThanD = terror.ErrCode(mysql.ErrMBiggerThanD) + codeDataOutOfRange = terror.ErrCode(mysql.ErrWarnDataOutOfRange) + codeDuplicatedValueInType = terror.ErrCode(mysql.ErrDuplicatedValueInType) + codeDatetimeFunctionOverflow = terror.ErrCode(mysql.ErrDatetimeFunctionOverflow) ) var ( @@ -92,23 +108,24 @@ var ( func init() { typesMySQLErrCodes := map[terror.ErrCode]uint16{ - codeDataTooLong: mysql.ErrDataTooLong, - codeIllegalValueForType: mysql.ErrIllegalValueForType, - codeTruncated: mysql.WarnDataTruncated, - codeOverflow: mysql.ErrDataOutOfRange, - codeDivByZero: mysql.ErrDivisionByZero, - codeTooBigDisplayWidth: mysql.ErrTooBigDisplaywidth, - codeTooBigFieldLength: mysql.ErrTooBigFieldlength, - codeTooBigSet: mysql.ErrTooBigSet, - codeTooBigScale: mysql.ErrTooBigScale, - codeTooBigPrecision: mysql.ErrTooBigPrecision, - codeWrongFieldSpec: mysql.ErrWrongFieldSpec, - codeTruncatedWrongValue: mysql.ErrTruncatedWrongValue, - codeUnknown: mysql.ErrUnknown, - codeInvalidDefault: mysql.ErrInvalidDefault, - codeMBiggerThanD: mysql.ErrMBiggerThanD, - codeDataOutOfRange: mysql.ErrWarnDataOutOfRange, - codeDuplicatedValueInType: mysql.ErrDuplicatedValueInType, + codeDataTooLong: mysql.ErrDataTooLong, + codeIllegalValueForType: mysql.ErrIllegalValueForType, + codeTruncated: mysql.WarnDataTruncated, + codeOverflow: mysql.ErrDataOutOfRange, + codeDivByZero: mysql.ErrDivisionByZero, + codeTooBigDisplayWidth: mysql.ErrTooBigDisplaywidth, + codeTooBigFieldLength: mysql.ErrTooBigFieldlength, + codeTooBigSet: mysql.ErrTooBigSet, + codeTooBigScale: mysql.ErrTooBigScale, + codeTooBigPrecision: mysql.ErrTooBigPrecision, + codeWrongFieldSpec: mysql.ErrWrongFieldSpec, + codeTruncatedWrongValue: mysql.ErrTruncatedWrongValue, + codeUnknown: mysql.ErrUnknown, + codeInvalidDefault: mysql.ErrInvalidDefault, + codeMBiggerThanD: mysql.ErrMBiggerThanD, + codeDataOutOfRange: mysql.ErrWarnDataOutOfRange, + codeDuplicatedValueInType: mysql.ErrDuplicatedValueInType, + codeDatetimeFunctionOverflow: mysql.ErrDatetimeFunctionOverflow, } terror.ErrClassToMySQLCodes[terror.ClassTypes] = typesMySQLErrCodes } diff --git a/types/fsp_test.go b/types/fsp_test.go index e4b73ee14e008..8802e87d5b3e4 100644 --- a/types/fsp_test.go +++ b/types/fsp_test.go @@ -15,15 +15,10 @@ package types import ( "strconv" - "testing" . "github.com/pingcap/check" ) -func Test(t *testing.T) { - TestingT(t) -} - var _ = Suite(&FspTest{}) type FspTest struct{} diff --git a/types/time.go b/types/time.go index b284b20ed382c..1915b01873629 100644 --- a/types/time.go +++ b/types/time.go @@ -29,18 +29,7 @@ import ( "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/util/logutil" -) - -// Portable analogs of some common call errors. -var ( - ErrInvalidTimeFormat = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "invalid time format: '%v'") - ErrInvalidWeekModeFormat = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "invalid week mode format: '%v'") - ErrInvalidYearFormat = errors.New("invalid year format") - ErrInvalidYear = errors.New("invalid year") - ErrZeroDate = errors.New("datetime zero in date") - ErrIncorrectDatetimeValue = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "Incorrect datetime value: '%s'") - ErrDatetimeFunctionOverflow = terror.ClassTypes.New(mysql.ErrDatetimeFunctionOverflow, mysql.MySQLErrName[mysql.ErrDatetimeFunctionOverflow]) - ErrTruncatedWrongValue = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, mysql.MySQLErrName[mysql.ErrTruncatedWrongValue]) + tidbMath "github.com/pingcap/tidb/util/math" ) // Time format without fractional seconds precision. @@ -196,6 +185,13 @@ var ( } ) +const ( + // GoDurationDay is the gotime.Duration which equals to a Day. + GoDurationDay = gotime.Hour * 24 + // GoDurationWeek is the gotime.Duration which equals to a Week. + GoDurationWeek = GoDurationDay * 7 +) + // FromGoTime translates time.Time to mysql time internal representation. func FromGoTime(t gotime.Time) MysqlTime { year, month, day := t.Date() @@ -1198,9 +1194,9 @@ func ParseDuration(sc *stmtctx.StatementContext, str string, fsp int) (Duration, // TruncateOverflowMySQLTime truncates d when it overflows, and return ErrTruncatedWrongVal. func TruncateOverflowMySQLTime(d gotime.Duration) (gotime.Duration, error) { if d > MaxTime { - return MaxTime, ErrTruncatedWrongVal.GenWithStackByArgs("time", d.String()) + return MaxTime, ErrTruncatedWrongVal.GenWithStackByArgs("time", d) } else if d < MinTime { - return MinTime, ErrTruncatedWrongVal.GenWithStackByArgs("time", d.String()) + return MinTime, ErrTruncatedWrongVal.GenWithStackByArgs("time", d) } return d, nil @@ -1615,7 +1611,10 @@ func ExtractDurationNum(d *Duration, unit string) (int64, error) { } } -func parseSingleTimeValue(unit string, format string) (int64, int64, int64, int64, error) { +// parseSingleTimeValue parse the format according the given unit. If we set strictCheck true, we'll check whether +// the converted value not exceed the range of MySQL's TIME type. +// The first four returned values are year, month, day and nanosecond. +func parseSingleTimeValue(unit string, format string, strictCheck bool) (int64, int64, int64, int64, error) { // Format is a preformatted number, it format should be A[.[B]]. decimalPointPos := strings.IndexRune(format, '.') if decimalPointPos == -1 { @@ -1652,33 +1651,59 @@ func parseSingleTimeValue(unit string, format string) (int64, int64, int64, int6 err = ErrTruncatedWrongValue.GenWithStackByArgs(format) } } - const gotimeDay = 24 * gotime.Hour switch strings.ToUpper(unit) { case "MICROSECOND": - dayCount := riv / int64(gotimeDay/gotime.Microsecond) - riv %= int64(gotimeDay / gotime.Microsecond) + if strictCheck && tidbMath.Abs(riv) > TimeMaxValueSeconds*1000 { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } + dayCount := riv / int64(GoDurationDay/gotime.Microsecond) + riv %= int64(GoDurationDay / gotime.Microsecond) return 0, 0, dayCount, riv * int64(gotime.Microsecond), err case "SECOND": - dayCount := iv / int64(gotimeDay/gotime.Second) - iv %= int64(gotimeDay / gotime.Second) + if strictCheck && tidbMath.Abs(iv) > TimeMaxValueSeconds { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } + dayCount := iv / int64(GoDurationDay/gotime.Second) + iv %= int64(GoDurationDay / gotime.Second) return 0, 0, dayCount, iv*int64(gotime.Second) + dv*int64(gotime.Microsecond), err case "MINUTE": - dayCount := riv / int64(gotimeDay/gotime.Minute) - riv %= int64(gotimeDay / gotime.Minute) + if strictCheck && tidbMath.Abs(riv) > TimeMaxHour*60+TimeMaxMinute { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } + dayCount := riv / int64(GoDurationDay/gotime.Minute) + riv %= int64(GoDurationDay / gotime.Minute) return 0, 0, dayCount, riv * int64(gotime.Minute), err case "HOUR": + if strictCheck && tidbMath.Abs(riv) > TimeMaxHour { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } dayCount := riv / 24 riv %= 24 return 0, 0, dayCount, riv * int64(gotime.Hour), err case "DAY": + if strictCheck && tidbMath.Abs(riv) > TimeMaxHour/24 { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } return 0, 0, riv, 0, err case "WEEK": + if strictCheck && 7*tidbMath.Abs(riv) > TimeMaxHour/24 { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } return 0, 0, 7 * riv, 0, err case "MONTH": + if strictCheck && tidbMath.Abs(riv) > 1 { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } return 0, riv, 0, 0, err case "QUARTER": + if strictCheck { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } return 0, 3 * riv, 0, 0, err case "YEAR": + if strictCheck { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } return riv, 0, 0, 0, err } @@ -1749,13 +1774,28 @@ func parseTimeValue(format string, index, cnt int) (int64, int64, int64, int64, return years, months, days, seconds*int64(gotime.Second) + microseconds*int64(gotime.Microsecond), nil } +func parseAndValidateDurationValue(format string, index, cnt int) (int64, error) { + year, month, day, nano, err := parseTimeValue(format, index, cnt) + if err != nil { + return 0, err + } + if year != 0 || month != 0 || tidbMath.Abs(day) > TimeMaxHour/24 { + return 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } + dur := day*int64(GoDurationDay) + nano + if tidbMath.Abs(dur) > int64(MaxTime) { + return 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } + return dur, nil +} + // ParseDurationValue parses time value from time unit and format. // Returns y years m months d days + n nanoseconds // Nanoseconds will no longer than one day. func ParseDurationValue(unit string, format string) (y int64, m int64, d int64, n int64, _ error) { switch strings.ToUpper(unit) { case "MICROSECOND", "SECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "QUARTER", "YEAR": - return parseSingleTimeValue(unit, format) + return parseSingleTimeValue(unit, format, false) case "SECOND_MICROSECOND": return parseTimeValue(format, MicrosecondIndex, SecondMicrosecondMaxCnt) case "MINUTE_MICROSECOND": @@ -1779,7 +1819,93 @@ func ParseDurationValue(unit string, format string) (y int64, m int64, d int64, case "YEAR_MONTH": return parseTimeValue(format, MonthIndex, YearMonthMaxCnt) default: - return 0, 0, 0, 0, errors.Errorf("invalid singel timeunit - %s", unit) + return 0, 0, 0, 0, errors.Errorf("invalid single timeunit - %s", unit) + } +} + +// ExtractDurationValue extract the value from format to Duration. +func ExtractDurationValue(unit string, format string) (Duration, error) { + unit = strings.ToUpper(unit) + switch unit { + case "MICROSECOND", "SECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "QUARTER", "YEAR": + _, month, day, nano, err := parseSingleTimeValue(unit, format, true) + if err != nil { + return ZeroDuration, err + } + dur := Duration{Duration: gotime.Duration((month*30+day)*int64(GoDurationDay) + nano)} + if unit == "MICROSECOND" { + dur.Fsp = MaxFsp + } + return dur, err + case "SECOND_MICROSECOND": + d, err := parseAndValidateDurationValue(format, MicrosecondIndex, SecondMicrosecondMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: MaxFsp}, nil + case "MINUTE_MICROSECOND": + d, err := parseAndValidateDurationValue(format, MicrosecondIndex, MinuteMicrosecondMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: MaxFsp}, nil + case "MINUTE_SECOND": + d, err := parseAndValidateDurationValue(format, SecondIndex, MinuteSecondMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: MaxFsp}, nil + case "HOUR_MICROSECOND": + d, err := parseAndValidateDurationValue(format, MicrosecondIndex, HourMicrosecondMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: MaxFsp}, nil + case "HOUR_SECOND": + d, err := parseAndValidateDurationValue(format, SecondIndex, HourSecondMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: MaxFsp}, nil + case "HOUR_MINUTE": + d, err := parseAndValidateDurationValue(format, MinuteIndex, HourMinuteMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: 0}, nil + case "DAY_MICROSECOND": + d, err := parseAndValidateDurationValue(format, MicrosecondIndex, DayMicrosecondMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: MaxFsp}, nil + case "DAY_SECOND": + d, err := parseAndValidateDurationValue(format, SecondIndex, DaySecondMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: MaxFsp}, nil + case "DAY_MINUTE": + d, err := parseAndValidateDurationValue(format, MinuteIndex, DayMinuteMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: 0}, nil + case "DAY_HOUR": + d, err := parseAndValidateDurationValue(format, HourIndex, DayHourMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: 0}, nil + case "YEAR_MONTH": + _, err := parseAndValidateDurationValue(format, MonthIndex, YearMonthMaxCnt) + if err != nil { + return ZeroDuration, err + } + // MONTH must exceed the limit of mysql's duration. So just return overflow error. + return ZeroDuration, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + default: + return ZeroDuration, errors.Errorf("invalid single timeunit - %s", unit) } } diff --git a/types/time_test.go b/types/time_test.go index b7b5a17aa897b..57c6cca351d6f 100644 --- a/types/time_test.go +++ b/types/time_test.go @@ -18,6 +18,7 @@ import ( "time" . "github.com/pingcap/check" + "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/sessionctx/stmtctx" @@ -1106,6 +1107,141 @@ func (s *testTimeSuite) TestCheckTimestamp(c *C) { } } +func (s *testTimeSuite) TestExtractDurationValue(c *C) { + tests := []struct { + unit string + format string + ans string + failed bool + }{ + { + unit: "MICROSECOND", + format: "50", + ans: "00:00:00.000050", + }, + { + unit: "SECOND", + format: "50", + ans: "00:00:50", + }, + { + unit: "MINUTE", + format: "10", + ans: "00:10:00", + }, + { + unit: "HOUR", + format: "10", + ans: "10:00:00", + }, + { + unit: "DAY", + format: "1", + ans: "24:00:00", + }, + { + unit: "WEEK", + format: "2", + ans: "336:00:00", + }, + { + unit: "SECOND_MICROSECOND", + format: "61.01", + ans: "00:01:01.010000", + }, + { + unit: "MINUTE_MICROSECOND", + format: "01:61.01", + ans: "00:02:01.010000", + }, + { + unit: "MINUTE_SECOND", + format: "61:61", + ans: "01:02:01.000000", + }, + { + unit: "HOUR_MICROSECOND", + format: "01:61:01.01", + ans: "02:01:01.010000", + }, + { + unit: "HOUR_SECOND", + format: "01:61:01", + ans: "02:01:01.000000", + }, + { + unit: "HOUr_MINUTE", + format: "2:2", + ans: "02:02:00", + }, + { + unit: "DAY_MICRoSECOND", + format: "1 1:1:1.02", + ans: "25:01:01.020000", + }, + { + unit: "DAY_SeCOND", + format: "1 02:03:04", + ans: "26:03:04.000000", + }, + { + unit: "DAY_MINUTE", + format: "1 1:2", + ans: "25:02:00", + }, + { + unit: "DAY_HOUr", + format: "1 1", + ans: "25:00:00", + }, + { + unit: "DAY", + format: "-35", + failed: true, + }, + { + unit: "day", + format: "34", + ans: "816:00:00", + }, + { + unit: "SECOND", + format: "-3020400", + failed: true, + }, + { + unit: "MONTH", + format: "1", + ans: "720:00:00", + }, + { + unit: "MONTH", + format: "-2", + failed: true, + }, + { + unit: "DAY_second", + format: "34 23:59:59", + failed: true, + }, + { + unit: "DAY_hOUR", + format: "-34 23", + failed: true, + }, + } + failedComment := "failed at case %d, unit: %s, format: %s" + for i, tt := range tests { + dur, err := types.ExtractDurationValue(tt.unit, tt.format) + if tt.failed { + c.Assert(err, NotNil, Commentf(failedComment+", dur: %v", i, tt.unit, tt.format, dur.String())) + } else { + c.Assert(err, IsNil, Commentf(failedComment+", error stack", i, tt.unit, tt.format, errors.ErrorStack(err))) + c.Assert(dur.String(), Equals, tt.ans, Commentf(failedComment, i, tt.unit, tt.format)) + } + } +} + func (s *testTimeSuite) TestCurrentTime(c *C) { res := types.CurrentTime(mysql.TypeTimestamp) c.Assert(res.Time, NotNil) diff --git a/util/math/math.go b/util/math/math.go index 6349b799756c5..3a25178326261 100644 --- a/util/math/math.go +++ b/util/math/math.go @@ -15,8 +15,8 @@ package math import "math" -// http://cavaliercoder.com/blog/optimized-abs-for-int64-in-go.html -func abs(n int64) int64 { +// Abs implement the abs function according to http://cavaliercoder.com/blog/optimized-abs-for-int64-in-go.html +func Abs(n int64) int64 { y := n >> 63 return (n ^ y) - y } @@ -46,5 +46,5 @@ func StrLenOfInt64Fast(x int64) int { if x < 0 { size = 1 // add "-" sign on the length count } - return size + StrLenOfUint64Fast(uint64(abs(x))) + return size + StrLenOfUint64Fast(uint64(Abs(x))) }