From 7b8e1744175d3e55333bb1436320e4b6835c365a Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Thu, 20 Jun 2019 13:47:46 +0800 Subject: [PATCH] *: add_date can return mysql.Time (#9830) (#10718) --- expression/builtin_time.go | 133 +++++++++++++++++++++++- expression/builtin_time_test.go | 27 +++++ mysql/const_test.go | 6 -- types/errors.go | 85 ++++++++++------ types/time.go | 174 +++++++++++++++++++++++++++----- types/time_test.go | 136 +++++++++++++++++++++++++ util/math/math.go | 20 ++++ util/math/math_test.go | 36 +++++++ 8 files changed, 552 insertions(+), 65 deletions(-) create mode 100644 util/math/math.go create mode 100644 util/math/math_test.go diff --git a/expression/builtin_time.go b/expression/builtin_time.go index ffe238a07b17a..2841dfd5ca839 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -2727,6 +2727,18 @@ func (du *baseDateArithmitical) add(ctx sessionctx.Context, date types.Time, int return date, false, nil } +func (du *baseDateArithmitical) addDuration(ctx sessionctx.Context, d types.Duration, interval string, unit string) (types.Duration, bool, error) { + dur, err := types.ExtractDurationValue(unit, interval) + if err != nil { + return types.ZeroDuration, true, handleInvalidTimeError(ctx, err) + } + retDur, err := d.Add(dur) + if err != nil { + return types.ZeroDuration, true, err + } + return retDur, false, nil +} + func (du *baseDateArithmitical) sub(ctx sessionctx.Context, date types.Time, interval string, unit string) (types.Time, bool, error) { year, month, day, nano, err := types.ParseDurationValue(unit, interval) if err := handleInvalidTimeError(ctx, err); err != nil { @@ -2770,7 +2782,7 @@ func (c *addDateFunctionClass) getFunction(ctx sessionctx.Context, args []Expres } dateEvalTp := args[0].GetType().EvalType() - if dateEvalTp != types.ETString && dateEvalTp != types.ETInt { + if dateEvalTp != types.ETString && dateEvalTp != types.ETInt && dateEvalTp != types.ETDuration { dateEvalTp = types.ETDatetime } @@ -2780,8 +2792,35 @@ func (c *addDateFunctionClass) getFunction(ctx sessionctx.Context, args []Expres } argTps := []types.EvalType{dateEvalTp, intervalEvalTp, types.ETString} - bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDatetime, argTps...) - bf.tp.Flen, bf.tp.Decimal = mysql.MaxDatetimeFullWidth, types.UnspecifiedLength + var bf baseBuiltinFunc + if dateEvalTp == types.ETDuration { + unit, _, err := args[2].EvalString(ctx, chunk.Row{}) + if err != nil { + return nil, err + } + internalFsp := 0 + switch unit { + // If the unit has micro second, then the fsp must be the MaxFsp. + case "MICROSECOND", "SECOND_MICROSECOND", "MINUTE_MICROSECOND", "HOUR_MICROSECOND", "DAY_MICROSECOND": + internalFsp = types.MaxFsp + // If the unit is second, the fsp is related with the arg[1]'s. + case "SECOND": + internalFsp = types.MaxFsp + if intervalEvalTp != types.ETString { + internalFsp = mathutil.Min(args[1].GetType().Decimal, types.MaxFsp) + } + // Otherwise, the fsp should be 0. + } + bf = newBaseBuiltinFuncWithTp(ctx, args, types.ETDuration, argTps...) + arg0Dec, err := getExpressionFsp(ctx, args[0]) + if err != nil { + return nil, err + } + bf.tp.Flen, bf.tp.Decimal = mysql.MaxDurationWidthWithFsp, mathutil.Max(arg0Dec, internalFsp) + } else { + bf = newBaseBuiltinFuncWithTp(ctx, args, types.ETDatetime, argTps...) + bf.tp.Flen, bf.tp.Decimal = mysql.MaxDatetimeFullWidth, types.UnspecifiedLength + } switch { case dateEvalTp == types.ETString && intervalEvalTp == types.ETString: @@ -2844,6 +2883,21 @@ func (c *addDateFunctionClass) getFunction(ctx sessionctx.Context, args []Expres baseBuiltinFunc: bf, baseDateArithmitical: newDateArighmeticalUtil(), } + case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETString: + sig = &builtinAddDateDurationStringSig{ + baseBuiltinFunc: bf, + baseDateArithmitical: newDateArighmeticalUtil(), + } + case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETInt: + sig = &builtinAddDateDurationIntSig{ + baseBuiltinFunc: bf, + baseDateArithmitical: newDateArighmeticalUtil(), + } + case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETDecimal: + sig = &builtinAddDateDurationDecimalSig{ + baseBuiltinFunc: bf, + baseDateArithmitical: newDateArighmeticalUtil(), + } } return sig, nil } @@ -3244,6 +3298,79 @@ func (b *builtinAddDateDatetimeDecimalSig) evalTime(row chunk.Row) (types.Time, return result, isNull || err != nil, errors.Trace(err) } +type builtinAddDateDurationStringSig struct { + baseBuiltinFunc + baseDateArithmitical +} + +func (b *builtinAddDateDurationStringSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { + unit, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + dur, isNull, err := b.args[0].EvalDuration(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + interval, isNull, err := b.getIntervalFromString(b.ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + result, isNull, err := b.addDuration(b.ctx, dur, interval, unit) + return result, isNull || err != nil, err +} + +type builtinAddDateDurationIntSig struct { + baseBuiltinFunc + baseDateArithmitical +} + +func (b *builtinAddDateDurationIntSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { + unit, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + dur, isNull, err := b.args[0].EvalDuration(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + interval, isNull, err := b.getIntervalFromInt(b.ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + result, isNull, err := b.addDuration(b.ctx, dur, interval, unit) + return result, isNull || err != nil, err +} + +type builtinAddDateDurationDecimalSig struct { + baseBuiltinFunc + baseDateArithmitical +} + +func (b *builtinAddDateDurationDecimalSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { + unit, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + dur, isNull, err := b.args[0].EvalDuration(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + interval, isNull, err := b.getIntervalFromDecimal(b.ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + result, isNull, err := b.addDuration(b.ctx, dur, interval, unit) + return result, isNull || err != nil, err +} + type subDateFunctionClass struct { baseFunctionClass } diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index 14c532763dcee..a96c1c2412df6 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -1742,6 +1742,33 @@ func (s *testEvaluatorSuite) TestDateArithFuncs(c *C) { c.Assert(err, IsNil) c.Assert(v.GetMysqlTime().String(), Equals, test.expected) } + testDurations := []struct { + dur string + fsp int + unit string + format interface{} + expected string + }{ + { + dur: "00:00:00", + fsp: 0, + unit: "MICROSECOND", + format: "100", + expected: "00:00:00.000100", + }, + } + for _, tt := range testDurations { + dur, _, ok, err := types.StrToDuration(nil, tt.dur, tt.fsp) + c.Assert(err, IsNil) + c.Assert(ok, IsTrue) + args = types.MakeDatums(dur, tt.format, tt.unit) + f, err = fcAdd.getFunction(s.ctx, s.datumsToConstants(args)) + c.Assert(err, IsNil) + c.Assert(f, NotNil) + v, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v.GetMysqlDuration().String(), Equals, tt.expected) + } } func (s *testEvaluatorSuite) TestTimestamp(c *C) { diff --git a/mysql/const_test.go b/mysql/const_test.go index 11c61157b7cd9..dd2076502867f 100644 --- a/mysql/const_test.go +++ b/mysql/const_test.go @@ -15,7 +15,6 @@ package mysql_test import ( "flag" - "testing" . "github.com/pingcap/check" "github.com/pingcap/parser" @@ -30,11 +29,6 @@ import ( "golang.org/x/net/context" ) -func TestT(t *testing.T) { - CustomVerboseFlag = true - TestingT(t) -} - var _ = Suite(&testMySQLConstSuite{}) type testMySQLConstSuite struct { diff --git a/types/errors.go b/types/errors.go index 9e1919b45d97e..f58bbf34cf60e 100644 --- a/types/errors.go +++ b/types/errors.go @@ -14,6 +14,7 @@ package types import ( + "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" parser_types "github.com/pingcap/parser/types" @@ -57,27 +58,45 @@ var ( // ErrWarnDataOutOfRange is returned when the value in a numeric column that is outside the permissible range of the column data type. // See https://dev.mysql.com/doc/refman/5.5/en/out-of-range-and-overflow.html for details ErrWarnDataOutOfRange = terror.ClassTypes.New(codeDataOutOfRange, mysql.MySQLErrName[mysql.ErrWarnDataOutOfRange]) + // ErrDuplicatedValueInType is returned when enum column has duplicated value. + ErrDuplicatedValueInType = terror.ClassTypes.New(codeDuplicatedValueInType, mysql.MySQLErrName[mysql.ErrDuplicatedValueInType]) + // ErrDatetimeFunctionOverflow is returned when the calculation in datetime function cause overflow. + ErrDatetimeFunctionOverflow = terror.ClassTypes.New(codeDatetimeFunctionOverflow, mysql.MySQLErrName[mysql.ErrDatetimeFunctionOverflow]) + // ErrInvalidTimeFormat is returned when the time format is not correct. + ErrInvalidTimeFormat = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "invalid time format: '%v'") + // ErrInvalidWeekModeFormat is returned when the week mode is wrong. + ErrInvalidWeekModeFormat = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "invalid week mode format: '%v'") + // ErrInvalidYearFormat is returned when the input is not a valid year format. + ErrInvalidYearFormat = errors.New("invalid year format") + // ErrInvalidYear is returned when the input value is not a valid year. + ErrInvalidYear = errors.New("invalid year") + // ErrIncorrectDatetimeValue is returned when the input is not valid date time value. + ErrIncorrectDatetimeValue = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "Incorrect datetime value: '%s'") + // ErrTruncatedWrongValue is returned then + ErrTruncatedWrongValue = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, mysql.MySQLErrName[mysql.ErrTruncatedWrongValue]) ) const ( codeBadNumber terror.ErrCode = 1 - codeDataTooLong = terror.ErrCode(mysql.ErrDataTooLong) - codeIllegalValueForType = terror.ErrCode(mysql.ErrIllegalValueForType) - codeTruncated = terror.ErrCode(mysql.WarnDataTruncated) - codeOverflow = terror.ErrCode(mysql.ErrDataOutOfRange) - codeDivByZero = terror.ErrCode(mysql.ErrDivisionByZero) - codeTooBigDisplayWidth = terror.ErrCode(mysql.ErrTooBigDisplaywidth) - codeTooBigFieldLength = terror.ErrCode(mysql.ErrTooBigFieldlength) - codeTooBigSet = terror.ErrCode(mysql.ErrTooBigSet) - codeTooBigScale = terror.ErrCode(mysql.ErrTooBigScale) - codeTooBigPrecision = terror.ErrCode(mysql.ErrTooBigPrecision) - codeWrongFieldSpec = terror.ErrCode(mysql.ErrWrongFieldSpec) - codeTruncatedWrongValue = terror.ErrCode(mysql.ErrTruncatedWrongValue) - codeUnknown = terror.ErrCode(mysql.ErrUnknown) - codeInvalidDefault = terror.ErrCode(mysql.ErrInvalidDefault) - codeMBiggerThanD = terror.ErrCode(mysql.ErrMBiggerThanD) - codeDataOutOfRange = terror.ErrCode(mysql.ErrWarnDataOutOfRange) + codeDataTooLong = terror.ErrCode(mysql.ErrDataTooLong) + codeIllegalValueForType = terror.ErrCode(mysql.ErrIllegalValueForType) + codeTruncated = terror.ErrCode(mysql.WarnDataTruncated) + codeOverflow = terror.ErrCode(mysql.ErrDataOutOfRange) + codeDivByZero = terror.ErrCode(mysql.ErrDivisionByZero) + codeTooBigDisplayWidth = terror.ErrCode(mysql.ErrTooBigDisplaywidth) + codeTooBigFieldLength = terror.ErrCode(mysql.ErrTooBigFieldlength) + codeTooBigSet = terror.ErrCode(mysql.ErrTooBigSet) + codeTooBigScale = terror.ErrCode(mysql.ErrTooBigScale) + codeTooBigPrecision = terror.ErrCode(mysql.ErrTooBigPrecision) + codeWrongFieldSpec = terror.ErrCode(mysql.ErrWrongFieldSpec) + codeTruncatedWrongValue = terror.ErrCode(mysql.ErrTruncatedWrongValue) + codeUnknown = terror.ErrCode(mysql.ErrUnknown) + codeInvalidDefault = terror.ErrCode(mysql.ErrInvalidDefault) + codeMBiggerThanD = terror.ErrCode(mysql.ErrMBiggerThanD) + codeDataOutOfRange = terror.ErrCode(mysql.ErrWarnDataOutOfRange) + codeDuplicatedValueInType = terror.ErrCode(mysql.ErrDuplicatedValueInType) + codeDatetimeFunctionOverflow = terror.ErrCode(mysql.ErrDatetimeFunctionOverflow) ) var ( @@ -89,22 +108,24 @@ var ( func init() { typesMySQLErrCodes := map[terror.ErrCode]uint16{ - codeDataTooLong: mysql.ErrDataTooLong, - codeIllegalValueForType: mysql.ErrIllegalValueForType, - codeTruncated: mysql.WarnDataTruncated, - codeOverflow: mysql.ErrDataOutOfRange, - codeDivByZero: mysql.ErrDivisionByZero, - codeTooBigDisplayWidth: mysql.ErrTooBigDisplaywidth, - codeTooBigFieldLength: mysql.ErrTooBigFieldlength, - codeTooBigSet: mysql.ErrTooBigSet, - codeTooBigScale: mysql.ErrTooBigScale, - codeTooBigPrecision: mysql.ErrTooBigPrecision, - codeWrongFieldSpec: mysql.ErrWrongFieldSpec, - codeTruncatedWrongValue: mysql.ErrTruncatedWrongValue, - codeUnknown: mysql.ErrUnknown, - codeInvalidDefault: mysql.ErrInvalidDefault, - codeMBiggerThanD: mysql.ErrMBiggerThanD, - codeDataOutOfRange: mysql.ErrWarnDataOutOfRange, + codeDataTooLong: mysql.ErrDataTooLong, + codeIllegalValueForType: mysql.ErrIllegalValueForType, + codeTruncated: mysql.WarnDataTruncated, + codeOverflow: mysql.ErrDataOutOfRange, + codeDivByZero: mysql.ErrDivisionByZero, + codeTooBigDisplayWidth: mysql.ErrTooBigDisplaywidth, + codeTooBigFieldLength: mysql.ErrTooBigFieldlength, + codeTooBigSet: mysql.ErrTooBigSet, + codeTooBigScale: mysql.ErrTooBigScale, + codeTooBigPrecision: mysql.ErrTooBigPrecision, + codeWrongFieldSpec: mysql.ErrWrongFieldSpec, + codeTruncatedWrongValue: mysql.ErrTruncatedWrongValue, + codeUnknown: mysql.ErrUnknown, + codeInvalidDefault: mysql.ErrInvalidDefault, + codeMBiggerThanD: mysql.ErrMBiggerThanD, + codeDataOutOfRange: mysql.ErrWarnDataOutOfRange, + codeDuplicatedValueInType: mysql.ErrDuplicatedValueInType, + codeDatetimeFunctionOverflow: mysql.ErrDatetimeFunctionOverflow, } terror.ErrClassToMySQLCodes[terror.ClassTypes] = typesMySQLErrCodes } diff --git a/types/time.go b/types/time.go index a38dfa2359c73..0a1d1d847250e 100644 --- a/types/time.go +++ b/types/time.go @@ -29,18 +29,7 @@ import ( "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/util/logutil" -) - -// Portable analogs of some common call errors. -var ( - ErrInvalidTimeFormat = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "invalid time format: '%v'") - ErrInvalidWeekModeFormat = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "invalid week mode format: '%v'") - ErrInvalidYearFormat = errors.New("invalid year format") - ErrInvalidYear = errors.New("invalid year") - ErrZeroDate = errors.New("datetime zero in date") - ErrIncorrectDatetimeValue = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "Incorrect datetime value: '%s'") - ErrDatetimeFunctionOverflow = terror.ClassTypes.New(mysql.ErrDatetimeFunctionOverflow, mysql.MySQLErrName[mysql.ErrDatetimeFunctionOverflow]) - ErrTruncatedWrongValue = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, mysql.MySQLErrName[mysql.ErrTruncatedWrongValue]) + tidbMath "github.com/pingcap/tidb/util/math" ) // Time format without fractional seconds precision. @@ -196,6 +185,13 @@ var ( } ) +const ( + // GoDurationDay is the gotime.Duration which equals to a Day. + GoDurationDay = gotime.Hour * 24 + // GoDurationWeek is the gotime.Duration which equals to a Week. + GoDurationWeek = GoDurationDay * 7 +) + // FromGoTime translates time.Time to mysql time internal representation. func FromGoTime(t gotime.Time) MysqlTime { year, month, day := t.Date() @@ -1191,9 +1187,9 @@ func ParseDuration(sc *stmtctx.StatementContext, str string, fsp int) (Duration, // TruncateOverflowMySQLTime truncates d when it overflows, and return ErrTruncatedWrongVal. func TruncateOverflowMySQLTime(d gotime.Duration) (gotime.Duration, error) { if d > MaxTime { - return MaxTime, ErrTruncatedWrongVal.GenWithStackByArgs("time", d.String()) + return MaxTime, ErrTruncatedWrongVal.GenWithStackByArgs("time", d) } else if d < MinTime { - return MinTime, ErrTruncatedWrongVal.GenWithStackByArgs("time", d.String()) + return MinTime, ErrTruncatedWrongVal.GenWithStackByArgs("time", d) } return d, nil @@ -1608,7 +1604,10 @@ func ExtractDurationNum(d *Duration, unit string) (int64, error) { } } -func parseSingleTimeValue(unit string, format string) (int64, int64, int64, int64, error) { +// parseSingleTimeValue parse the format according the given unit. If we set strictCheck true, we'll check whether +// the converted value not exceed the range of MySQL's TIME type. +// The first four returned values are year, month, day and nanosecond. +func parseSingleTimeValue(unit string, format string, strictCheck bool) (int64, int64, int64, int64, error) { // Format is a preformatted number, it format should be A[.[B]]. decimalPointPos := strings.IndexRune(format, '.') if decimalPointPos == -1 { @@ -1645,33 +1644,59 @@ func parseSingleTimeValue(unit string, format string) (int64, int64, int64, int6 err = ErrTruncatedWrongValue.GenWithStackByArgs(format) } } - const gotimeDay = 24 * gotime.Hour switch strings.ToUpper(unit) { case "MICROSECOND": - dayCount := riv / int64(gotimeDay/gotime.Microsecond) - riv %= int64(gotimeDay / gotime.Microsecond) + if strictCheck && tidbMath.Abs(riv) > TimeMaxValueSeconds*1000 { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } + dayCount := riv / int64(GoDurationDay/gotime.Microsecond) + riv %= int64(GoDurationDay / gotime.Microsecond) return 0, 0, dayCount, riv * int64(gotime.Microsecond), err case "SECOND": - dayCount := iv / int64(gotimeDay/gotime.Second) - iv %= int64(gotimeDay / gotime.Second) + if strictCheck && tidbMath.Abs(iv) > TimeMaxValueSeconds { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } + dayCount := iv / int64(GoDurationDay/gotime.Second) + iv %= int64(GoDurationDay / gotime.Second) return 0, 0, dayCount, iv*int64(gotime.Second) + dv*int64(gotime.Microsecond), err case "MINUTE": - dayCount := riv / int64(gotimeDay/gotime.Minute) - riv %= int64(gotimeDay / gotime.Minute) + if strictCheck && tidbMath.Abs(riv) > TimeMaxHour*60+TimeMaxMinute { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } + dayCount := riv / int64(GoDurationDay/gotime.Minute) + riv %= int64(GoDurationDay / gotime.Minute) return 0, 0, dayCount, riv * int64(gotime.Minute), err case "HOUR": + if strictCheck && tidbMath.Abs(riv) > TimeMaxHour { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } dayCount := riv / 24 riv %= 24 return 0, 0, dayCount, riv * int64(gotime.Hour), err case "DAY": + if strictCheck && tidbMath.Abs(riv) > TimeMaxHour/24 { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } return 0, 0, riv, 0, err case "WEEK": + if strictCheck && 7*tidbMath.Abs(riv) > TimeMaxHour/24 { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } return 0, 0, 7 * riv, 0, err case "MONTH": + if strictCheck && tidbMath.Abs(riv) > 1 { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } return 0, riv, 0, 0, err case "QUARTER": + if strictCheck { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } return 0, 3 * riv, 0, 0, err case "YEAR": + if strictCheck { + return 0, 0, 0, 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } return riv, 0, 0, 0, err } @@ -1742,13 +1767,28 @@ func parseTimeValue(format string, index, cnt int) (int64, int64, int64, int64, return years, months, days, seconds*int64(gotime.Second) + microseconds*int64(gotime.Microsecond), nil } +func parseAndValidateDurationValue(format string, index, cnt int) (int64, error) { + year, month, day, nano, err := parseTimeValue(format, index, cnt) + if err != nil { + return 0, err + } + if year != 0 || month != 0 || tidbMath.Abs(day) > TimeMaxHour/24 { + return 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } + dur := day*int64(GoDurationDay) + nano + if tidbMath.Abs(dur) > int64(MaxTime) { + return 0, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + } + return dur, nil +} + // ParseDurationValue parses time value from time unit and format. // Returns y years m months d days + n nanoseconds // Nanoseconds will no longer than one day. func ParseDurationValue(unit string, format string) (y int64, m int64, d int64, n int64, _ error) { switch strings.ToUpper(unit) { case "MICROSECOND", "SECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "QUARTER", "YEAR": - return parseSingleTimeValue(unit, format) + return parseSingleTimeValue(unit, format, false) case "SECOND_MICROSECOND": return parseTimeValue(format, MicrosecondIndex, SecondMicrosecondMaxCnt) case "MINUTE_MICROSECOND": @@ -1772,7 +1812,93 @@ func ParseDurationValue(unit string, format string) (y int64, m int64, d int64, case "YEAR_MONTH": return parseTimeValue(format, MonthIndex, YearMonthMaxCnt) default: - return 0, 0, 0, 0, errors.Errorf("invalid singel timeunit - %s", unit) + return 0, 0, 0, 0, errors.Errorf("invalid single timeunit - %s", unit) + } +} + +// ExtractDurationValue extract the value from format to Duration. +func ExtractDurationValue(unit string, format string) (Duration, error) { + unit = strings.ToUpper(unit) + switch unit { + case "MICROSECOND", "SECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "QUARTER", "YEAR": + _, month, day, nano, err := parseSingleTimeValue(unit, format, true) + if err != nil { + return ZeroDuration, err + } + dur := Duration{Duration: gotime.Duration((month*30+day)*int64(GoDurationDay) + nano)} + if unit == "MICROSECOND" { + dur.Fsp = MaxFsp + } + return dur, err + case "SECOND_MICROSECOND": + d, err := parseAndValidateDurationValue(format, MicrosecondIndex, SecondMicrosecondMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: MaxFsp}, nil + case "MINUTE_MICROSECOND": + d, err := parseAndValidateDurationValue(format, MicrosecondIndex, MinuteMicrosecondMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: MaxFsp}, nil + case "MINUTE_SECOND": + d, err := parseAndValidateDurationValue(format, SecondIndex, MinuteSecondMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: MaxFsp}, nil + case "HOUR_MICROSECOND": + d, err := parseAndValidateDurationValue(format, MicrosecondIndex, HourMicrosecondMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: MaxFsp}, nil + case "HOUR_SECOND": + d, err := parseAndValidateDurationValue(format, SecondIndex, HourSecondMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: MaxFsp}, nil + case "HOUR_MINUTE": + d, err := parseAndValidateDurationValue(format, MinuteIndex, HourMinuteMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: 0}, nil + case "DAY_MICROSECOND": + d, err := parseAndValidateDurationValue(format, MicrosecondIndex, DayMicrosecondMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: MaxFsp}, nil + case "DAY_SECOND": + d, err := parseAndValidateDurationValue(format, SecondIndex, DaySecondMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: MaxFsp}, nil + case "DAY_MINUTE": + d, err := parseAndValidateDurationValue(format, MinuteIndex, DayMinuteMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: 0}, nil + case "DAY_HOUR": + d, err := parseAndValidateDurationValue(format, HourIndex, DayHourMaxCnt) + if err != nil { + return ZeroDuration, err + } + return Duration{Duration: gotime.Duration(d), Fsp: 0}, nil + case "YEAR_MONTH": + _, err := parseAndValidateDurationValue(format, MonthIndex, YearMonthMaxCnt) + if err != nil { + return ZeroDuration, err + } + // MONTH must exceed the limit of mysql's duration. So just return overflow error. + return ZeroDuration, ErrDatetimeFunctionOverflow.GenWithStackByArgs("time") + default: + return ZeroDuration, errors.Errorf("invalid single timeunit - %s", unit) } } diff --git a/types/time_test.go b/types/time_test.go index caf177de996e6..0e9b0deeac789 100644 --- a/types/time_test.go +++ b/types/time_test.go @@ -18,6 +18,7 @@ import ( "time" . "github.com/pingcap/check" + "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/sessionctx/stmtctx" @@ -1081,6 +1082,141 @@ func (s *testTimeSuite) TestCheckTimestamp(c *C) { } } +func (s *testTimeSuite) TestExtractDurationValue(c *C) { + tests := []struct { + unit string + format string + ans string + failed bool + }{ + { + unit: "MICROSECOND", + format: "50", + ans: "00:00:00.000050", + }, + { + unit: "SECOND", + format: "50", + ans: "00:00:50", + }, + { + unit: "MINUTE", + format: "10", + ans: "00:10:00", + }, + { + unit: "HOUR", + format: "10", + ans: "10:00:00", + }, + { + unit: "DAY", + format: "1", + ans: "24:00:00", + }, + { + unit: "WEEK", + format: "2", + ans: "336:00:00", + }, + { + unit: "SECOND_MICROSECOND", + format: "61.01", + ans: "00:01:01.010000", + }, + { + unit: "MINUTE_MICROSECOND", + format: "01:61.01", + ans: "00:02:01.010000", + }, + { + unit: "MINUTE_SECOND", + format: "61:61", + ans: "01:02:01.000000", + }, + { + unit: "HOUR_MICROSECOND", + format: "01:61:01.01", + ans: "02:01:01.010000", + }, + { + unit: "HOUR_SECOND", + format: "01:61:01", + ans: "02:01:01.000000", + }, + { + unit: "HOUr_MINUTE", + format: "2:2", + ans: "02:02:00", + }, + { + unit: "DAY_MICRoSECOND", + format: "1 1:1:1.02", + ans: "25:01:01.020000", + }, + { + unit: "DAY_SeCOND", + format: "1 02:03:04", + ans: "26:03:04.000000", + }, + { + unit: "DAY_MINUTE", + format: "1 1:2", + ans: "25:02:00", + }, + { + unit: "DAY_HOUr", + format: "1 1", + ans: "25:00:00", + }, + { + unit: "DAY", + format: "-35", + failed: true, + }, + { + unit: "day", + format: "34", + ans: "816:00:00", + }, + { + unit: "SECOND", + format: "-3020400", + failed: true, + }, + { + unit: "MONTH", + format: "1", + ans: "720:00:00", + }, + { + unit: "MONTH", + format: "-2", + failed: true, + }, + { + unit: "DAY_second", + format: "34 23:59:59", + failed: true, + }, + { + unit: "DAY_hOUR", + format: "-34 23", + failed: true, + }, + } + failedComment := "failed at case %d, unit: %s, format: %s" + for i, tt := range tests { + dur, err := types.ExtractDurationValue(tt.unit, tt.format) + if tt.failed { + c.Assert(err, NotNil, Commentf(failedComment+", dur: %v", i, tt.unit, tt.format, dur.String())) + } else { + c.Assert(err, IsNil, Commentf(failedComment+", error stack", i, tt.unit, tt.format, errors.ErrorStack(err))) + c.Assert(dur.String(), Equals, tt.ans, Commentf(failedComment, i, tt.unit, tt.format)) + } + } +} + func (s *testTimeSuite) TestCurrentTime(c *C) { res := types.CurrentTime(mysql.TypeTimestamp) c.Assert(res.Time, NotNil) diff --git a/util/math/math.go b/util/math/math.go new file mode 100644 index 0000000000000..adce0f9fcd932 --- /dev/null +++ b/util/math/math.go @@ -0,0 +1,20 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package math + +// Abs implement the abs function according to http://cavaliercoder.com/blog/optimized-abs-for-int64-in-go.html +func Abs(n int64) int64 { + y := n >> 63 + return (n ^ y) - y +} diff --git a/util/math/math_test.go b/util/math/math_test.go new file mode 100644 index 0000000000000..34aaa8a0d01ad --- /dev/null +++ b/util/math/math_test.go @@ -0,0 +1,36 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package math + +import ( + "testing" + + . "github.com/pingcap/check" +) + +func TestT(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&testMath{}) + +type testMath struct{} + +func (s *testMath) TestAbs(c *C) { + c.Assert(Abs(1), Equals, int64(1)) + c.Assert(Abs(0), Equals, int64(0)) + c.Assert(Abs(1000), Equals, int64(1000)) + c.Assert(Abs(-100), Equals, int64(100)) + c.Assert(Abs(-1234), Equals, int64(1234)) +}