Skip to content

Commit

Permalink
Merge pull request #8998 from planetscale/subquery-deps-fix
Browse files Browse the repository at this point in the history
gen4: Subquery dependencies update
  • Loading branch information
systay authored Oct 14, 2021
2 parents 4091720 + 438a96d commit 180ccfb
Show file tree
Hide file tree
Showing 26 changed files with 602 additions and 85 deletions.
2 changes: 1 addition & 1 deletion go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -1971,7 +1971,7 @@ type (
ArgName string
HasValuesArg string
OpCode int // this should really be engine.PulloutOpCode, but we cannot depend on engine :(
Subquery SelectStatement
Subquery *Subquery
OtherSide Expr // represents the side of the comparison, this field will be nil if Original is not a comparison
NeedsRewrite bool // tells whether we need to rewrite this subquery to Original or not
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/sqlparser/ast_clone.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion go/vt/sqlparser/ast_equals.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions go/vt/sqlparser/ast_rewrite.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion go/vt/sqlparser/ast_visit.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 2 additions & 4 deletions go/vt/sqlparser/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion go/vt/sqlparser/precedence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestNotInSubqueryPrecedence(t *testing.T) {
ArgName: "arg1",
HasValuesArg: "has_values1",
OpCode: 1,
Subquery: subq.Select,
Subquery: subq,
OtherSide: cmp.Left,
}
not.Expr = extracted
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/abstract/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func createOperatorFromSelect(sel *sqlparser.Select, semTable *semantics.SemTabl
if len(semTable.SubqueryMap[sel]) > 0 {
resultantOp = &SubQuery{}
for _, sq := range semTable.SubqueryMap[sel] {
subquerySelectStatement, isSel := sq.Subquery.(*sqlparser.Select)
subquerySelectStatement, isSel := sq.Subquery.Select.(*sqlparser.Select)
if !isSel {
return nil, semantics.Gen4NotSupportedF("UNION in subquery")
}
Expand Down
4 changes: 1 addition & 3 deletions go/vt/vtgate/planbuilder/abstract/operator_test_data.txt
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,7 @@ SubQuery: {
}]
Outer: QueryGraph: {
Tables:
TableSet{0}:`user` AS u
JoinPredicates:
TableSet{0,1} - u.id = (select id from user_extra where id = u.id)
TableSet{0}:`user` AS u where u.id = (select id from user_extra where id = u.id)
}
}

Expand Down
5 changes: 2 additions & 3 deletions go/vt/vtgate/planbuilder/abstract/querygraph.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ limitations under the License.
package abstract

import (
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/semantics"
)

Expand Down Expand Up @@ -139,7 +137,8 @@ func (qg *QueryGraph) collectPredicate(predicate sqlparser.Expr, semTable *seman
case 1:
found := qg.addToSingleTable(deps, predicate)
if !found {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "table %v for predicate %v not found", deps, sqlparser.String(predicate))
// this could be a predicate that only has dependencies from outside this QG
qg.addJoinPredicates(deps, predicate)
}
default:
qg.addJoinPredicates(deps, predicate)
Expand Down
84 changes: 84 additions & 0 deletions go/vt/vtgate/planbuilder/gen4_planner_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
Copyright 2021 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package planbuilder

import (
"testing"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/semantics"
"vitess.io/vitess/go/vt/vtgate/vindexes"
)

func TestBindingSubquery(t *testing.T) {
testcases := []struct {
query string
requiredTableSet semantics.TableSet
extractor func(p *sqlparser.Select) sqlparser.Expr
rewrite bool
}{
{
query: "select (select col from tabl limit 1) as a from foo join tabl order by a + 1",
requiredTableSet: semantics.EmptyTableSet(),
extractor: func(sel *sqlparser.Select) sqlparser.Expr {
return sel.OrderBy[0].Expr
},
rewrite: true,
}, {
query: "select t.a from (select (select col from tabl limit 1) as a from foo join tabl) t",
requiredTableSet: semantics.EmptyTableSet(),
extractor: func(sel *sqlparser.Select) sqlparser.Expr {
return extractExpr(sel, 0)
},
rewrite: true,
}, {
query: "select (select col from tabl where foo.id = 4 limit 1) as a from foo",
requiredTableSet: semantics.SingleTableSet(0),
extractor: func(sel *sqlparser.Select) sqlparser.Expr {
return extractExpr(sel, 0)
},
rewrite: false,
},
}
for _, testcase := range testcases {
t.Run(testcase.query, func(t *testing.T) {
parse, err := sqlparser.Parse(testcase.query)
require.NoError(t, err)
selStmt := parse.(*sqlparser.Select)
semTable, err := semantics.Analyze(selStmt, "d", &semantics.FakeSI{
Tables: map[string]*vindexes.Table{
"tabl": {Name: sqlparser.NewTableIdent("tabl")},
"foo": {Name: sqlparser.NewTableIdent("foo")},
},
})
require.NoError(t, err)
if testcase.rewrite {
err = queryRewrite(semTable, sqlparser.NewReservedVars("vt", make(sqlparser.BindVars)), selStmt)
require.NoError(t, err)
}
expr := testcase.extractor(selStmt)
tableset := semTable.RecursiveDeps(expr)
require.Equal(t, testcase.requiredTableSet, tableset)
})
}
}

func extractExpr(in *sqlparser.Select, idx int) sqlparser.Expr {
return in.SelectExprs[idx].(*sqlparser.AliasedExpr).Expr
}
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/querytree_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func transformSubqueryTree(ctx *planningContext, n *subqueryTree) (logicalPlan,
if err != nil {
return nil, err
}
innerPlan, err = planHorizon(ctx, innerPlan, n.extracted.Subquery)
innerPlan, err = planHorizon(ctx, innerPlan, n.extracted.Subquery.Select)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/rewrite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func TestHavingRewrite(t *testing.T) {
assert.Equal(t, len(tcase.sqs), len(squeries), "number of subqueries not matched")
}
for _, sq := range squeries {
assert.Equal(t, tcase.sqs[sq.ArgName], sqlparser.String(sq.Subquery))
assert.Equal(t, tcase.sqs[sq.ArgName], sqlparser.String(sq.Subquery.Select))
}
})
}
Expand Down
24 changes: 17 additions & 7 deletions go/vt/vtgate/planbuilder/route_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (c planningContext) isSubQueryToReplace(e sqlparser.Expr) bool {
return false
}
for _, extractedSubq := range c.semTable.GetSubqueryNeedingRewrite() {
if extractedSubq.NeedsRewrite && sqlparser.EqualsRefOfSubquery(&sqlparser.Subquery{Select: extractedSubq.Subquery}, ext) {
if extractedSubq.NeedsRewrite && sqlparser.EqualsRefOfSubquery(extractedSubq.Subquery, ext) {
return true
}
}
Expand Down Expand Up @@ -199,7 +199,7 @@ func tryMergeSubQuery(ctx *planningContext, outer, subq queryTree, subQueryInner
if err != nil {
return nil, err
}
return rt, rewriteSubqueryDependenciesForJoin(ctx, outerTree.rhs, outerTree, subQueryInner)
return rt, rewriteColumnsInSubqueryForJoin(ctx, outerTree.rhs, outerTree, subQueryInner)
}
merged, err = tryMergeSubQuery(ctx, outerTree.lhs, subq, subQueryInner, joinPredicates, newMergefunc)
if err != nil {
Expand All @@ -215,7 +215,7 @@ func tryMergeSubQuery(ctx *planningContext, outer, subq queryTree, subQueryInner
if err != nil {
return nil, err
}
return rt, rewriteSubqueryDependenciesForJoin(ctx, outerTree.lhs, outerTree, subQueryInner)
return rt, rewriteColumnsInSubqueryForJoin(ctx, outerTree.lhs, outerTree, subQueryInner)
}
merged, err = tryMergeSubQuery(ctx, outerTree.rhs, subq, subQueryInner, joinPredicates, newMergefunc)
if err != nil {
Expand All @@ -231,13 +231,15 @@ func tryMergeSubQuery(ctx *planningContext, outer, subq queryTree, subQueryInner
}
}

// rewriteColumnsInSubqueryForJoin rewrites the columns that appear from the other side
// of the join. For example, let's say we merged a subquery on the right side of a join tree
// If it was using any columns from the left side then they need to be replaced by bind variables supplied
// from that side.
// outerTree is the joinTree within whose children the subquery lives in
// the child of joinTree which does not contain the subquery is the otherTree
func rewriteSubqueryDependenciesForJoin(ctx *planningContext, otherTree queryTree, outerTree *joinTree, subQueryInner *abstract.SubQueryInner) error {
// first we find the other side of the tree by comparing the tableIDs
// other side is RHS if the subquery is in the LHS, otherwise it is LHS
func rewriteColumnsInSubqueryForJoin(ctx *planningContext, otherTree queryTree, outerTree *joinTree, subQueryInner *abstract.SubQueryInner) error {
var rewriteError error
// go over the entire where expression in the subquery
// go over the entire expression in the subquery
sqlparser.Rewrite(subQueryInner.ExtractedSubquery.Original, func(cursor *sqlparser.Cursor) bool {
sqlNode := cursor.Node()
switch node := sqlNode.(type) {
Expand Down Expand Up @@ -266,6 +268,14 @@ func rewriteSubqueryDependenciesForJoin(ctx *planningContext, otherTree queryTre
return true
}, nil)

// update the dependencies for the subquery by removing the dependencies from the otherTree
tableSet := ctx.semTable.Direct[subQueryInner.ExtractedSubquery.Subquery]
tableSet.RemoveInPlace(otherTree.tableID())
ctx.semTable.Direct[subQueryInner.ExtractedSubquery.Subquery] = tableSet
tableSet = ctx.semTable.Recursive[subQueryInner.ExtractedSubquery.Subquery]
tableSet.RemoveInPlace(otherTree.tableID())
ctx.semTable.Recursive[subQueryInner.ExtractedSubquery.Subquery] = tableSet

// return any error while rewriting
return rewriteError
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/routetree.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ func (rp *routeTree) searchForNewVindexes(ctx *planningContext, predicates []sql
// using the node.subquery which is the rewritten version of our subquery
cmp := &sqlparser.ComparisonExpr{
Left: node.OtherSide,
Right: &sqlparser.Subquery{Select: node.Subquery},
Right: &sqlparser.Subquery{Select: node.Subquery.Select},
Operator: originalCmp.Operator,
}
found, exitEarly, err := rp.planComparison(ctx, cmp)
Expand Down
Loading

0 comments on commit 180ccfb

Please sign in to comment.