diff --git a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go index a714550fb50..0951900982d 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go @@ -452,3 +452,18 @@ func TestMinMaxAcrossJoins(t *testing.T) { `SELECT /*vt+ PLANNER=gen4 */ t1.name, max(t1.shardKey), t2.shardKey, min(t2.id) FROM t1 JOIN t2 ON t1.t1_id != t2.shardKey GROUP BY t1.name, t2.shardKey`, `[[VARCHAR("name 2") INT64(2) INT64(10) INT64(1)] [VARCHAR("name 1") INT64(1) INT64(10) INT64(1)] [VARCHAR("name 2") INT64(2) INT64(20) INT64(2)] [VARCHAR("name 1") INT64(1) INT64(20) INT64(2)]]`) } + +func TestComplexAggregation(t *testing.T) { + mcmp, closer := start(t) + defer closer() + mcmp.Exec("insert into t1(t1_id, `name`, `value`, shardkey) values(1,'a1','foo',100), (2,'b1','foo',200), (3,'c1','foo',300), (4,'a1','foo',100), (5,'d1','toto',200), (6,'c1','tata',893), (7,'a1','titi',2380), (8,'b1','tete',12833), (9,'e1','yoyo',783493)") + + mcmp.Exec("set @@sql_mode = ' '") + mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ 1+COUNT(t1_id) FROM t1`) + mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ COUNT(t1_id)+1 FROM t1`) + mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ COUNT(t1_id)+MAX(shardkey) FROM t1`) + mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ shardkey, MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`) + mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ shardkey + MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`) + mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ name+COUNT(t1_id)+1 FROM t1 GROUP BY name`) + mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ COUNT(*)+shardkey+MIN(t1_id)+1+MAX(t1_id)*SUM(t1_id)+1+name FROM t1 GROUP BY shardkey, name`) +} diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 1375f5ff690..7de3a7b97b5 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -297,7 +297,7 @@ func (hp *horizonPlanning) planAggrUsingOA( } } - aggregationExprs, err := hp.qp.AggregationExpressions(ctx) + aggregationExprs, _, err := hp.qp.AggregationExpressions(ctx, false) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go index ae730277a54..2342d6edb27 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go +++ b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go @@ -112,7 +112,7 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon horizo return nil, err } - aggregations, err := qp.AggregationExpressions(ctx) + aggregations, complexAggr, err := qp.AggregationExpressions(ctx, true) if err != nil { return nil, err } @@ -135,6 +135,13 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon horizo a.Alias = derived.Alias } + if complexAggr { + return createProjectionForComplexAggregation(a, qp) + } + return createProjectionForSimpleAggregation(ctx, a, qp) +} + +func createProjectionForSimpleAggregation(ctx *plancontext.PlanningContext, a *Aggregator, qp *QueryProjection) (ops.Operator, error) { outer: for colIdx, expr := range qp.SelectExprs { ae, err := expr.GetAliasedExpr() @@ -165,10 +172,35 @@ outer: } return nil, vterrors.VT13001(fmt.Sprintf("Could not find the %s in aggregation in the original query", sqlparser.String(ae))) } - return a, nil } +func createProjectionForComplexAggregation(a *Aggregator, qp *QueryProjection) (ops.Operator, error) { + p := &Projection{ + Source: a, + Alias: a.Alias, + TableID: a.TableID, + } + + for _, expr := range qp.SelectExprs { + ae, err := expr.GetAliasedExpr() + if err != nil { + return nil, err + } + p.Columns = append(p.Columns, ae) + p.Projections = append(p.Projections, UnexploredExpression{E: ae.Expr}) + } + for i, by := range a.Grouping { + a.Grouping[i].ColOffset = len(a.Columns) + a.Columns = append(a.Columns, aeWrap(by.SimplifiedExpr)) + } + for i, aggregation := range a.Aggregations { + a.Aggregations[i].ColOffset = len(a.Columns) + a.Columns = append(a.Columns, aggregation.Original) + } + return p, nil +} + func createProjectionWithoutAggr(qp *QueryProjection, src ops.Operator) (*Projection, error) { proj := &Projection{ Source: src, diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 99bfbe9b3c2..4764f26d58b 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -91,8 +91,11 @@ type ( OriginalOpCode opcode.AggregateOpcode Alias string + // The index at which the user expects to see this aggregated function. Set to nil, if the user does not ask for it - Index *int + // Only used in the old Horizon Planner + Index *int + Distinct bool // the offsets point to columns on the same aggregator @@ -444,13 +447,9 @@ func checkForInvalidAggregations(exp *sqlparser.AliasedExpr) error { }, exp.Expr) } -func (qp *QueryProjection) isExprInGroupByExprs(ctx *plancontext.PlanningContext, expr SelectExpr) bool { +func (qp *QueryProjection) isExprInGroupByExprs(ctx *plancontext.PlanningContext, expr sqlparser.Expr) bool { for _, groupByExpr := range qp.groupByExprs { - exp, err := expr.GetExpr() - if err != nil { - return false - } - if ctx.SemTable.EqualsExprWithDeps(groupByExpr.SimplifiedExpr, exp) { + if ctx.SemTable.EqualsExprWithDeps(groupByExpr.SimplifiedExpr, expr) { return true } } @@ -623,7 +622,7 @@ func (qp *QueryProjection) NeedsDistinct() bool { return true } -func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningContext) (out []Aggr, err error) { +func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningContext, allowComplexExpression bool) (out []Aggr, complex bool, err error) { orderBy: for _, orderExpr := range qp.OrderExprs { orderExpr := orderExpr.SimplifiedExpr @@ -649,27 +648,54 @@ orderBy: for idx, expr := range qp.SelectExprs { aliasedExpr, err := expr.GetAliasedExpr() if err != nil { - return nil, err + return nil, false, err } idxCopy := idx if !sqlparser.ContainsAggregation(expr.Col) { - if !qp.isExprInGroupByExprs(ctx, expr) { + getExpr, err := expr.GetExpr() + if err != nil { + return nil, false, err + } + if !qp.isExprInGroupByExprs(ctx, getExpr) { aggr := NewAggr(opcode.AggregateRandom, nil, aliasedExpr, aliasedExpr.ColumnName()) aggr.Index = &idxCopy out = append(out, aggr) } continue } - fnc, isAggregate := aliasedExpr.Expr.(sqlparser.AggrFunc) - if !isAggregate { - return nil, vterrors.VT12001("in scatter query: complex aggregate expression") + _, isAggregate := aliasedExpr.Expr.(sqlparser.AggrFunc) + if !isAggregate && !allowComplexExpression { + return nil, false, vterrors.VT12001("in scatter query: complex aggregate expression") } - aggr := createAggrFromAggrFunc(fnc, aliasedExpr) - aggr.Index = &idxCopy - out = append(out, aggr) + sqlparser.CopyOnRewrite(aliasedExpr.Expr, func(node, parent sqlparser.SQLNode) bool { + ex, isExpr := node.(sqlparser.Expr) + if !isExpr { + return true + } + if aggr, isAggr := node.(sqlparser.AggrFunc); isAggr { + ae := aeWrap(aggr) + if aggr == aliasedExpr.Expr { + ae = aliasedExpr + } + aggrFunc := createAggrFromAggrFunc(aggr, ae) + aggrFunc.Index = &idxCopy + out = append(out, aggrFunc) + return false + } + if sqlparser.ContainsAggregation(node) { + complex = true + return true + } + if !qp.isExprInGroupByExprs(ctx, ex) { + aggr := NewAggr(opcode.AggregateRandom, nil, aeWrap(ex), "") + aggr.Index = &idxCopy + out = append(out, aggr) + } + return false + }, nil, nil) } return } @@ -683,9 +709,7 @@ func createAggrFromAggrFunc(fnc sqlparser.AggrFunc, aliasedExpr *sqlparser.Alias } } - aggrF, _ := aliasedExpr.Expr.(sqlparser.AggrFunc) - - if aggrF.IsDistinct() { + if fnc.IsDistinct() { switch code { case opcode.AggregateCount: code = opcode.AggregateCountDistinct @@ -694,8 +718,8 @@ func createAggrFromAggrFunc(fnc sqlparser.AggrFunc, aliasedExpr *sqlparser.Alias } } - aggr := NewAggr(code, aggrF, aliasedExpr, aliasedExpr.ColumnName()) - aggr.Distinct = aggrF.IsDistinct() + aggr := NewAggr(code, fnc, aliasedExpr, aliasedExpr.ColumnName()) + aggr.Distinct = fnc.IsDistinct() return aggr } diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index fdf43c8266d..fdf88942bbc 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -6626,5 +6626,113 @@ "user.user_extra" ] } + }, + { + "comment": "Complex aggregate expression on scatter", + "query": "select 1+count(*) from user", + "v3-plan": "VT12001: unsupported: in scatter query: complex aggregate expression", + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select 1+count(*) from user", + "Instructions": { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] + [COLUMN 1] as 1 + count(*)" + ], + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "random(0), sum_count_star(1) AS count(*)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1, count(*) from `user` where 1 != 1", + "Query": "select 1, count(*) from `user`", + "Table": "`user`" + } + ] + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "combine the output of two aggregations in the final result", + "query": "select greatest(sum(user.foo), sum(user_extra.bar)) from user join user_extra on user.col = user_extra.col", + "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select greatest(sum(user.foo), sum(user_extra.bar)) from user join user_extra on user.col = user_extra.col", + "Instructions": { + "OperatorType": "Projection", + "Expressions": [ + "GREATEST([COLUMN 0], [COLUMN 1]) as greatest(sum(`user`.foo), sum(user_extra.bar))" + ], + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum(0) AS sum(`user`.foo), sum(1) AS sum(user_extra.bar)", + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * [COLUMN 1] as sum(`user`.foo)", + "[COLUMN 3] * [COLUMN 2] as sum(user_extra.bar)" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0,R:1,L:1", + "JoinVars": { + "user_col": 2 + }, + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select sum(`user`.foo), count(*), `user`.col from `user` where 1 != 1 group by `user`.col", + "Query": "select sum(`user`.foo), count(*), `user`.col from `user` group by `user`.col", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), sum(user_extra.bar) from user_extra where 1 != 1 group by .0", + "Query": "select count(*), sum(user_extra.bar) from user_extra where user_extra.col = :user_col group by .0", + "Table": "user_extra" + } + ] + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json index 7ef69afad01..bc30cc728f5 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json @@ -1059,7 +1059,68 @@ "comment": "TPC-H query 14", "query": "select 100.00 * sum(case when p_type like 'PROMO%' then l_extendedprice * (1 - l_discount) else 0 end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue from lineitem, part where l_partkey = p_partkey and l_shipdate >= date('1995-09-01') and l_shipdate < date('1995-09-01') + interval '1' month", "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", - "gen4-plan": "VT12001: unsupported: in scatter query: complex aggregate expression" + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select 100.00 * sum(case when p_type like 'PROMO%' then l_extendedprice * (1 - l_discount) else 0 end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue from lineitem, part where l_partkey = p_partkey and l_shipdate >= date('1995-09-01') and l_shipdate < date('1995-09-01') + interval '1' month", + "Instructions": { + "OperatorType": "Projection", + "Expressions": [ + "([COLUMN 0] * [COLUMN 1]) / [COLUMN 2] as promo_revenue" + ], + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "random(0), sum(1) AS sum(case when p_type like 'PROMO%' then l_extendedprice * (1 - l_discount) else 0 end), sum(2) AS sum(l_extendedprice * (1 - l_discount))", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0,L:3", + "JoinVars": { + "l_discount": 2, + "l_extendedprice": 1, + "l_partkey": 4 + }, + "TableName": "lineitem_part", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "main", + "Sharded": true + }, + "FieldQuery": "select 100.00, l_extendedprice, l_discount, l_extendedprice * (1 - l_discount), l_partkey from lineitem where 1 != 1", + "Query": "select 100.00, l_extendedprice, l_discount, l_extendedprice * (1 - l_discount), l_partkey from lineitem where l_shipdate >= date('1995-09-01') and l_shipdate < date('1995-09-01') + interval '1' month", + "Table": "lineitem" + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "main", + "Sharded": true + }, + "FieldQuery": "select case when p_type like 'PROMO%' then :l_extendedprice * (1 - :l_discount) else 0 end from part where 1 != 1", + "Query": "select case when p_type like 'PROMO%' then :l_extendedprice * (1 - :l_discount) else 0 end from part where p_partkey = :l_partkey", + "Table": "part", + "Values": [ + ":l_partkey" + ], + "Vindex": "hash" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "main.lineitem", + "main.part" + ] + } }, { "comment": "TPC-H query 15 view\n#\"with revenue0(supplier_no, total_revenue) as (select l_suppkey, sum(l_extendedprice * (1 - l_discount)) from lineitem where l_shipdate >= date('1996-01-01') and l_shipdate < date('1996-01-01') + interval '3' month group by l_suppkey )\"\n#\"syntax error at position 236\"\n#Gen4 plan same as above\n# TPC-H query 15", diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index 0d32c2f5b14..a47a7ba2c5d 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -68,11 +68,6 @@ "v3-plan": "VT12001: unsupported: '*' expression in cross-shard query", "gen4-plan": "cannot use column offsets in group statement when using `*`" }, - { - "comment": "Complex aggregate expression on scatter", - "query": "select 1+count(*) from user", - "plan": "VT12001: unsupported: in scatter query: complex aggregate expression" - }, { "comment": "Multi-value aggregates not supported", "query": "select count(a,b) from user", @@ -466,12 +461,6 @@ "query": "delete from user where x = (@val := 42)", "plan": "VT12001: unsupported: Assignment expression" }, - { - "comment": "combine the output of two aggregations in the final result", - "query": "select greatest(sum(user.foo), sum(user_extra.bar)) from user join user_extra on user.col = user_extra.col", - "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", - "gen4-plan": "VT12001: unsupported: in scatter query: complex aggregate expression" - }, { "comment": "extremum on input from both sides", "query": "insert into music(user_id, id) select foo, bar from music on duplicate key update id = id+1",