Skip to content

Commit

Permalink
[planner bugfix] add expressions to HAVING (vitessio#12668)
Browse files Browse the repository at this point in the history
* [planner bugfix] add expressions to HAVING

When a predicate contains aggregation, it should not
be added to the WHERE clause. It should go to the

Signed-off-by: Andres Taylor <[email protected]>

* update test expecteations

Signed-off-by: Andres Taylor <[email protected]>

---------

Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay authored and frouioui committed Mar 22, 2023
1 parent dc0bd69 commit c86be70
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 33 deletions.
6 changes: 2 additions & 4 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1010,10 +1010,8 @@ func (node *Select) AddHaving(expr Expr) {
}
return
}
node.Having.Expr = &AndExpr{
Left: node.Having.Expr,
Right: expr,
}
exprs := SplitAndExpression(nil, node.Having.Expr)
node.Having.Expr = AndExpressions(append(exprs, expr)...)
}

// AddGroupBy adds a grouping expression, unless it's already present
Expand Down
34 changes: 10 additions & 24 deletions go/vt/sqlparser/ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,32 +57,22 @@ func TestSelect(t *testing.T) {
sel.AddWhere(expr)
buf := NewTrackedBuffer(nil)
sel.Where.Format(buf)
want := " where a = 1"
if buf.String() != want {
t.Errorf("where: %q, want %s", buf.String(), want)
}
assert.Equal(t, " where a = 1", buf.String())
sel.AddWhere(expr)
buf = NewTrackedBuffer(nil)
sel.Where.Format(buf)
want = " where a = 1"
if buf.String() != want {
t.Errorf("where: %q, want %s", buf.String(), want)
}
assert.Equal(t, " where a = 1", buf.String())

sel = &Select{}
sel.AddHaving(expr)
buf = NewTrackedBuffer(nil)
sel.Having.Format(buf)
want = " having a = 1"
if buf.String() != want {
t.Errorf("having: %q, want %s", buf.String(), want)
}
assert.Equal(t, " having a = 1", buf.String())

sel.AddHaving(expr)
buf = NewTrackedBuffer(nil)
sel.Having.Format(buf)
want = " having a = 1 and a = 1"
if buf.String() != want {
t.Errorf("having: %q, want %s", buf.String(), want)
}
assert.Equal(t, " having a = 1", buf.String())

tree, err = Parse("select * from t where a = 1 or b = 1")
require.NoError(t, err)
Expand All @@ -91,18 +81,14 @@ func TestSelect(t *testing.T) {
sel.AddWhere(expr)
buf = NewTrackedBuffer(nil)
sel.Where.Format(buf)
want = " where a = 1 or b = 1"
if buf.String() != want {
t.Errorf("where: %q, want %s", buf.String(), want)
}
assert.Equal(t, " where a = 1 or b = 1", buf.String())

sel = &Select{}
sel.AddHaving(expr)
buf = NewTrackedBuffer(nil)
sel.Having.Format(buf)
want = " having a = 1 or b = 1"
if buf.String() != want {
t.Errorf("having: %q, want %s", buf.String(), want)
}
assert.Equal(t, " having a = 1 or b = 1", buf.String())

}

func TestUpdate(t *testing.T) {
Expand Down
28 changes: 23 additions & 5 deletions go/vt/vtgate/planbuilder/operators/SQL_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,16 @@ func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) {
}

sel := qb.sel.(*sqlparser.Select)
if sel.Where == nil {
sel.AddWhere(expr)
return
_, isSubQuery := expr.(*sqlparser.ExtractedSubquery)
var addPred func(sqlparser.Expr)

if sqlparser.ContainsAggregation(expr) && !isSubQuery {
addPred = sel.AddHaving
} else {
addPred = sel.AddWhere
}
for _, exp := range sqlparser.SplitAndExpression(nil, expr) {
sel.AddWhere(exp)
addPred(exp)
}
}

Expand Down Expand Up @@ -349,7 +353,7 @@ func buildQuery(op ops.Operator, qb *queryBuilder) error {
sel.Limit = opQuery.Limit
sel.OrderBy = opQuery.OrderBy
sel.GroupBy = opQuery.GroupBy
sel.Having = opQuery.Having
sel.Having = mergeHaving(sel.Having, opQuery.Having)
sel.SelectExprs = opQuery.SelectExprs
qb.addTableExpr(op.Alias, op.Alias, TableID(op), &sqlparser.DerivedTable{
Select: sel,
Expand Down Expand Up @@ -380,3 +384,17 @@ func buildQuery(op ops.Operator, qb *queryBuilder) error {
}
return nil
}

func mergeHaving(h1, h2 *sqlparser.Where) *sqlparser.Where {
switch {
case h1 == nil && h2 == nil:
return nil
case h1 == nil:
return h2
case h2 == nil:
return h1
default:
h1.Expr = sqlparser.AndExpressions(h1.Expr, h2.Expr)
return h1
}
}
37 changes: 37 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/aggr_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -5004,5 +5004,42 @@
"user.user"
]
}
},
{
"comment": "when pushing predicates into derived tables, make sure to put them in HAVING when they contain aggregations",
"query": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as count from user_extra where localDate > :v1 group by user_id, flowId order by null) as t1 where count >= :v2",
"v3-plan": {
"QueryType": "SELECT",
"Original": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as count from user_extra where localDate > :v1 group by user_id, flowId order by null) as t1 where count >= :v2",
"Instructions": {
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as `count` from user_extra where 1 != 1 group by user_id, flowId) as t1 where 1 != 1",
"Query": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as `count` from user_extra where localDate > :v1 group by user_id, flowId order by null) as t1 where `count` >= :v2",
"Table": "user_extra"
}
},
"gen4-plan": {
"QueryType": "SELECT",
"Original": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as count from user_extra where localDate > :v1 group by user_id, flowId order by null) as t1 where count >= :v2",
"Instructions": {
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as `count` from user_extra where 1 != 1 group by user_id, flowId) as t1 where 1 != 1",
"Query": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as `count` from user_extra where localDate > :v1 group by user_id, flowId having count(*) >= :v2 order by null) as t1",
"Table": "user_extra"
},
"TablesUsed": [
"user.user_extra"
]
}
}
]

0 comments on commit c86be70

Please sign in to comment.