From ad849b3aee249081bd591b0fbb81632655a88499 Mon Sep 17 00:00:00 2001
From: Lingyu Song <songlingyu@pingcap.com>
Date: Tue, 13 Nov 2018 10:49:15 +0800
Subject: [PATCH] planner: fix expression rewriter wrong compare logic  (#8269)

---
 planner/core/expression_rewriter.go      | 33 ++++++-------
 planner/core/expression_rewriter_test.go | 60 ++++++++++++++++++++++++
 2 files changed, 74 insertions(+), 19 deletions(-)
 create mode 100644 planner/core/expression_rewriter_test.go

diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go
index 0f804f6d430e6..4d5ceed793b33 100644
--- a/planner/core/expression_rewriter.go
+++ b/planner/core/expression_rewriter.go
@@ -169,15 +169,11 @@ type expressionRewriter struct {
 }
 
 // 1. If op are EQ or NE or NullEQ, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2)
-// 2. If op are LE or GE, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to
-// `IF( (a0 op b0) EQ 0, 0,
-//      IF ( (a1 op b1) EQ 0, 0, a2 op b2))`
-// 3. If op are LT or GT, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to
+// 2. Else constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to
 // `IF( a0 NE b0, a0 op b0,
-//      IF( a1 NE b1,
-//          a1 op b1,
-//          a2 op b2)
-// )`
+// 		IF ( isNull(a0 NE b0), Null,
+// 			IF ( a1 NE b1, a1 op b1,
+// 				IF ( isNull(a1 NE b1), Null, a2 op b2))))`
 func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, r expression.Expression, op string) (expression.Expression, error) {
 	lLen, rLen := expression.GetRowLen(l), expression.GetRowLen(r)
 	if lLen == 1 && rLen == 1 {
@@ -198,15 +194,10 @@ func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression,
 		return expression.ComposeCNFCondition(er.ctx, funcs...), nil
 	default:
 		larg0, rarg0 := expression.GetFuncArg(l, 0), expression.GetFuncArg(r, 0)
-		var expr1, expr2, expr3 expression.Expression
-		if op == ast.LE || op == ast.GE {
-			expr1 = expression.NewFunctionInternal(er.ctx, op, types.NewFieldType(mysql.TypeTiny), larg0, rarg0)
-			expr1 = expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), expr1, expression.Zero)
-			expr2 = expression.Zero
-		} else if op == ast.LT || op == ast.GT {
-			expr1 = expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), larg0, rarg0)
-			expr2 = expression.NewFunctionInternal(er.ctx, op, types.NewFieldType(mysql.TypeTiny), larg0, rarg0)
-		}
+		var expr1, expr2, expr3, expr4, expr5 expression.Expression
+		expr1 = expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), larg0, rarg0)
+		expr2 = expression.NewFunctionInternal(er.ctx, op, types.NewFieldType(mysql.TypeTiny), larg0, rarg0)
+		expr3 = expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr1)
 		var err error
 		l, err = expression.PopRowFirstArg(er.ctx, l)
 		if err != nil {
@@ -216,11 +207,15 @@ func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression,
 		if err != nil {
 			return nil, errors.Trace(err)
 		}
-		expr3, err = er.constructBinaryOpFunction(l, r, op)
+		expr4, err = er.constructBinaryOpFunction(l, r, op)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+		expr5, err = expression.NewFunction(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), expr3, expression.Null, expr4)
 		if err != nil {
 			return nil, errors.Trace(err)
 		}
-		return expression.NewFunction(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), expr1, expr2, expr3)
+		return expression.NewFunction(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), expr1, expr2, expr5)
 	}
 }
 
diff --git a/planner/core/expression_rewriter_test.go b/planner/core/expression_rewriter_test.go
new file mode 100644
index 0000000000000..bedd1328446ca
--- /dev/null
+++ b/planner/core/expression_rewriter_test.go
@@ -0,0 +1,60 @@
+// Copyright 2018 PingCAP, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package core_test
+
+import (
+	. "github.com/pingcap/check"
+	"github.com/pingcap/tidb/util/testkit"
+	"github.com/pingcap/tidb/util/testleak"
+)
+
+var _ = Suite(&testExpressionRewriterSuite{})
+
+type testExpressionRewriterSuite struct {
+}
+
+func (s *testExpressionRewriterSuite) TestIfNullEliminateColName(c *C) {
+	defer testleak.AfterTest(c)()
+	store, dom, err := newStoreWithBootstrap()
+	c.Assert(err, IsNil)
+	tk := testkit.NewTestKit(c, store)
+	defer func() {
+		dom.Close()
+		store.Close()
+	}()
+	tk.MustExec("use test")
+	tk.MustExec("drop table if exists t")
+	tk.MustExec("create table t(a int not null, b int not null)")
+	rs, err := tk.Exec("select ifnull(a,b) from t")
+	c.Assert(err, IsNil)
+	fields := rs.Fields()
+	c.Assert(fields[0].Column.Name.L, Equals, "ifnull(a,b)")
+}
+
+func (s *testExpressionRewriterSuite) TestBinaryOpFunction(c *C) {
+	defer testleak.AfterTest(c)()
+	store, dom, err := newStoreWithBootstrap()
+	c.Assert(err, IsNil)
+	tk := testkit.NewTestKit(c, store)
+	defer func() {
+		dom.Close()
+		store.Close()
+	}()
+	tk.MustExec("use test")
+	tk.MustExec("drop table if exists t")
+	tk.MustExec("CREATE TABLE t(a int, b int, c int);")
+	tk.MustExec("INSERT INTO t VALUES (1, 2, 3), (NULL, 2, 3  ), (1, NULL, 3),(1, 2,   NULL),(NULL, 2, 3+1), (1, NULL, 3+1), (1, 2+1, NULL),(NULL, 2, 3-1), (1, NULL, 3-1), (1, 2-1, NULL)")
+	tk.MustQuery("SELECT * FROM t WHERE (a,b,c) <= (1,2,3) order by b").Check(testkit.Rows("1 1 <nil>", "1 2 3"))
+	tk.MustQuery("SELECT * FROM t WHERE (a,b,c) > (1,2,3) order by b").Check(testkit.Rows("1 3 <nil>"))
+}