From 56a98449bad0b2e1578d68416f0500333592b2f0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E8=99=8E?= <ichneumon.hu@foxmail.com>
Date: Sat, 25 Jan 2020 10:10:10 +0800
Subject: [PATCH 1/2] planner: Fix SMJ hint, support SMJ with descending order.
 (#14505)

cherry-pick
---
 executor/builder.go                    |  1 +
 executor/merge_join.go                 | 10 +++++---
 executor/merge_join_test.go            | 34 ++++++++++++++++++++++++++
 planner/core/exhaust_physical_plans.go | 11 ++++++---
 planner/core/physical_plans.go         |  2 ++
 5 files changed, 51 insertions(+), 7 deletions(-)

diff --git a/executor/builder.go b/executor/builder.go
index af8dea29d13b8..7d5a2fd50c009 100644
--- a/executor/builder.go
+++ b/executor/builder.go
@@ -934,6 +934,7 @@ func (b *executorBuilder) buildMergeJoin(v *plannercore.PhysicalMergeJoin) Execu
 			retTypes(rightExec),
 		),
 		isOuterJoin: v.JoinType.IsOuterJoin(),
+		desc:        v.Desc,
 	}
 
 	leftKeys := v.LeftKeys
diff --git a/executor/merge_join.go b/executor/merge_join.go
index 4a0521740bc7c..29eabac783055 100644
--- a/executor/merge_join.go
+++ b/executor/merge_join.go
@@ -40,6 +40,7 @@ type MergeJoinExec struct {
 	compareFuncs []expression.CompareFunc
 	joiner       joiner
 	isOuterJoin  bool
+	desc         bool
 
 	prepared bool
 	outerIdx int
@@ -297,21 +298,24 @@ func (e *MergeJoinExec) joinToChunk(ctx context.Context, chk *chunk.Chunk) (hasM
 		}
 
 		cmpResult := -1
-		if e.outerTable.selected[e.outerTable.row.Idx()] && len(e.innerRows) > 0 {
+		if e.desc {
+			cmpResult = 1
+		}
+		if e.outerTable.selected[e.outerTable.row.Idx()] && e.innerIter4Row.Len() > 0 {
 			cmpResult, err = e.compare(e.outerTable.row, e.innerIter4Row.Current())
 			if err != nil {
 				return false, err
 			}
 		}
 
-		if cmpResult > 0 {
+		if (cmpResult > 0 && !e.desc) || (cmpResult < 0 && e.desc) {
 			if err = e.fetchNextInnerRows(); err != nil {
 				return false, err
 			}
 			continue
 		}
 
-		if cmpResult < 0 {
+		if (cmpResult < 0 && !e.desc) || (cmpResult > 0 && e.desc) {
 			e.joiner.onMissMatch(false, e.outerTable.row, chk)
 			if err != nil {
 				return false, err
diff --git a/executor/merge_join_test.go b/executor/merge_join_test.go
index b6fa5e032ce2d..03831971a8362 100644
--- a/executor/merge_join_test.go
+++ b/executor/merge_join_test.go
@@ -360,6 +360,40 @@ func (s *testSuite1) TestMergeJoin(c *C) {
 		"1",
 		"0",
 	))
+
+	// Test TIDB_SMJ for join with order by desc, see https://github.com/pingcap/tidb/issues/14483
+	tk.MustExec("drop table if exists t")
+	tk.MustExec("drop table if exists t1")
+	tk.MustExec("create table t (a int, key(a))")
+	tk.MustExec("create table t1 (a int, key(a))")
+	tk.MustExec("insert into t values (1), (2), (3)")
+	tk.MustExec("insert into t1 values (1), (2), (3)")
+	tk.MustQuery("select /*+ TIDB_SMJ(t1, t2) */ t.a from t, t1 where t.a = t1.a order by t1.a desc").Check(testkit.Rows(
+		"3", "2", "1"))
+	tk.MustExec("drop table if exists t")
+	tk.MustExec("create table t (a int, b int, key(a), key(b))")
+	tk.MustExec("insert into t values (1,1),(1,2),(1,3),(2,1),(2,2),(3,1),(3,2),(3,3)")
+	tk.MustQuery("select /*+ TIDB_SMJ(t1, t2) */ t1.a from t t1, t t2 where t1.a = t2.b order by t1.a desc").Check(testkit.Rows(
+		"3", "3", "3", "3", "3", "3",
+		"2", "2", "2", "2", "2", "2",
+		"1", "1", "1", "1", "1", "1", "1", "1", "1"))
+
+	tk.MustExec("drop table if exists s")
+	tk.MustExec("create table s (a int)")
+	tk.MustExec("insert into s values (4), (1), (3), (2)")
+	tk.MustQuery("explain select s1.a1 from (select a as a1 from s order by s.a desc) as s1 join (select a as a2 from s order by s.a desc) as s2 on s1.a1 = s2.a2 order by s1.a1 desc").Check(testkit.Rows(
+		"Projection_27 12487.50 root test.s.a",
+		"└─MergeJoin_28 12487.50 root inner join, left key:test.s.a, right key:test.s.a",
+		"  ├─Sort_29 9990.00 root test.s.a:desc",
+		"  │ └─TableReader_21 9990.00 root data:Selection_20",
+		"  │   └─Selection_20 9990.00 cop[tikv] not(isnull(test.s.a))",
+		"  │     └─TableScan_19 10000.00 cop[tikv] table:s, range:[-inf,+inf], keep order:false, stats:pseudo",
+		"  └─Sort_31 9990.00 root test.s.a:desc",
+		"    └─TableReader_26 9990.00 root data:Selection_25",
+		"      └─Selection_25 9990.00 cop[tikv] not(isnull(test.s.a))",
+		"        └─TableScan_24 10000.00 cop[tikv] table:s, range:[-inf,+inf], keep order:false, stats:pseudo"))
+	tk.MustQuery("select s1.a1 from (select a as a1 from s order by s.a desc) as s1 join (select a as a2 from s order by s.a desc) as s2 on s1.a1 = s2.a2 order by s1.a1 desc").Check(testkit.Rows(
+		"4", "3", "2", "1"))
 }
 
 func (s *testSuite1) Test3WaysMergeJoin(c *C) {
diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go
index 02d06a76961f2..f32e566aeaa2c 100644
--- a/planner/core/exhaust_physical_plans.go
+++ b/planner/core/exhaust_physical_plans.go
@@ -95,12 +95,12 @@ func (p *LogicalJoin) moveEqualToOtherConditions(offsets []int) []expression.Exp
 
 // Only if the input required prop is the prefix fo join keys, we can pass through this property.
 func (p *PhysicalMergeJoin) tryToGetChildReqProp(prop *property.PhysicalProperty) ([]*property.PhysicalProperty, bool) {
-	lProp := property.NewPhysicalProperty(property.RootTaskType, p.LeftKeys, false, math.MaxFloat64, false)
-	rProp := property.NewPhysicalProperty(property.RootTaskType, p.RightKeys, false, math.MaxFloat64, false)
+	all, desc := prop.AllSameOrder()
+	lProp := property.NewPhysicalProperty(property.RootTaskType, p.LeftKeys, desc, math.MaxFloat64, false)
+	rProp := property.NewPhysicalProperty(property.RootTaskType, p.RightKeys, desc, math.MaxFloat64, false)
 	if !prop.IsEmpty() {
 		// sort merge join fits the cases of massive ordered data, so desc scan is always expensive.
-		all, desc := prop.AllSameOrder()
-		if !all || desc {
+		if !all {
 			return nil, false
 		}
 		if !prop.IsPrefix(lProp) && !prop.IsPrefix(rProp) {
@@ -156,6 +156,8 @@ func (p *LogicalJoin) getMergeJoin(prop *property.PhysicalProperty) []PhysicalPl
 				reqProps[1].ExpectedCnt = p.children[1].statsInfo().RowCount * expCntScale
 			}
 			mergeJoin.childrenReqProps = reqProps
+			_, desc := prop.AllSameOrder()
+			mergeJoin.Desc = desc
 			joins = append(joins, mergeJoin)
 		}
 	}
@@ -239,6 +241,7 @@ func (p *LogicalJoin) getEnforcedMergeJoin(prop *property.PhysicalProperty) []Ph
 		LeftKeys:        leftKeys,
 		RightKeys:       rightKeys,
 		OtherConditions: p.OtherConditions,
+		Desc:            desc,
 	}.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt))
 	enforcedPhysicalMergeJoin.SetSchema(p.schema)
 	enforcedPhysicalMergeJoin.childrenReqProps = []*property.PhysicalProperty{lProp, rProp}
diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go
index b49992f301197..8b561e9bcc005 100644
--- a/planner/core/physical_plans.go
+++ b/planner/core/physical_plans.go
@@ -271,6 +271,8 @@ type PhysicalMergeJoin struct {
 
 	LeftKeys  []*expression.Column
 	RightKeys []*expression.Column
+	// Desc means whether inner child keep desc order.
+	Desc bool
 }
 
 // PhysicalLock is the physical operator of lock, which is used for `select ... for update` clause.

From 37d1339456bd641dea7072c89893b4c0658bcb44 Mon Sep 17 00:00:00 2001
From: ichn-hu <zfhu16@fudan.edu.cn>
Date: Thu, 6 Feb 2020 21:24:10 +0800
Subject: [PATCH 2/2] fix ut

---
 executor/merge_join_test.go | 17 +++++++++--------
 1 file changed, 9 insertions(+), 8 deletions(-)

diff --git a/executor/merge_join_test.go b/executor/merge_join_test.go
index 03831971a8362..cf0db17e4107c 100644
--- a/executor/merge_join_test.go
+++ b/executor/merge_join_test.go
@@ -382,16 +382,17 @@ func (s *testSuite1) TestMergeJoin(c *C) {
 	tk.MustExec("create table s (a int)")
 	tk.MustExec("insert into s values (4), (1), (3), (2)")
 	tk.MustQuery("explain select s1.a1 from (select a as a1 from s order by s.a desc) as s1 join (select a as a2 from s order by s.a desc) as s2 on s1.a1 = s2.a2 order by s1.a1 desc").Check(testkit.Rows(
-		"Projection_27 12487.50 root test.s.a",
-		"└─MergeJoin_28 12487.50 root inner join, left key:test.s.a, right key:test.s.a",
-		"  ├─Sort_29 9990.00 root test.s.a:desc",
+		"Projection_27 12487.50 root test.s.a1",
+		"└─MergeJoin_28 12487.50 root inner join, left key:test.s.a1, right key:test.s.a2",
+		"  ├─Sort_29 9990.00 root test.s.a1:desc",
 		"  │ └─TableReader_21 9990.00 root data:Selection_20",
-		"  │   └─Selection_20 9990.00 cop[tikv] not(isnull(test.s.a))",
-		"  │     └─TableScan_19 10000.00 cop[tikv] table:s, range:[-inf,+inf], keep order:false, stats:pseudo",
-		"  └─Sort_31 9990.00 root test.s.a:desc",
+		"  │   └─Selection_20 9990.00 cop not(isnull(test.s.a))",
+		"  │     └─TableScan_19 10000.00 cop table:s, range:[-inf,+inf], keep order:false, stats:pseudo",
+		"  └─Sort_31 9990.00 root test.s.a2:desc",
 		"    └─TableReader_26 9990.00 root data:Selection_25",
-		"      └─Selection_25 9990.00 cop[tikv] not(isnull(test.s.a))",
-		"        └─TableScan_24 10000.00 cop[tikv] table:s, range:[-inf,+inf], keep order:false, stats:pseudo"))
+		"      └─Selection_25 9990.00 cop not(isnull(test.s.a))",
+		"        └─TableScan_24 10000.00 cop table:s, range:[-inf,+inf], keep order:false, stats:pseudo",
+	))
 	tk.MustQuery("select s1.a1 from (select a as a1 from s order by s.a desc) as s1 join (select a as a2 from s order by s.a desc) as s2 on s1.a1 = s2.a2 order by s1.a1 desc").Check(testkit.Rows(
 		"4", "3", "2", "1"))
 }