Skip to content

Commit

Permalink
Support complex aggregation in Gen4's Operators (#13326)
Browse files Browse the repository at this point in the history
  • Loading branch information
frouioui authored Jun 16, 2023
1 parent d78b5ab commit 509788b
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 36 deletions.
15 changes: 15 additions & 0 deletions go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,18 @@ func TestMinMaxAcrossJoins(t *testing.T) {
`SELECT /*vt+ PLANNER=gen4 */ t1.name, max(t1.shardKey), t2.shardKey, min(t2.id) FROM t1 JOIN t2 ON t1.t1_id != t2.shardKey GROUP BY t1.name, t2.shardKey`,
`[[VARCHAR("name 2") INT64(2) INT64(10) INT64(1)] [VARCHAR("name 1") INT64(1) INT64(10) INT64(1)] [VARCHAR("name 2") INT64(2) INT64(20) INT64(2)] [VARCHAR("name 1") INT64(1) INT64(20) INT64(2)]]`)
}

func TestComplexAggregation(t *testing.T) {
mcmp, closer := start(t)
defer closer()
mcmp.Exec("insert into t1(t1_id, `name`, `value`, shardkey) values(1,'a1','foo',100), (2,'b1','foo',200), (3,'c1','foo',300), (4,'a1','foo',100), (5,'d1','toto',200), (6,'c1','tata',893), (7,'a1','titi',2380), (8,'b1','tete',12833), (9,'e1','yoyo',783493)")

mcmp.Exec("set @@sql_mode = ' '")
mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ 1+COUNT(t1_id) FROM t1`)
mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ COUNT(t1_id)+1 FROM t1`)
mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ COUNT(t1_id)+MAX(shardkey) FROM t1`)
mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ shardkey, MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`)
mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ shardkey + MIN(t1_id)+MAX(t1_id) FROM t1 GROUP BY shardkey`)
mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ name+COUNT(t1_id)+1 FROM t1 GROUP BY name`)
mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ COUNT(*)+shardkey+MIN(t1_id)+1+MAX(t1_id)*SUM(t1_id)+1+name FROM t1 GROUP BY shardkey, name`)
}
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/horizon_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ func (hp *horizonPlanning) planAggrUsingOA(
}
}

aggregationExprs, err := hp.qp.AggregationExpressions(ctx)
aggregationExprs, _, err := hp.qp.AggregationExpressions(ctx, false)
if err != nil {
return nil, err
}
Expand Down
36 changes: 34 additions & 2 deletions go/vt/vtgate/planbuilder/operators/horizon_expanding.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon horizo
return nil, err
}

aggregations, err := qp.AggregationExpressions(ctx)
aggregations, complexAggr, err := qp.AggregationExpressions(ctx, true)
if err != nil {
return nil, err
}
Expand All @@ -135,6 +135,13 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon horizo
a.Alias = derived.Alias
}

if complexAggr {
return createProjectionForComplexAggregation(a, qp)
}
return createProjectionForSimpleAggregation(ctx, a, qp)
}

func createProjectionForSimpleAggregation(ctx *plancontext.PlanningContext, a *Aggregator, qp *QueryProjection) (ops.Operator, error) {
outer:
for colIdx, expr := range qp.SelectExprs {
ae, err := expr.GetAliasedExpr()
Expand Down Expand Up @@ -165,10 +172,35 @@ outer:
}
return nil, vterrors.VT13001(fmt.Sprintf("Could not find the %s in aggregation in the original query", sqlparser.String(ae)))
}

return a, nil
}

func createProjectionForComplexAggregation(a *Aggregator, qp *QueryProjection) (ops.Operator, error) {
p := &Projection{
Source: a,
Alias: a.Alias,
TableID: a.TableID,
}

for _, expr := range qp.SelectExprs {
ae, err := expr.GetAliasedExpr()
if err != nil {
return nil, err
}
p.Columns = append(p.Columns, ae)
p.Projections = append(p.Projections, UnexploredExpression{E: ae.Expr})
}
for i, by := range a.Grouping {
a.Grouping[i].ColOffset = len(a.Columns)
a.Columns = append(a.Columns, aeWrap(by.SimplifiedExpr))
}
for i, aggregation := range a.Aggregations {
a.Aggregations[i].ColOffset = len(a.Columns)
a.Columns = append(a.Columns, aggregation.Original)
}
return p, nil
}

func createProjectionWithoutAggr(qp *QueryProjection, src ops.Operator) (*Projection, error) {
proj := &Projection{
Source: src,
Expand Down
66 changes: 45 additions & 21 deletions go/vt/vtgate/planbuilder/operators/queryprojection.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,11 @@ type (
OriginalOpCode opcode.AggregateOpcode

Alias string

// The index at which the user expects to see this aggregated function. Set to nil, if the user does not ask for it
Index *int
// Only used in the old Horizon Planner
Index *int

Distinct bool

// the offsets point to columns on the same aggregator
Expand Down Expand Up @@ -444,13 +447,9 @@ func checkForInvalidAggregations(exp *sqlparser.AliasedExpr) error {
}, exp.Expr)
}

func (qp *QueryProjection) isExprInGroupByExprs(ctx *plancontext.PlanningContext, expr SelectExpr) bool {
func (qp *QueryProjection) isExprInGroupByExprs(ctx *plancontext.PlanningContext, expr sqlparser.Expr) bool {
for _, groupByExpr := range qp.groupByExprs {
exp, err := expr.GetExpr()
if err != nil {
return false
}
if ctx.SemTable.EqualsExprWithDeps(groupByExpr.SimplifiedExpr, exp) {
if ctx.SemTable.EqualsExprWithDeps(groupByExpr.SimplifiedExpr, expr) {
return true
}
}
Expand Down Expand Up @@ -623,7 +622,7 @@ func (qp *QueryProjection) NeedsDistinct() bool {
return true
}

func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningContext) (out []Aggr, err error) {
func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningContext, allowComplexExpression bool) (out []Aggr, complex bool, err error) {
orderBy:
for _, orderExpr := range qp.OrderExprs {
orderExpr := orderExpr.SimplifiedExpr
Expand All @@ -649,27 +648,54 @@ orderBy:
for idx, expr := range qp.SelectExprs {
aliasedExpr, err := expr.GetAliasedExpr()
if err != nil {
return nil, err
return nil, false, err
}

idxCopy := idx

if !sqlparser.ContainsAggregation(expr.Col) {
if !qp.isExprInGroupByExprs(ctx, expr) {
getExpr, err := expr.GetExpr()
if err != nil {
return nil, false, err
}
if !qp.isExprInGroupByExprs(ctx, getExpr) {
aggr := NewAggr(opcode.AggregateRandom, nil, aliasedExpr, aliasedExpr.ColumnName())
aggr.Index = &idxCopy
out = append(out, aggr)
}
continue
}
fnc, isAggregate := aliasedExpr.Expr.(sqlparser.AggrFunc)
if !isAggregate {
return nil, vterrors.VT12001("in scatter query: complex aggregate expression")
_, isAggregate := aliasedExpr.Expr.(sqlparser.AggrFunc)
if !isAggregate && !allowComplexExpression {
return nil, false, vterrors.VT12001("in scatter query: complex aggregate expression")
}

aggr := createAggrFromAggrFunc(fnc, aliasedExpr)
aggr.Index = &idxCopy
out = append(out, aggr)
sqlparser.CopyOnRewrite(aliasedExpr.Expr, func(node, parent sqlparser.SQLNode) bool {
ex, isExpr := node.(sqlparser.Expr)
if !isExpr {
return true
}
if aggr, isAggr := node.(sqlparser.AggrFunc); isAggr {
ae := aeWrap(aggr)
if aggr == aliasedExpr.Expr {
ae = aliasedExpr
}
aggrFunc := createAggrFromAggrFunc(aggr, ae)
aggrFunc.Index = &idxCopy
out = append(out, aggrFunc)
return false
}
if sqlparser.ContainsAggregation(node) {
complex = true
return true
}
if !qp.isExprInGroupByExprs(ctx, ex) {
aggr := NewAggr(opcode.AggregateRandom, nil, aeWrap(ex), "")
aggr.Index = &idxCopy
out = append(out, aggr)
}
return false
}, nil, nil)
}
return
}
Expand All @@ -683,9 +709,7 @@ func createAggrFromAggrFunc(fnc sqlparser.AggrFunc, aliasedExpr *sqlparser.Alias
}
}

aggrF, _ := aliasedExpr.Expr.(sqlparser.AggrFunc)

if aggrF.IsDistinct() {
if fnc.IsDistinct() {
switch code {
case opcode.AggregateCount:
code = opcode.AggregateCountDistinct
Expand All @@ -694,8 +718,8 @@ func createAggrFromAggrFunc(fnc sqlparser.AggrFunc, aliasedExpr *sqlparser.Alias
}
}

aggr := NewAggr(code, aggrF, aliasedExpr, aliasedExpr.ColumnName())
aggr.Distinct = aggrF.IsDistinct()
aggr := NewAggr(code, fnc, aliasedExpr, aliasedExpr.ColumnName())
aggr.Distinct = fnc.IsDistinct()
return aggr
}

Expand Down
108 changes: 108 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/aggr_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -6626,5 +6626,113 @@
"user.user_extra"
]
}
},
{
"comment": "Complex aggregate expression on scatter",
"query": "select 1+count(*) from user",
"v3-plan": "VT12001: unsupported: in scatter query: complex aggregate expression",
"gen4-plan": {
"QueryType": "SELECT",
"Original": "select 1+count(*) from user",
"Instructions": {
"OperatorType": "Projection",
"Expressions": [
"[COLUMN 0] + [COLUMN 1] as 1 + count(*)"
],
"Inputs": [
{
"OperatorType": "Aggregate",
"Variant": "Scalar",
"Aggregates": "random(0), sum_count_star(1) AS count(*)",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select 1, count(*) from `user` where 1 != 1",
"Query": "select 1, count(*) from `user`",
"Table": "`user`"
}
]
}
]
},
"TablesUsed": [
"user.user"
]
}
},
{
"comment": "combine the output of two aggregations in the final result",
"query": "select greatest(sum(user.foo), sum(user_extra.bar)) from user join user_extra on user.col = user_extra.col",
"v3-plan": "VT12001: unsupported: cross-shard query with aggregates",
"gen4-plan": {
"QueryType": "SELECT",
"Original": "select greatest(sum(user.foo), sum(user_extra.bar)) from user join user_extra on user.col = user_extra.col",
"Instructions": {
"OperatorType": "Projection",
"Expressions": [
"GREATEST([COLUMN 0], [COLUMN 1]) as greatest(sum(`user`.foo), sum(user_extra.bar))"
],
"Inputs": [
{
"OperatorType": "Aggregate",
"Variant": "Scalar",
"Aggregates": "sum(0) AS sum(`user`.foo), sum(1) AS sum(user_extra.bar)",
"Inputs": [
{
"OperatorType": "Projection",
"Expressions": [
"[COLUMN 0] * [COLUMN 1] as sum(`user`.foo)",
"[COLUMN 3] * [COLUMN 2] as sum(user_extra.bar)"
],
"Inputs": [
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:0,R:0,R:1,L:1",
"JoinVars": {
"user_col": 2
},
"TableName": "`user`_user_extra",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select sum(`user`.foo), count(*), `user`.col from `user` where 1 != 1 group by `user`.col",
"Query": "select sum(`user`.foo), count(*), `user`.col from `user` group by `user`.col",
"Table": "`user`"
},
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select count(*), sum(user_extra.bar) from user_extra where 1 != 1 group by .0",
"Query": "select count(*), sum(user_extra.bar) from user_extra where user_extra.col = :user_col group by .0",
"Table": "user_extra"
}
]
}
]
}
]
}
]
},
"TablesUsed": [
"user.user",
"user.user_extra"
]
}
}
]
63 changes: 62 additions & 1 deletion go/vt/vtgate/planbuilder/testdata/tpch_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,68 @@
"comment": "TPC-H query 14",
"query": "select 100.00 * sum(case when p_type like 'PROMO%' then l_extendedprice * (1 - l_discount) else 0 end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue from lineitem, part where l_partkey = p_partkey and l_shipdate >= date('1995-09-01') and l_shipdate < date('1995-09-01') + interval '1' month",
"v3-plan": "VT12001: unsupported: cross-shard query with aggregates",
"gen4-plan": "VT12001: unsupported: in scatter query: complex aggregate expression"
"gen4-plan": {
"QueryType": "SELECT",
"Original": "select 100.00 * sum(case when p_type like 'PROMO%' then l_extendedprice * (1 - l_discount) else 0 end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue from lineitem, part where l_partkey = p_partkey and l_shipdate >= date('1995-09-01') and l_shipdate < date('1995-09-01') + interval '1' month",
"Instructions": {
"OperatorType": "Projection",
"Expressions": [
"([COLUMN 0] * [COLUMN 1]) / [COLUMN 2] as promo_revenue"
],
"Inputs": [
{
"OperatorType": "Aggregate",
"Variant": "Scalar",
"Aggregates": "random(0), sum(1) AS sum(case when p_type like 'PROMO%' then l_extendedprice * (1 - l_discount) else 0 end), sum(2) AS sum(l_extendedprice * (1 - l_discount))",
"Inputs": [
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:0,R:0,L:3",
"JoinVars": {
"l_discount": 2,
"l_extendedprice": 1,
"l_partkey": 4
},
"TableName": "lineitem_part",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "main",
"Sharded": true
},
"FieldQuery": "select 100.00, l_extendedprice, l_discount, l_extendedprice * (1 - l_discount), l_partkey from lineitem where 1 != 1",
"Query": "select 100.00, l_extendedprice, l_discount, l_extendedprice * (1 - l_discount), l_partkey from lineitem where l_shipdate >= date('1995-09-01') and l_shipdate < date('1995-09-01') + interval '1' month",
"Table": "lineitem"
},
{
"OperatorType": "Route",
"Variant": "EqualUnique",
"Keyspace": {
"Name": "main",
"Sharded": true
},
"FieldQuery": "select case when p_type like 'PROMO%' then :l_extendedprice * (1 - :l_discount) else 0 end from part where 1 != 1",
"Query": "select case when p_type like 'PROMO%' then :l_extendedprice * (1 - :l_discount) else 0 end from part where p_partkey = :l_partkey",
"Table": "part",
"Values": [
":l_partkey"
],
"Vindex": "hash"
}
]
}
]
}
]
},
"TablesUsed": [
"main.lineitem",
"main.part"
]
}
},
{
"comment": "TPC-H query 15 view\n#\"with revenue0(supplier_no, total_revenue) as (select l_suppkey, sum(l_extendedprice * (1 - l_discount)) from lineitem where l_shipdate >= date('1996-01-01') and l_shipdate < date('1996-01-01') + interval '3' month group by l_suppkey )\"\n#\"syntax error at position 236\"\n#Gen4 plan same as above\n# TPC-H query 15",
Expand Down
Loading

0 comments on commit 509788b

Please sign in to comment.