diff --git a/expression/builtin_time.go b/expression/builtin_time.go index 8c87de117687f..d9d659074dba7 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -5209,6 +5209,14 @@ func (b *builtinPeriodDiffSig) evalInt(row chunk.Row) (int64, bool, error) { return 0, isNull, err } + if !validPeriod(p1) { + return 0, false, errIncorrectArgs.GenWithStackByArgs("period_diff") + } + + if !validPeriod(p2) { + return 0, false, errIncorrectArgs.GenWithStackByArgs("period_diff") + } + return int64(period2Month(uint64(p1)) - period2Month(uint64(p2))), false, nil } diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index 2ac4b1e494293..bb837c1c0f6a4 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -2403,19 +2403,24 @@ func (s *testEvaluatorSuite) TestPeriodDiff(c *C) { }{ {201611, 201611, true, 0}, {200802, 200703, true, 11}, - {0, 999999999, true, -120000086}, - {9999999, 0, true, 1200086}, - {411, 200413, true, -2}, - {197000, 207700, true, -1284}, {201701, 201611, true, 2}, {201702, 201611, true, 3}, {201510, 201611, true, -13}, {201702, 1611, true, 3}, {197102, 7011, true, 3}, - {12509, 12323, true, 10}, - {12509, 12323, true, 10}, } + tests2 := []struct { + Period1 int64 + Period2 int64 + }{ + {0, 999999999}, + {9999999, 0}, + {411, 200413}, + {197000, 207700}, + {12509, 12323}, + {12509, 12323}, + } fc := funcs[ast.PeriodDiff] for _, test := range tests { period1 := types.NewIntDatum(test.Period1) @@ -2433,6 +2438,18 @@ func (s *testEvaluatorSuite) TestPeriodDiff(c *C) { value := result.GetInt64() c.Assert(value, Equals, test.Expect) } + + for _, test := range tests2 { + period1 := types.NewIntDatum(test.Period1) + period2 := types.NewIntDatum(test.Period2) + f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{period1, period2})) + c.Assert(err, IsNil) + c.Assert(f, NotNil) + _, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[expression:1210]Incorrect arguments to period_diff") + } + // nil args := []types.Datum{types.NewDatum(nil), types.NewIntDatum(0)} f, err := fc.getFunction(s.ctx, s.datumsToConstants(args)) diff --git a/expression/integration_test.go b/expression/integration_test.go index 2094f7e5dc3d3..6740daaf1079b 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -1523,10 +1523,16 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { } // for period_diff - result = tk.MustQuery(`SELECT period_diff(191, 2), period_diff(191, -2), period_diff(0, 0), period_diff(191, 191);`) - result.Check(testkit.Rows("101 -2213609288845122103 0 0")) - result = tk.MustQuery(`SELECT period_diff(NULL, 2), period_diff(-191, NULL), period_diff(NULL, NULL), period_diff(12.09, 2), period_diff("21aa", "11aa"), period_diff("", "");`) - result.Check(testkit.Rows(" 10 10 0")) + result = tk.MustQuery(`SELECT period_diff(200807, 200705), period_diff(200807, 200908);`) + result.Check(testkit.Rows("14 -13")) + result = tk.MustQuery(`SELECT period_diff(NULL, 2), period_diff(-191, NULL), period_diff(NULL, NULL), period_diff(12.09, 2), period_diff("12aa", "11aa");`) + result.Check(testkit.Rows(" 10 1")) + for _, errPeriod := range []string{ + "period_diff(-00013,1)", "period_diff(00013,1)", "period_diff(0, 0)", "period_diff(200013, 1)", "period_diff(5612, 4513)", "period_diff('', '')", + } { + err := tk.QueryToErr(fmt.Sprintf("SELECT %v;", errPeriod)) + c.Assert(err.Error(), Equals, "[expression:1210]Incorrect arguments to period_diff") + } // TODO: fix `CAST(xx as duration)` and release the test below: // result = tk.MustQuery(`SELECT hour("aaa"), hour(123456), hour(1234567);`)