diff --git a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go index cf4bbaf9d6b..503effc2a33 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go @@ -428,3 +428,13 @@ func TestScalarAggregate(t *testing.T) { mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)") mcmp.AssertMatches("select /*vt+ PLANNER=gen4 */ count(distinct val1) from aggr_test", `[[INT64(3)]]`) } + +func TestAggregationRandomOnAnAggregatedValue(t *testing.T) { + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into t11(k, a, b) values (0, 100, 10), (10, 200, 20);") + + mcmp.AssertMatchesNoOrder("select /*vt+ PLANNER=gen4 */ A.a, A.b, (A.a / A.b) as d from (select sum(a) as a, sum(b) as b from t11 where a = 100) A;", + `[[DECIMAL(100) DECIMAL(10) DECIMAL(10.0000)]]`) +} diff --git a/go/test/endtoend/vtgate/queries/aggregation/schema.sql b/go/test/endtoend/vtgate/queries/aggregation/schema.sql index 944c3783048..6f388027c6d 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/schema.sql +++ b/go/test/endtoend/vtgate/queries/aggregation/schema.sql @@ -69,3 +69,9 @@ CREATE TABLE t2 ( shardKey bigint, PRIMARY KEY (id) ) ENGINE InnoDB; + +CREATE TABLE t11 ( + k BIGINT PRIMARY KEY, + a INT, + b INT +); diff --git a/go/test/endtoend/vtgate/queries/aggregation/vschema.json b/go/test/endtoend/vtgate/queries/aggregation/vschema.json index c2d3f133a35..727d6adc1d0 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/vschema.json +++ b/go/test/endtoend/vtgate/queries/aggregation/vschema.json @@ -123,6 +123,14 @@ "name": "hash" } ] + }, + "t11": { + "column_vindexes": [ + { + "column": "k", + "name": "hash" + } + ] } } } \ No newline at end of file diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index dd01106a2b6..3f9224fe003 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -3707,3 +3707,37 @@ func TestSelectAggregationData(t *testing.T) { }) } } + +func TestSelectAggregationRandom(t *testing.T) { + cell := "aa" + hc := discovery.NewFakeHealthCheck(nil) + createSandbox(KsTestSharded).VSchema = executorVSchema + getSandbox(KsTestUnsharded).VSchema = unshardedVSchema + serv := newSandboxForCells([]string{cell}) + resolver := newTestResolver(hc, serv, cell) + shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"} + var conns []*sandboxconn.SandboxConn + for _, shard := range shards { + sbc := hc.AddTestTablet(cell, shard, 1, KsTestSharded, shard, topodatapb.TabletType_PRIMARY, true, 1, nil) + conns = append(conns, sbc) + + sbc.SetResults([]*sqltypes.Result{sqltypes.MakeTestResult( + sqltypes.MakeTestFields("a|b", "int64|int64"), + "null|null", + )}) + } + + conns[0].SetResults([]*sqltypes.Result{sqltypes.MakeTestResult( + sqltypes.MakeTestFields("a|b", "int64|int64"), + "10|1", + )}) + + executor := createExecutor(serv, cell, resolver) + executor.pv = querypb.ExecuteOptions_Gen4 + session := NewAutocommitSession(&vtgatepb.Session{}) + + rs, err := executor.Execute(context.Background(), "TestSelectCFC", session, + "select /*vt+ PLANNER=gen4 */ A.a, A.b, (A.a / A.b) as c from (select sum(a) as a, sum(b) as b from user) A", nil) + require.NoError(t, err) + assert.Equal(t, `[[INT64(10) INT64(1) DECIMAL(10.0000)]]`, fmt.Sprintf("%v", rs.Rows)) +} diff --git a/go/vt/vtgate/planbuilder/abstract/queryprojection.go b/go/vt/vtgate/planbuilder/abstract/queryprojection.go index 180d23e00a5..9b7b9df34ce 100644 --- a/go/vt/vtgate/planbuilder/abstract/queryprojection.go +++ b/go/vt/vtgate/planbuilder/abstract/queryprojection.go @@ -22,6 +22,8 @@ import ( "strings" "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" @@ -428,7 +430,87 @@ func (qp *QueryProjection) NeedsAggregation() bool { return qp.HasAggr || len(qp.groupByExprs) > 0 } -func (qp QueryProjection) onlyAggr() bool { +// NeedsProjecting returns true if we have projections that need to be evaluated at the vtgate level +// and can't be pushed down to MySQL +func (qp *QueryProjection) NeedsProjecting( + ctx *plancontext.PlanningContext, + pusher func(expr *sqlparser.AliasedExpr) (int, error), +) (needsVtGateEval bool, expressions []sqlparser.Expr, colNames []string, err error) { + for _, se := range qp.SelectExprs { + var ae *sqlparser.AliasedExpr + ae, err = se.GetAliasedExpr() + if err != nil { + return false, nil, nil, err + } + + expr := ae.Expr + colNames = append(colNames, ae.ColumnName()) + + if _, isCol := expr.(*sqlparser.ColName); isCol { + offset, err := pusher(ae) + if err != nil { + return false, nil, nil, err + } + expressions = append(expressions, sqlparser.NewOffset(offset, expr)) + continue + } + + rExpr := sqlparser.CloneExpr(expr) + sqlparser.Rewrite(rExpr, func(cursor *sqlparser.Cursor) bool { + col, isCol := cursor.Node().(*sqlparser.ColName) + if !isCol { + return true + } + var tableInfo semantics.TableInfo + tableInfo, err = ctx.SemTable.TableInfoForExpr(col) + if err != nil { + return true + } + _, isDT := tableInfo.(*semantics.DerivedTable) + if !isDT { + return true + } + + rewritten, tErr := semantics.RewriteDerivedExpression(col, tableInfo) + if tErr != nil { + err = tErr + return false + } + if sqlparser.ContainsAggregation(rewritten) { + offset, tErr := pusher(&sqlparser.AliasedExpr{Expr: col}) + if tErr != nil { + err = tErr + return false + } + + cursor.Replace(sqlparser.NewOffset(offset, col)) + } + return true + }, nil) + + if err != nil { + return + } + + if !sqlparser.EqualsExpr(rExpr, expr) { + // if we changed the expression, it means that we have to evaluate the rest at the vtgate level + expressions = append(expressions, rExpr) + needsVtGateEval = true + continue + } + + // we did not need to push any parts of this expression down. Let's check if we can push all of it + offset, err := pusher(ae) + if err != nil { + return false, nil, nil, err + } + expressions = append(expressions, sqlparser.NewOffset(offset, expr)) + } + + return +} + +func (qp *QueryProjection) onlyAggr() bool { if !qp.HasAggr { return false } diff --git a/go/vt/vtgate/planbuilder/gen4_planner.go b/go/vt/vtgate/planbuilder/gen4_planner.go index e524af11457..b54cf62894c 100644 --- a/go/vt/vtgate/planbuilder/gen4_planner.go +++ b/go/vt/vtgate/planbuilder/gen4_planner.go @@ -209,6 +209,8 @@ func newBuildSelectPlan( return nil, err } + optimizePlan(plan) + sel, isSel := selStmt.(*sqlparser.Select) if isSel { if err := setMiscFunc(plan, sel); err != nil { @@ -228,6 +230,27 @@ func newBuildSelectPlan( return plan, nil } +func optimizePlan(plan logicalPlan) { + for _, lp := range plan.Inputs() { + optimizePlan(lp) + } + + this, ok := plan.(*simpleProjection) + if !ok { + return + } + + input, ok := this.input.(*simpleProjection) + if !ok { + return + } + + for i, col := range this.eSimpleProj.Cols { + this.eSimpleProj.Cols[i] = input.eSimpleProj.Cols[col] + } + this.input = input.input +} + func gen4UpdateStmtPlanner( version querypb.ExecuteOptions_PlannerVersion, updStmt *sqlparser.Update, diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 0527b8d70b5..0f3b25a70b4 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -60,7 +60,8 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo // a simpleProjection. We create a new Route that contains the derived table in the // FROM clause. Meaning that, when we push expressions to the select list of this // new Route, we do not want them to rewrite them. - if _, isSimpleProj := plan.(*simpleProjection); isSimpleProj { + sp, derivedTable := plan.(*simpleProjection) + if derivedTable { oldRewriteDerivedExpr := ctx.RewriteDerivedExpr defer func() { ctx.RewriteDerivedExpr = oldRewriteDerivedExpr @@ -75,10 +76,11 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo } needsOrdering := len(hp.qp.OrderExprs) > 0 - canShortcut := isRoute && hp.sel.Having == nil && !needsOrdering // If we still have a HAVING clause, it's because it could not be pushed to the WHERE, // so it probably has aggregations + canShortcut := isRoute && hp.sel.Having == nil && !needsOrdering + switch { case hp.qp.NeedsAggregation() || hp.sel.Having != nil: plan, err = hp.planAggregations(ctx, plan) @@ -92,6 +94,26 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo if err != nil { return nil, err } + case derivedTable: + pusher := func(ae *sqlparser.AliasedExpr) (int, error) { + offset, _, err := pushProjection(ctx, ae, sp.input, true, true, false) + return offset, err + } + needsVtGate, projections, colNames, err := hp.qp.NeedsProjecting(ctx, pusher) + if err != nil { + return nil, err + } + if !needsVtGate { + break + } + + // there were some expressions we could not push down entirely, + // so replace the simpleProjection with a real projection + plan = &projection{ + source: sp.input, + columns: projections, + columnNames: colNames, + } default: err = pushProjections(ctx, plan, hp.qp.SelectExprs) if err != nil { diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index 70cb6294a3b..8768d337f9a 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -5273,5 +5273,42 @@ ] } } + }, + { + "comment": "Aggregations from derived table used in arithmetic outside derived table", + "query": "select A.a, A.b, (A.a / A.b) as d from (select sum(a) as a, sum(b) as b from user) A", + "v3-plan": "unsupported: expression on results of a cross-shard subquery", + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select A.a, A.b, (A.a / A.b) as d from (select sum(a) as a, sum(b) as b from user) A", + "Instructions": { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] as a", + "[COLUMN 1] as b", + "[COLUMN 0] / [COLUMN 1] as d" + ], + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum(0) AS a, sum(1) AS b", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select sum(a) as a, sum(b) as b from `user` where 1 != 1", + "Query": "select sum(a) as a, sum(b) as b from `user`", + "Table": "`user`" + } + ] + } + ] + } + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index 667f974a1c7..1c9c0ddfc0a 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -4655,7 +4655,7 @@ { "comment": "Mergeable subquery with multiple levels of derived statements, using a single value `IN` predicate", "query": "SELECT music.id FROM music WHERE music.id IN (SELECT * FROM (SELECT * FROM (SELECT music.id FROM music WHERE music.user_id IN (5) LIMIT 10) subquery_for_limit) subquery_for_limit)", - "plan": { + "v3-plan": { "QueryType": "SELECT", "Original": "SELECT music.id FROM music WHERE music.id IN (SELECT * FROM (SELECT * FROM (SELECT music.id FROM music WHERE music.user_id IN (5) LIMIT 10) subquery_for_limit) subquery_for_limit)", "Instructions": { @@ -4720,12 +4720,71 @@ } ] } + }, + "gen4-plan": { + "QueryType": "SELECT", + "Original": "SELECT music.id FROM music WHERE music.id IN (SELECT * FROM (SELECT * FROM (SELECT music.id FROM music WHERE music.user_id IN (5) LIMIT 10) subquery_for_limit) subquery_for_limit)", + "Instructions": { + "OperatorType": "Subquery", + "Variant": "PulloutIn", + "PulloutVars": [ + "__sq_has_values1", + "__sq1" + ], + "Inputs": [ + { + "OperatorType": "SimpleProjection", + "Columns": [ + 0 + ], + "Inputs": [ + { + "OperatorType": "Limit", + "Count": "INT64(10)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "IN", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select music.id from music where 1 != 1", + "Query": "select music.id from music where music.user_id in ::__vals limit :__upper_limit", + "Table": "music", + "Values": [ + "(INT64(5))" + ], + "Vindex": "user_index" + } + ] + } + ] + }, + { + "OperatorType": "Route", + "Variant": "IN", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select music.id from music where 1 != 1", + "Query": "select music.id from music where :__sq_has_values1 = 1 and music.id in ::__vals", + "Table": "music", + "Values": [ + ":__sq1" + ], + "Vindex": "music_user_map" + } + ] + } } + }, { "comment": "Unmergeable subquery with multiple levels of derived statements, using a multi value `IN` predicate", "query": "SELECT music.id FROM music WHERE music.id IN (SELECT * FROM (SELECT * FROM (SELECT music.id FROM music WHERE music.user_id IN (5, 6) LIMIT 10) subquery_for_limit) subquery_for_limit)", - "plan": { + "v3-plan": { "QueryType": "SELECT", "Original": "SELECT music.id FROM music WHERE music.id IN (SELECT * FROM (SELECT * FROM (SELECT music.id FROM music WHERE music.user_id IN (5, 6) LIMIT 10) subquery_for_limit) subquery_for_limit)", "Instructions": { @@ -4790,12 +4849,70 @@ } ] } + }, + "gen4-plan": { + "QueryType": "SELECT", + "Original": "SELECT music.id FROM music WHERE music.id IN (SELECT * FROM (SELECT * FROM (SELECT music.id FROM music WHERE music.user_id IN (5, 6) LIMIT 10) subquery_for_limit) subquery_for_limit)", + "Instructions": { + "OperatorType": "Subquery", + "Variant": "PulloutIn", + "PulloutVars": [ + "__sq_has_values1", + "__sq1" + ], + "Inputs": [ + { + "OperatorType": "SimpleProjection", + "Columns": [ + 0 + ], + "Inputs": [ + { + "OperatorType": "Limit", + "Count": "INT64(10)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "IN", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select music.id from music where 1 != 1", + "Query": "select music.id from music where music.user_id in ::__vals limit :__upper_limit", + "Table": "music", + "Values": [ + "(INT64(5), INT64(6))" + ], + "Vindex": "user_index" + } + ] + } + ] + }, + { + "OperatorType": "Route", + "Variant": "IN", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select music.id from music where 1 != 1", + "Query": "select music.id from music where :__sq_has_values1 = 1 and music.id in ::__vals", + "Table": "music", + "Values": [ + ":__sq1" + ], + "Vindex": "music_user_map" + } + ] + } } }, { "comment": "Unmergeable subquery with multiple levels of derived statements", "query": "SELECT music.id FROM music WHERE music.id IN (SELECT * FROM (SELECT * FROM (SELECT music.id FROM music LIMIT 10) subquery_for_limit) subquery_for_limit)", - "plan": { + "v3-plan": { "QueryType": "SELECT", "Original": "SELECT music.id FROM music WHERE music.id IN (SELECT * FROM (SELECT * FROM (SELECT music.id FROM music LIMIT 10) subquery_for_limit) subquery_for_limit)", "Instructions": { @@ -4856,6 +4973,60 @@ } ] } + }, + "gen4-plan": { + "QueryType": "SELECT", + "Original": "SELECT music.id FROM music WHERE music.id IN (SELECT * FROM (SELECT * FROM (SELECT music.id FROM music LIMIT 10) subquery_for_limit) subquery_for_limit)", + "Instructions": { + "OperatorType": "Subquery", + "Variant": "PulloutIn", + "PulloutVars": [ + "__sq_has_values1", + "__sq1" + ], + "Inputs": [ + { + "OperatorType": "SimpleProjection", + "Columns": [ + 0 + ], + "Inputs": [ + { + "OperatorType": "Limit", + "Count": "INT64(10)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select music.id from music where 1 != 1", + "Query": "select music.id from music limit :__upper_limit", + "Table": "music" + } + ] + } + ] + }, + { + "OperatorType": "Route", + "Variant": "IN", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select music.id from music where 1 != 1", + "Query": "select music.id from music where :__sq_has_values1 = 1 and music.id in ::__vals", + "Table": "music", + "Values": [ + ":__sq1" + ], + "Vindex": "music_user_map" + } + ] + } } }, { diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 72e0835ad12..6283ca64c0a 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -307,7 +307,7 @@ func (d ExprDependencies) dependencies(expr sqlparser.Expr) (deps TableSet) { // We need `foo` to be translated to `id+42` on the inside of the derived table func RewriteDerivedExpression(expr sqlparser.Expr, vt TableInfo) (sqlparser.Expr, error) { newExpr := sqlparser.CloneExpr(expr) - sqlparser.Rewrite(newExpr, func(cursor *sqlparser.Cursor) bool { + n := sqlparser.Rewrite(newExpr, func(cursor *sqlparser.Cursor) bool { switch node := cursor.Node().(type) { case *sqlparser.ColName: exp, err := vt.getExprFor(node.Name.String()) @@ -323,7 +323,7 @@ func RewriteDerivedExpression(expr sqlparser.Expr, vt TableInfo) (sqlparser.Expr } return true }, nil) - return newExpr, nil + return n.(sqlparser.Expr), nil } // FindSubqueryReference goes over the sub queries and searches for it by value equality instead of reference equality