From 8477f727f9f646f6e82d33a85c31ba69e006f8b4 Mon Sep 17 00:00:00 2001 From: Florent Poinsard Date: Fri, 10 Mar 2023 17:45:07 +0200 Subject: [PATCH 1/5] Fix random aggregation to not select Null column Signed-off-by: Florent Poinsard --- .../queries/aggregation/aggregation_test.go | 10 ++++++ .../vtgate/queries/aggregation/schema.sql | 5 +++ .../vtgate/queries/aggregation/vschema.json | 8 +++++ go/vt/vtgate/engine/ordered_aggregate.go | 9 ++++- go/vt/vtgate/executor_select_test.go | 34 +++++++++++++++++++ 5 files changed, 65 insertions(+), 1 deletion(-) diff --git a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go index b7ef4c4a78d..f07fb734df8 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go @@ -425,3 +425,13 @@ func TestScalarAggregate(t *testing.T) { mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)") mcmp.AssertMatches("select /*vt+ PLANNER=gen4 */ count(distinct val1) from aggr_test", `[[INT64(3)]]`) } + +func TestAggregationRandomOnAnAggregatedValue(t *testing.T) { + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into t10(k, a, b) values (0, 100, 10), (10, 200, 20);") + + mcmp.AssertMatchesNoOrder("select /*vt+ PLANNER=gen4 */ A.a, A.b, (A.a / A.b) as d from (select sum(a) as a, sum(b) as b from t10 where a = 100) A;", + `[[DECIMAL(100) DECIMAL(10) DECIMAL(10.0000)]]`) +} diff --git a/go/test/endtoend/vtgate/queries/aggregation/schema.sql b/go/test/endtoend/vtgate/queries/aggregation/schema.sql index a538a3dafed..0375bdb8499 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/schema.sql +++ b/go/test/endtoend/vtgate/queries/aggregation/schema.sql @@ -71,3 +71,8 @@ CREATE TABLE t2 ( PRIMARY KEY (id) ) ENGINE InnoDB; +CREATE TABLE t10 ( + k BIGINT PRIMARY KEY, + a INT, + b INT +); \ No newline at end of file diff --git a/go/test/endtoend/vtgate/queries/aggregation/vschema.json b/go/test/endtoend/vtgate/queries/aggregation/vschema.json index c2d3f133a35..4d1623d5633 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/vschema.json +++ b/go/test/endtoend/vtgate/queries/aggregation/vschema.json @@ -123,6 +123,14 @@ "name": "hash" } ] + }, + "t10": { + "column_vindexes": [ + { + "column": "k", + "name": "hash" + } + ] } } } \ No newline at end of file diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index e5d3057a127..906a267b725 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -526,7 +526,14 @@ func merge( val, _ := sqltypes.NewValue(sqltypes.VarBinary, data) result[aggr.Col] = val case AggregateRandom: - // we just grab the first value per grouping. no need to do anything more complicated here + // we just grab the first value per grouping + // however, if the first row contains a Null value for this row we decide to ignore + // it and use the second row. there might some cases (i.e. `sum(a) / sum(b)`) on a sharded + // cluster for which MySQL will return Null on row1 and a value on row2. we want to return + // the computed value of row2. + if row1[aggr.Col].IsNull() { + result[aggr.Col] = row2[aggr.Col] + } default: return nil, nil, fmt.Errorf("BUG: Unexpected opcode: %v", aggr.Opcode) } diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index f76bd742d03..e95039d6825 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -3856,6 +3856,40 @@ func TestSelectAggregationData(t *testing.T) { } } +func TestSelectAggregationRandom(t *testing.T) { + cell := "aa" + hc := discovery.NewFakeHealthCheck(nil) + createSandbox(KsTestSharded).VSchema = executorVSchema + getSandbox(KsTestUnsharded).VSchema = unshardedVSchema + serv := newSandboxForCells([]string{cell}) + resolver := newTestResolver(hc, serv, cell) + shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"} + var conns []*sandboxconn.SandboxConn + for _, shard := range shards { + sbc := hc.AddTestTablet(cell, shard, 1, KsTestSharded, shard, topodatapb.TabletType_PRIMARY, true, 1, nil) + conns = append(conns, sbc) + + sbc.SetResults([]*sqltypes.Result{sqltypes.MakeTestResult( + sqltypes.MakeTestFields("a|b|c", "int64|int64|int64"), + "null|null|null", + )}) + } + + conns[0].SetResults([]*sqltypes.Result{sqltypes.MakeTestResult( + sqltypes.MakeTestFields("a|b|c", "int64|int64|int64"), + "10|1|10", + )}) + + executor := createExecutor(serv, cell, resolver) + executor.pv = querypb.ExecuteOptions_Gen4 + session := NewAutocommitSession(&vtgatepb.Session{}) + + rs, err := executor.Execute(context.Background(), "TestSelectCFC", session, + "select /*vt+ PLANNER=gen4 */ A.a, A.b, (A.a / A.b) as c from (select sum(a) as a, sum(b) as b from user) A", nil) + require.NoError(t, err) + assert.Equal(t, `[[INT64(10) INT64(1) INT64(10)]]`, fmt.Sprintf("%v", rs.Rows)) +} + func TestSelectHexAndBit(t *testing.T) { executor, _, _, _ := createExecutorEnv() executor.normalize = true From 21d379183ef97dd8a9d160eac6bf999993fc067c Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 15 Mar 2023 15:32:13 +0100 Subject: [PATCH 2/5] stop pushing down projections that should be evaluated at the vtgate level Signed-off-by: Andres Taylor --- go/vt/vtgate/planbuilder/gen4_planner.go | 48 +++++++---- go/vt/vtgate/planbuilder/horizon_planning.go | 31 ++++++- .../planbuilder/operators/queryprojection.go | 80 ++++++++++++++++++- .../planbuilder/testdata/aggr_cases.json | 40 ++++++++++ 4 files changed, 178 insertions(+), 21 deletions(-) diff --git a/go/vt/vtgate/planbuilder/gen4_planner.go b/go/vt/vtgate/planbuilder/gen4_planner.go index dc49ae0a700..8de2ba02f1f 100644 --- a/go/vt/vtgate/planbuilder/gen4_planner.go +++ b/go/vt/vtgate/planbuilder/gen4_planner.go @@ -216,7 +216,10 @@ func newBuildSelectPlan( return nil, nil, nil, err } - plan = optimizePlan(plan) + plan, err = optimizePlan(plan) + if err != nil { + return nil, nil, nil, err + } sel, isSel := selStmt.(*sqlparser.Select) if isSel { @@ -238,25 +241,36 @@ func newBuildSelectPlan( } // optimizePlan removes unnecessary simpleProjections that have been created while planning -func optimizePlan(plan logicalPlan) logicalPlan { - newPlan, _ := visit(plan, func(plan logicalPlan) (bool, logicalPlan, error) { - this, ok := plan.(*simpleProjection) - if !ok { - return true, plan, nil +func optimizePlan(plan logicalPlan) (output logicalPlan, err error) { + output = plan + inputs := make([]logicalPlan, len(plan.Inputs())) + for i, lp := range plan.Inputs() { + in, err := optimizePlan(lp) + if err != nil { + return nil, err } + inputs[i] = in + } + err = plan.Rewrite(inputs...) + if err != nil { + return + } - input, ok := this.input.(*simpleProjection) - if !ok { - return true, plan, nil - } + this, ok := plan.(*simpleProjection) + if !ok { + return + } - for i, col := range this.eSimpleProj.Cols { - this.eSimpleProj.Cols[i] = input.eSimpleProj.Cols[col] - } - this.input = input.input - return true, this, nil - }) - return newPlan + input, ok := this.input.(*simpleProjection) + if !ok { + return + } + + for i, col := range this.eSimpleProj.Cols { + this.eSimpleProj.Cols[i] = input.eSimpleProj.Cols[col] + } + this.input = input.input + return } func gen4UpdateStmtPlanner( diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index eea1400b916..664bce56b93 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -60,7 +60,8 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo // a simpleProjection. We create a new Route that contains the derived table in the // FROM clause. Meaning that, when we push expressions to the select list of this // new Route, we do not want them to rewrite them. - if _, isSimpleProj := plan.(*simpleProjection); isSimpleProj { + sp, isSimpleProj := plan.(*simpleProjection) + if isSimpleProj { oldRewriteDerivedExpr := ctx.RewriteDerivedExpr defer func() { ctx.RewriteDerivedExpr = oldRewriteDerivedExpr @@ -75,10 +76,11 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo } needsOrdering := len(hp.qp.OrderExprs) > 0 - canShortcut := isRoute && hp.sel.Having == nil && !needsOrdering // If we still have a HAVING clause, it's because it could not be pushed to the WHERE, // so it probably has aggregations + canShortcut := isRoute && hp.sel.Having == nil && !needsOrdering + switch { case hp.qp.NeedsAggregation() || hp.sel.Having != nil: plan, err = hp.planAggregations(ctx, plan) @@ -93,10 +95,33 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo return nil, err } default: - err = pushProjections(ctx, plan, hp.qp.SelectExprs) + if !isSimpleProj { + err = pushProjections(ctx, plan, hp.qp.SelectExprs) + if err != nil { + return nil, err + } + break + } + + pusher := func(ae *sqlparser.AliasedExpr) (int, error) { + offset, _, err := pushProjection(ctx, ae, sp.input, true, true, false) + return offset, err + } + needsVtGate, projections, colNames, err := hp.qp.NeedsProjecting(ctx, pusher) if err != nil { return nil, err } + if !needsVtGate { + break + } + + // there were some expressions we could not push down entirely, + // so replace the simpleProjection with a real projection + plan = &projection{ + source: sp.input, + columns: projections, + columnNames: colNames, + } } // If we didn't already take care of ORDER BY during aggregation planning, we need to handle it now diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 29e356c6650..8de53a762be 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -418,7 +418,85 @@ func (qp *QueryProjection) NeedsAggregation() bool { return qp.HasAggr || len(qp.groupByExprs) > 0 } -func (qp QueryProjection) onlyAggr() bool { +// NeedsProjecting returns true if we have projections that need to be evaluated at the vtgate level +// and can't be pushed down to MySQL +func (qp *QueryProjection) NeedsProjecting( + ctx *plancontext.PlanningContext, + pusher func(expr *sqlparser.AliasedExpr) (int, error), +) (needsVtGateEval bool, expressions []sqlparser.Expr, colNames []string, err error) { + for _, se := range qp.SelectExprs { + var ae *sqlparser.AliasedExpr + ae, err = se.GetAliasedExpr() + if err != nil { + return false, nil, nil, err + } + + expr := ae.Expr + colNames = append(colNames, ae.ColumnName()) + + if _, isCol := expr.(*sqlparser.ColName); isCol { + offset, err := pusher(ae) + if err != nil { + return false, nil, nil, err + } + expressions = append(expressions, sqlparser.NewOffset(offset, expr)) + continue + } + + stopOnError := func(sqlparser.SQLNode, sqlparser.SQLNode) bool { + return err == nil + } + rewriter := func(cursor *sqlparser.CopyOnWriteCursor) { + col, isCol := cursor.Node().(*sqlparser.ColName) + if !isCol { + return + } + var tableInfo semantics.TableInfo + tableInfo, err = ctx.SemTable.TableInfoForExpr(col) + if err != nil { + return + } + dt, isDT := tableInfo.(*semantics.DerivedTable) + if !isDT { + return + } + + rewritten := semantics.RewriteDerivedTableExpression(col, dt) + if sqlparser.ContainsAggregation(rewritten) { + offset, tErr := pusher(&sqlparser.AliasedExpr{Expr: col}) + if tErr != nil { + err = tErr + return + } + + cursor.Replace(sqlparser.NewOffset(offset, col)) + } + } + newExpr := sqlparser.CopyOnRewrite(expr, stopOnError, rewriter, nil) + + if err != nil { + return + } + + if newExpr != expr { + // if we changed the expression, it means that we have to evaluate the rest at the vtgate level + expressions = append(expressions, newExpr.(sqlparser.Expr)) + needsVtGateEval = true + continue + } + + // we did not need to push any parts of this expression down. Let's check if we can push all of it + offset, err := pusher(ae) + if err != nil { + return false, nil, nil, err + } + expressions = append(expressions, sqlparser.NewOffset(offset, expr)) + } + + return +} + +func (qp *QueryProjection) onlyAggr() bool { if !qp.HasAggr { return false } diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index f8e6c7fcde1..e0cee664828 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -4964,5 +4964,45 @@ "user.user_extra" ] } + }, + { + "comment": "Aggregations from derived table used in arithmetic outside derived table", + "query": "select A.a, A.b, (A.a / A.b) as d from (select sum(a) as a, sum(b) as b from user) A", + "v3-plan": "VT12001: unsupported: expression on results of a cross-shard subquery", + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select A.a, A.b, (A.a / A.b) as d from (select sum(a) as a, sum(b) as b from user) A", + "Instructions": { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] as a", + "[COLUMN 1] as b", + "[COLUMN 0] / [COLUMN 1] as d" + ], + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum(0) AS a, sum(1) AS b", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select sum(a) as a, sum(b) as b from `user` where 1 != 1", + "Query": "select sum(a) as a, sum(b) as b from `user`", + "Table": "`user`" + } + ] + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } } ] From 22eb37b235413897ed3cd692a9c12208236cae34 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 15 Mar 2023 15:35:50 +0100 Subject: [PATCH 3/5] undo changes to AggregateRandom Signed-off-by: Andres Taylor --- go/vt/vtgate/engine/ordered_aggregate.go | 9 +-------- go/vt/vtgate/executor_select_test.go | 2 +- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index 906a267b725..e5d3057a127 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -526,14 +526,7 @@ func merge( val, _ := sqltypes.NewValue(sqltypes.VarBinary, data) result[aggr.Col] = val case AggregateRandom: - // we just grab the first value per grouping - // however, if the first row contains a Null value for this row we decide to ignore - // it and use the second row. there might some cases (i.e. `sum(a) / sum(b)`) on a sharded - // cluster for which MySQL will return Null on row1 and a value on row2. we want to return - // the computed value of row2. - if row1[aggr.Col].IsNull() { - result[aggr.Col] = row2[aggr.Col] - } + // we just grab the first value per grouping. no need to do anything more complicated here default: return nil, nil, fmt.Errorf("BUG: Unexpected opcode: %v", aggr.Opcode) } diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index e95039d6825..1906ba1d6b4 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -3887,7 +3887,7 @@ func TestSelectAggregationRandom(t *testing.T) { rs, err := executor.Execute(context.Background(), "TestSelectCFC", session, "select /*vt+ PLANNER=gen4 */ A.a, A.b, (A.a / A.b) as c from (select sum(a) as a, sum(b) as b from user) A", nil) require.NoError(t, err) - assert.Equal(t, `[[INT64(10) INT64(1) INT64(10)]]`, fmt.Sprintf("%v", rs.Rows)) + assert.Equal(t, `[[INT64(10) INT64(1) DECIMAL(10.0000)]]`, fmt.Sprintf("%v", rs.Rows)) } func TestSelectHexAndBit(t *testing.T) { From d54da3562412c609ff6f34f79616c4078548743c Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 15 Mar 2023 15:46:59 +0100 Subject: [PATCH 4/5] clean up code Signed-off-by: Andres Taylor --- go/vt/vtgate/planbuilder/gen4_planner.go | 22 ++++---------------- go/vt/vtgate/planbuilder/horizon_planning.go | 19 +++++++---------- 2 files changed, 12 insertions(+), 29 deletions(-) diff --git a/go/vt/vtgate/planbuilder/gen4_planner.go b/go/vt/vtgate/planbuilder/gen4_planner.go index 8de2ba02f1f..6822dcff642 100644 --- a/go/vt/vtgate/planbuilder/gen4_planner.go +++ b/go/vt/vtgate/planbuilder/gen4_planner.go @@ -216,10 +216,7 @@ func newBuildSelectPlan( return nil, nil, nil, err } - plan, err = optimizePlan(plan) - if err != nil { - return nil, nil, nil, err - } + optimizePlan(plan) sel, isSel := selStmt.(*sqlparser.Select) if isSel { @@ -241,19 +238,9 @@ func newBuildSelectPlan( } // optimizePlan removes unnecessary simpleProjections that have been created while planning -func optimizePlan(plan logicalPlan) (output logicalPlan, err error) { - output = plan - inputs := make([]logicalPlan, len(plan.Inputs())) - for i, lp := range plan.Inputs() { - in, err := optimizePlan(lp) - if err != nil { - return nil, err - } - inputs[i] = in - } - err = plan.Rewrite(inputs...) - if err != nil { - return +func optimizePlan(plan logicalPlan) { + for _, lp := range plan.Inputs() { + optimizePlan(lp) } this, ok := plan.(*simpleProjection) @@ -270,7 +257,6 @@ func optimizePlan(plan logicalPlan) (output logicalPlan, err error) { this.eSimpleProj.Cols[i] = input.eSimpleProj.Cols[col] } this.input = input.input - return } func gen4UpdateStmtPlanner( diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 664bce56b93..4e33f62ebe5 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -60,8 +60,8 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo // a simpleProjection. We create a new Route that contains the derived table in the // FROM clause. Meaning that, when we push expressions to the select list of this // new Route, we do not want them to rewrite them. - sp, isSimpleProj := plan.(*simpleProjection) - if isSimpleProj { + sp, derivedTable := plan.(*simpleProjection) + if derivedTable { oldRewriteDerivedExpr := ctx.RewriteDerivedExpr defer func() { ctx.RewriteDerivedExpr = oldRewriteDerivedExpr @@ -94,15 +94,7 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo if err != nil { return nil, err } - default: - if !isSimpleProj { - err = pushProjections(ctx, plan, hp.qp.SelectExprs) - if err != nil { - return nil, err - } - break - } - + case derivedTable: pusher := func(ae *sqlparser.AliasedExpr) (int, error) { offset, _, err := pushProjection(ctx, ae, sp.input, true, true, false) return offset, err @@ -122,6 +114,11 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo columns: projections, columnNames: colNames, } + default: + err = pushProjections(ctx, plan, hp.qp.SelectExprs) + if err != nil { + return nil, err + } } // If we didn't already take care of ORDER BY during aggregation planning, we need to handle it now From e9a203f00f129f5fb1947dd08b91d5bc3a7a412f Mon Sep 17 00:00:00 2001 From: Florent Poinsard Date: Thu, 16 Mar 2023 14:51:52 +0200 Subject: [PATCH 5/5] fix executor test mock Signed-off-by: Florent Poinsard --- go/vt/vtgate/executor_select_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 1906ba1d6b4..8ba046146a7 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -3870,14 +3870,14 @@ func TestSelectAggregationRandom(t *testing.T) { conns = append(conns, sbc) sbc.SetResults([]*sqltypes.Result{sqltypes.MakeTestResult( - sqltypes.MakeTestFields("a|b|c", "int64|int64|int64"), - "null|null|null", + sqltypes.MakeTestFields("a|b", "int64|int64"), + "null|null", )}) } conns[0].SetResults([]*sqltypes.Result{sqltypes.MakeTestResult( - sqltypes.MakeTestFields("a|b|c", "int64|int64|int64"), - "10|1|10", + sqltypes.MakeTestFields("a|b", "int64|int64"), + "10|1", )}) executor := createExecutor(serv, cell, resolver)