diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 4aba634a193f9..80b1528e11a37 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -170,15 +170,11 @@ type expressionRewriter struct { } // 1. If op are EQ or NE or NullEQ, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2) -// 2. If op are LE or GE, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to -// `IF( (a0 op b0) EQ 0, 0, -// IF ( (a1 op b1) EQ 0, 0, a2 op b2))` -// 3. If op are LT or GT, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to +// 2. Else constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to // `IF( a0 NE b0, a0 op b0, -// IF( a1 NE b1, -// a1 op b1, -// a2 op b2) -// )` +// IF ( isNull(a0 NE b0), Null, +// IF ( a1 NE b1, a1 op b1, +// IF ( isNull(a1 NE b1), Null, a2 op b2))))` func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, r expression.Expression, op string) (expression.Expression, error) { lLen, rLen := expression.GetRowLen(l), expression.GetRowLen(r) if lLen == 1 && rLen == 1 { @@ -202,9 +198,10 @@ func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, return expression.ComposeCNFCondition(er.ctx, funcs...), nil default: larg0, rarg0 := expression.GetFuncArg(l, 0), expression.GetFuncArg(r, 0) - var expr1, expr2, expr3 expression.Expression + var expr1, expr2, expr3, expr4, expr5 expression.Expression expr1 = expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), larg0, rarg0) expr2 = expression.NewFunctionInternal(er.ctx, op, types.NewFieldType(mysql.TypeTiny), larg0, rarg0) + expr3 = expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr1) var err error l, err = expression.PopRowFirstArg(er.ctx, l) if err != nil { @@ -214,23 +211,15 @@ func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, if err != nil { return nil, errors.Trace(err) } - if evalexpr, ok := expr1.(*expression.Constant); ok { - _, isNull, err1 := evalexpr.EvalInt(er.ctx, chunk.Row{}) - if err1 != nil || isNull { - return expr1, err1 - } - } - if evalexpr, ok := expr2.(*expression.Constant); ok { - _, isNull, err1 := evalexpr.EvalInt(er.ctx, chunk.Row{}) - if err1 != nil || isNull { - return expr2, err1 - } + expr4, err = er.constructBinaryOpFunction(l, r, op) + if err != nil { + return nil, errors.Trace(err) } - expr3, err = er.constructBinaryOpFunction(l, r, op) + expr5, err = expression.NewFunction(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), expr3, expression.Null, expr4) if err != nil { return nil, errors.Trace(err) } - return expression.NewFunction(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), expr1, expr2, expr3) + return expression.NewFunction(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), expr1, expr2, expr5) } } diff --git a/planner/core/expression_rewriter_test.go b/planner/core/expression_rewriter_test.go index 9d10a34150a29..bedd1328446ca 100644 --- a/planner/core/expression_rewriter_test.go +++ b/planner/core/expression_rewriter_test.go @@ -41,3 +41,20 @@ func (s *testExpressionRewriterSuite) TestIfNullEliminateColName(c *C) { fields := rs.Fields() c.Assert(fields[0].Column.Name.L, Equals, "ifnull(a,b)") } + +func (s *testExpressionRewriterSuite) TestBinaryOpFunction(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("CREATE TABLE t(a int, b int, c int);") + tk.MustExec("INSERT INTO t VALUES (1, 2, 3), (NULL, 2, 3 ), (1, NULL, 3),(1, 2, NULL),(NULL, 2, 3+1), (1, NULL, 3+1), (1, 2+1, NULL),(NULL, 2, 3-1), (1, NULL, 3-1), (1, 2-1, NULL)") + tk.MustQuery("SELECT * FROM t WHERE (a,b,c) <= (1,2,3) order by b").Check(testkit.Rows("1 1 ", "1 2 3")) + tk.MustQuery("SELECT * FROM t WHERE (a,b,c) > (1,2,3) order by b").Check(testkit.Rows("1 3 ")) +}