From 32bd4a7f4bfa635df8c8fc12e56e506e85212778 Mon Sep 17 00:00:00 2001 From: Yiding Cui <winoros@gmail.com> Date: Thu, 10 Oct 2019 17:38:52 +0800 Subject: [PATCH] planner: fix unexpected behavior of UPDATE (#12597) --- executor/update_test.go | 11 +++++++++++ planner/core/logical_plan_builder.go | 9 ++++----- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/executor/update_test.go b/executor/update_test.go index 410ba65fc4246..7d33c2357e2ad 100644 --- a/executor/update_test.go +++ b/executor/update_test.go @@ -212,3 +212,14 @@ func (s *testUpdateSuite) TestUpdateWithAutoidSchema(c *C) { tk.MustQuery(tt.query).Check(tt.result) } } + +func (s *testUpdateSuite) TestUpdateWithSubquery(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table t1(id varchar(30) not null, status varchar(1) not null default 'N', id2 varchar(30))") + tk.MustExec("create table t2(id varchar(30) not null, field varchar(4) not null)") + tk.MustExec("insert into t1 values('abc', 'F', 'abc')") + tk.MustExec("insert into t2 values('abc', 'MAIN')") + tk.MustExec("update t1 set status = 'N' where status = 'F' and (id in (select id from t2 where field = 'MAIN') or id2 in (select id from t2 where field = 'main'))") + tk.MustQuery("select * from t1").Check(testkit.Rows("abc N abc")) +} diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 26ffc3dcb260c..c7e464a125d33 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -2584,7 +2584,7 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) ( b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName, t.Name.L, "", nil) } - oldSchemaLen := p.Schema().Len() + oldSchema := p.Schema().Clone() if sel.Where != nil { p, err = b.buildSelection(ctx, p, update.Where, nil) if err != nil { @@ -2592,10 +2592,9 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) ( } } // TODO: expression rewriter should not change the output columns. We should cut the columns here. - if p.Schema().Len() != oldSchemaLen { - proj := LogicalProjection{Exprs: expression.Column2Exprs(p.Schema().Columns[:oldSchemaLen])}.Init(b.ctx) - proj.SetSchema(expression.NewSchema(make([]*expression.Column, oldSchemaLen)...)) - copy(proj.schema.Columns, p.Schema().Columns[:oldSchemaLen]) + if p.Schema().Len() != oldSchema.Len() { + proj := LogicalProjection{Exprs: expression.Column2Exprs(oldSchema.Columns)}.Init(b.ctx) + proj.SetSchema(oldSchema) proj.SetChildren(p) p = proj }