From 32bd4a7f4bfa635df8c8fc12e56e506e85212778 Mon Sep 17 00:00:00 2001
From: Yiding Cui <winoros@gmail.com>
Date: Thu, 10 Oct 2019 17:38:52 +0800
Subject: [PATCH] planner: fix unexpected behavior of UPDATE (#12597)

---
 executor/update_test.go              | 11 +++++++++++
 planner/core/logical_plan_builder.go |  9 ++++-----
 2 files changed, 15 insertions(+), 5 deletions(-)

diff --git a/executor/update_test.go b/executor/update_test.go
index 410ba65fc4246..7d33c2357e2ad 100644
--- a/executor/update_test.go
+++ b/executor/update_test.go
@@ -212,3 +212,14 @@ func (s *testUpdateSuite) TestUpdateWithAutoidSchema(c *C) {
 		tk.MustQuery(tt.query).Check(tt.result)
 	}
 }
+
+func (s *testUpdateSuite) TestUpdateWithSubquery(c *C) {
+	tk := testkit.NewTestKit(c, s.store)
+	tk.MustExec("use test")
+	tk.MustExec("create table t1(id varchar(30) not null, status varchar(1) not null default 'N', id2 varchar(30))")
+	tk.MustExec("create table t2(id varchar(30) not null, field varchar(4) not null)")
+	tk.MustExec("insert into t1 values('abc', 'F', 'abc')")
+	tk.MustExec("insert into t2 values('abc', 'MAIN')")
+	tk.MustExec("update t1 set status = 'N' where status = 'F' and (id in (select id from t2 where field = 'MAIN') or id2 in (select id from t2 where field = 'main'))")
+	tk.MustQuery("select * from t1").Check(testkit.Rows("abc N abc"))
+}
diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go
index 26ffc3dcb260c..c7e464a125d33 100644
--- a/planner/core/logical_plan_builder.go
+++ b/planner/core/logical_plan_builder.go
@@ -2584,7 +2584,7 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) (
 		b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName, t.Name.L, "", nil)
 	}
 
-	oldSchemaLen := p.Schema().Len()
+	oldSchema := p.Schema().Clone()
 	if sel.Where != nil {
 		p, err = b.buildSelection(ctx, p, update.Where, nil)
 		if err != nil {
@@ -2592,10 +2592,9 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) (
 		}
 	}
 	// TODO: expression rewriter should not change the output columns. We should cut the columns here.
-	if p.Schema().Len() != oldSchemaLen {
-		proj := LogicalProjection{Exprs: expression.Column2Exprs(p.Schema().Columns[:oldSchemaLen])}.Init(b.ctx)
-		proj.SetSchema(expression.NewSchema(make([]*expression.Column, oldSchemaLen)...))
-		copy(proj.schema.Columns, p.Schema().Columns[:oldSchemaLen])
+	if p.Schema().Len() != oldSchema.Len() {
+		proj := LogicalProjection{Exprs: expression.Column2Exprs(oldSchema.Columns)}.Init(b.ctx)
+		proj.SetSchema(oldSchema)
 		proj.SetChildren(p)
 		p = proj
 	}