diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index b9a1ea1b6d3..afa130eb279 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -991,10 +991,8 @@ func (node *Select) AddHaving(expr Expr) { } return } - node.Having.Expr = &AndExpr{ - Left: node.Having.Expr, - Right: expr, - } + exprs := SplitAndExpression(nil, node.Having.Expr) + node.Having.Expr = AndExpressions(append(exprs, expr)...) } // AddGroupBy adds a grouping expression, unless it's already present diff --git a/go/vt/sqlparser/ast_test.go b/go/vt/sqlparser/ast_test.go index 71c56594875..3c518ee7886 100644 --- a/go/vt/sqlparser/ast_test.go +++ b/go/vt/sqlparser/ast_test.go @@ -57,32 +57,22 @@ func TestSelect(t *testing.T) { sel.AddWhere(expr) buf := NewTrackedBuffer(nil) sel.Where.Format(buf) - want := " where a = 1" - if buf.String() != want { - t.Errorf("where: %q, want %s", buf.String(), want) - } + assert.Equal(t, " where a = 1", buf.String()) sel.AddWhere(expr) buf = NewTrackedBuffer(nil) sel.Where.Format(buf) - want = " where a = 1" - if buf.String() != want { - t.Errorf("where: %q, want %s", buf.String(), want) - } + assert.Equal(t, " where a = 1", buf.String()) + sel = &Select{} sel.AddHaving(expr) buf = NewTrackedBuffer(nil) sel.Having.Format(buf) - want = " having a = 1" - if buf.String() != want { - t.Errorf("having: %q, want %s", buf.String(), want) - } + assert.Equal(t, " having a = 1", buf.String()) + sel.AddHaving(expr) buf = NewTrackedBuffer(nil) sel.Having.Format(buf) - want = " having a = 1 and a = 1" - if buf.String() != want { - t.Errorf("having: %q, want %s", buf.String(), want) - } + assert.Equal(t, " having a = 1", buf.String()) tree, err = Parse("select * from t where a = 1 or b = 1") require.NoError(t, err) @@ -91,18 +81,14 @@ func TestSelect(t *testing.T) { sel.AddWhere(expr) buf = NewTrackedBuffer(nil) sel.Where.Format(buf) - want = " where a = 1 or b = 1" - if buf.String() != want { - t.Errorf("where: %q, want %s", buf.String(), want) - } + assert.Equal(t, " where a = 1 or b = 1", buf.String()) + sel = &Select{} sel.AddHaving(expr) buf = NewTrackedBuffer(nil) sel.Having.Format(buf) - want = " having a = 1 or b = 1" - if buf.String() != want { - t.Errorf("having: %q, want %s", buf.String(), want) - } + assert.Equal(t, " having a = 1 or b = 1", buf.String()) + } func TestUpdate(t *testing.T) { diff --git a/go/vt/vtgate/planbuilder/operator_to_query.go b/go/vt/vtgate/planbuilder/operator_to_query.go index 87dec2d8526..1e2d4ff54c4 100644 --- a/go/vt/vtgate/planbuilder/operator_to_query.go +++ b/go/vt/vtgate/planbuilder/operator_to_query.go @@ -89,7 +89,7 @@ func buildQuery(op abstract.PhysicalOperator, qb *queryBuilder) { sel.Limit = opQuery.Limit sel.OrderBy = opQuery.OrderBy sel.GroupBy = opQuery.GroupBy - sel.Having = opQuery.Having + sel.Having = mergeHaving(sel.Having, opQuery.Having) sel.SelectExprs = opQuery.SelectExprs qb.addTableExpr(op.Alias, op.Alias, op.TableID(), &sqlparser.DerivedTable{ Select: sel, @@ -161,12 +161,16 @@ func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) { } sel := qb.sel.(*sqlparser.Select) - if sel.Where == nil { - sel.AddWhere(expr) - return + _, isSubQuery := expr.(*sqlparser.ExtractedSubquery) + var addPred func(sqlparser.Expr) + + if sqlparser.ContainsAggregation(expr) && !isSubQuery { + addPred = sel.AddHaving + } else { + addPred = sel.AddWhere } for _, exp := range sqlparser.SplitAndExpression(nil, expr) { - sel.AddWhere(exp) + addPred(exp) } } @@ -288,3 +292,17 @@ func (ts *tableSorter) Less(i, j int) bool { func (ts *tableSorter) Swap(i, j int) { ts.sel.From[i], ts.sel.From[j] = ts.sel.From[j], ts.sel.From[i] } + +func mergeHaving(h1, h2 *sqlparser.Where) *sqlparser.Where { + switch { + case h1 == nil && h2 == nil: + return nil + case h1 == nil: + return h2 + case h2 == nil: + return h1 + default: + h1.Expr = sqlparser.AndExpressions(h1.Expr, h2.Expr) + return h1 + } +} diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index 1041fadba1d..d5b46f3eb12 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -4996,5 +4996,42 @@ "user.user" ] } + }, + { + "comment": "when pushing predicates into derived tables, make sure to put them in HAVING when they contain aggregations", + "query": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as count from user_extra where localDate > :v1 group by user_id, flowId order by null) as t1 where count >= :v2", + "v3-plan": { + "QueryType": "SELECT", + "Original": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as count from user_extra where localDate > :v1 group by user_id, flowId order by null) as t1 where count >= :v2", + "Instructions": { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as `count` from user_extra where 1 != 1 group by user_id, flowId) as t1 where 1 != 1", + "Query": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as `count` from user_extra where localDate > :v1 group by user_id, flowId order by null) as t1 where `count` >= :v2", + "Table": "user_extra" + } + }, + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as count from user_extra where localDate > :v1 group by user_id, flowId order by null) as t1 where count >= :v2", + "Instructions": { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as `count` from user_extra where 1 != 1 group by user_id, flowId) as t1 where 1 != 1", + "Query": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as `count` from user_extra where localDate > :v1 group by user_id, flowId having count(*) >= :v2 order by null) as t1", + "Table": "user_extra" + }, + "TablesUsed": [ + "user.user_extra" + ] + } } ] \ No newline at end of file