From 86d15f99614033171dbe0b8fdb65703a2de9519d Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Sun, 17 Jan 2021 20:54:12 +0530 Subject: [PATCH 1/3] handling subqueries in query graph Signed-off-by: Harshit Gangal --- go/vt/vtgate/planbuilder/querygraph.go | 45 ++++++++++++++++++++- go/vt/vtgate/planbuilder/querygraph_test.go | 29 +++++++++---- go/vt/vtgate/semantics/analyzer.go | 2 +- 3 files changed, 66 insertions(+), 10 deletions(-) diff --git a/go/vt/vtgate/planbuilder/querygraph.go b/go/vt/vtgate/planbuilder/querygraph.go index 28641f495b0..7931e952e56 100644 --- a/go/vt/vtgate/planbuilder/querygraph.go +++ b/go/vt/vtgate/planbuilder/querygraph.go @@ -41,6 +41,8 @@ type ( // noDeps contains the predicates that can be evaluated anywhere. noDeps sqlparser.Expr + + subqueries map[*sqlparser.Subquery][]*queryGraph } // queryTable is a single FROM table, including all predicates particular to this table @@ -79,9 +81,37 @@ func createQGFromSelect(sel *sqlparser.Select, semTable *semantics.SemTable) (*q return qg, nil } +func createQGFromSelectStatement(selStmt sqlparser.SelectStatement, semTable *semantics.SemTable) ([]*queryGraph, error) { + switch stmt := selStmt.(type) { + case *sqlparser.Select: + qg, err := createQGFromSelect(stmt, semTable) + if err != nil { + return nil, err + } + return []*queryGraph{qg}, err + case *sqlparser.Union: + qg, err := createQGFromSelectStatement(stmt.FirstStatement, semTable) + if err != nil { + return nil, err + } + for _, sel := range stmt.UnionSelects { + qgr, err := createQGFromSelectStatement(sel.Statement, semTable) + if err != nil { + return nil, err + } + qg = append(qg, qgr...) + } + case *sqlparser.ParenSelect: + return createQGFromSelectStatement(stmt.Select, semTable) + } + + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: not reachable %T", selStmt) +} + func newQueryGraph() *queryGraph { return &queryGraph{ crossTable: map[semantics.TableSet][]sqlparser.Expr{}, + subqueries: map[*sqlparser.Subquery][]*queryGraph{}, } } @@ -156,7 +186,20 @@ func (qg *queryGraph) collectPredicate(predicate sqlparser.Expr, semTable *seman } qg.crossTable[deps] = allPredicates } - return nil + err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + switch subQuery := node.(type) { + case *sqlparser.Subquery: + + qgr, err := createQGFromSelectStatement(subQuery.Select, semTable) + if err != nil { + return false, err + } + qg.subqueries[subQuery] = qgr + } + return true, nil + }, predicate) + + return err } func (qg *queryGraph) addToSingleTable(table semantics.TableSet, predicate sqlparser.Expr) bool { diff --git a/go/vt/vtgate/planbuilder/querygraph_test.go b/go/vt/vtgate/planbuilder/querygraph_test.go index df16ab022e5..d3cd6dab51c 100644 --- a/go/vt/vtgate/planbuilder/querygraph_test.go +++ b/go/vt/vtgate/planbuilder/querygraph_test.go @@ -26,7 +26,6 @@ import ( "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/semantics" - "vitess.io/vitess/go/vt/vtgate/vindexes" ) type tcase struct { @@ -126,6 +125,25 @@ var tcases = []tcase{{ Right: equals(literalInt(12), literalString("12")), }, }, +}, { + input: "select 1 from t where exists (select 1)", + output: &queryGraph{ + tables: []*queryTable{{ + tableID: 1, + alias: tableAlias("t"), + table: tableName("t"), + }}, + crossTable: map[semantics.TableSet][]sqlparser.Expr{}, + noDeps: &sqlparser.ExistsExpr{ + Subquery: &sqlparser.Subquery{ + Select: &sqlparser.Select{ + SelectExprs: sqlparser.SelectExprs{&sqlparser.AliasedExpr{Expr: newIntLiteral("1")}}, + From: sqlparser.TableExprs{&sqlparser.AliasedTableExpr{Expr: sqlparser.TableName{Name: sqlparser.NewTableIdent("dual")}}}, + }, + }, + }, + subqueries: map[*sqlparser.Subquery][]*queryGraph{}, + }, }} func literalInt(i int) *sqlparser.Literal { @@ -163,12 +181,6 @@ func tableName(name string) sqlparser.TableName { return sqlparser.TableName{Name: sqlparser.NewTableIdent(name)} } -type schemaInf struct{} - -func (node *schemaInf) FindTable(tablename sqlparser.TableName) (*vindexes.Table, error) { - return nil, nil -} - func TestQueryGraph(t *testing.T) { for _, tc := range tcases { sql := tc.input @@ -180,6 +192,7 @@ func TestQueryGraph(t *testing.T) { qgraph, err := createQGFromSelect(tree.(*sqlparser.Select), semTable) require.NoError(t, err) mustMatch(t, tc.output, qgraph, "incorrect query graph") + }) } } @@ -190,5 +203,5 @@ var mustMatch = utils.MustMatchFn( queryTable{}, sqlparser.TableIdent{}, }, - []string{}, // ignored fields + []string{".subqueries"}, // ignored fields ) diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 6ac1a998e07..ebb23e181b0 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -53,7 +53,7 @@ func (a *analyzer) analyzeDown(cursor *sqlparser.Cursor) bool { a.err = err return false } - case *sqlparser.DerivedTable, *sqlparser.Subquery: + case *sqlparser.DerivedTable: a.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "%T not supported", node) case *sqlparser.TableExprs: // this has already been visited when we encountered the SELECT struct From 62b97e56474cf440395a105df60a5e72c02332d8 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Sun, 17 Jan 2021 20:54:41 +0530 Subject: [PATCH 2/3] handling subqueries in query graph Signed-off-by: Harshit Gangal --- go/vt/vtgate/planbuilder/querygraph.go | 1 + 1 file changed, 1 insertion(+) diff --git a/go/vt/vtgate/planbuilder/querygraph.go b/go/vt/vtgate/planbuilder/querygraph.go index 7931e952e56..0f8d09436e4 100644 --- a/go/vt/vtgate/planbuilder/querygraph.go +++ b/go/vt/vtgate/planbuilder/querygraph.go @@ -101,6 +101,7 @@ func createQGFromSelectStatement(selStmt sqlparser.SelectStatement, semTable *se } qg = append(qg, qgr...) } + return qg, nil case *sqlparser.ParenSelect: return createQGFromSelectStatement(stmt.Select, semTable) } From 5dbda83a0f70a3e6f14af80ec1f9279db9a2a7df Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Sun, 17 Jan 2021 19:40:34 +0100 Subject: [PATCH 3/3] added testString to queryGraph to make it easier to write unit tests Signed-off-by: Andres Taylor --- go/vt/vtgate/planbuilder/querygraph.go | 1 + go/vt/vtgate/planbuilder/querygraph_test.go | 296 +++++++++++--------- go/vt/vtgate/semantics/semantic_state.go | 12 + 3 files changed, 175 insertions(+), 134 deletions(-) diff --git a/go/vt/vtgate/planbuilder/querygraph.go b/go/vt/vtgate/planbuilder/querygraph.go index 0f8d09436e4..45b0e68ad30 100644 --- a/go/vt/vtgate/planbuilder/querygraph.go +++ b/go/vt/vtgate/planbuilder/querygraph.go @@ -42,6 +42,7 @@ type ( // noDeps contains the predicates that can be evaluated anywhere. noDeps sqlparser.Expr + // subqueries contains the subqueries that depend on this query graph subqueries map[*sqlparser.Subquery][]*queryGraph } diff --git a/go/vt/vtgate/planbuilder/querygraph_test.go b/go/vt/vtgate/planbuilder/querygraph_test.go index d3cd6dab51c..71c8c06989f 100644 --- a/go/vt/vtgate/planbuilder/querygraph_test.go +++ b/go/vt/vtgate/planbuilder/querygraph_test.go @@ -18,8 +18,12 @@ package planbuilder import ( "fmt" + "sort" + "strings" "testing" + "github.com/stretchr/testify/assert" + "vitess.io/vitess/go/test/utils" "github.com/stretchr/testify/require" @@ -29,137 +33,69 @@ import ( ) type tcase struct { - input string - output *queryGraph + input, output string } -var threeWayJoin = &queryGraph{ - tables: []*queryTable{{ - tableID: 1, - alias: tableAlias("t"), - table: tableName("t"), - predicates: []sqlparser.Expr{equals(colName("t", "name"), literalString("foo"))}, - }, { - tableID: 2, - alias: tableAlias("y"), - table: tableName("y"), - predicates: []sqlparser.Expr{equals(colName("y", "col"), literalInt(42))}, - }, { - tableID: 4, - alias: tableAlias("z"), - table: tableName("z"), - predicates: []sqlparser.Expr{equals(colName("z", "baz"), literalInt(101))}, - }}, - crossTable: map[semantics.TableSet][]sqlparser.Expr{ - 1 | 2: { - equals( - colName("t", "id"), - colName("y", "t_id"))}, - 1 | 4: { - equals( - colName("t", "id"), - colName("z", "t_id"))}}} - var tcases = []tcase{{ input: "select * from t", - output: &queryGraph{ - tables: []*queryTable{{ - tableID: 1, - alias: tableAlias("t"), - table: tableName("t"), - }}, - crossTable: map[semantics.TableSet][]sqlparser.Expr{}, - }, + output: `{ +Tables: + 1:t +}`, }, { input: "select t.c from t,y,z where t.c = y.c and (t.a = z.a or t.a = y.a) and 1 < 2", - output: &queryGraph{ - tables: []*queryTable{{ - tableID: 1, - alias: tableAlias("t"), - table: tableName("t"), - }, { - tableID: 2, - alias: tableAlias("y"), - table: tableName("y"), - }, { - tableID: 4, - alias: tableAlias("z"), - table: tableName("z"), - }}, - crossTable: map[semantics.TableSet][]sqlparser.Expr{ - 1 | 2: { - equals( - colName("t", "c"), - colName("y", "c"))}, - 1 | 2 | 4: { - or( - equals( - colName("t", "a"), - colName("z", "a")), - equals( - colName("t", "a"), - colName("y", "a")))}, - }, - noDeps: &sqlparser.ComparisonExpr{ - Operator: sqlparser.LessThanOp, - Left: literalInt(1), - Right: literalInt(2)}, - }, + output: `{ +Tables: + 1:t + 2:y + 4:z +JoinPredicates: + 1:2 - t.c = y.c + 1:2:4 - t.a = z.a or t.a = y.a +ForAll: 1 < 2 +}`, }, { - input: "select t.c from t join y on t.id = y.t_id join z on t.id = z.t_id where t.name = 'foo' and y.col = 42 and z.baz = 101", - output: threeWayJoin, + input: "select t.c from t join y on t.id = y.t_id join z on t.id = z.t_id where t.name = 'foo' and y.col = 42 and z.baz = 101", + output: `{ +Tables: + 1:t where t.` + "`name`" + ` = 'foo' + 2:y where y.col = 42 + 4:z where z.baz = 101 +JoinPredicates: + 1:2 - t.id = y.t_id + 1:4 - t.id = z.t_id +}`, }, { - input: "select t.c from t,y,z where t.name = 'foo' and y.col = 42 and z.baz = 101 and t.id = y.t_id and t.id = z.t_id", - output: threeWayJoin, + input: "select t.c from t,y,z where t.name = 'foo' and y.col = 42 and z.baz = 101 and t.id = y.t_id and t.id = z.t_id", + output: `{ +Tables: + 1:t where t.` + "`name`" + ` = 'foo' + 2:y where y.col = 42 + 4:z where z.baz = 101 +JoinPredicates: + 1:2 - t.id = y.t_id + 1:4 - t.id = z.t_id +}`, }, { input: "select 1 from t where '1' = 1 and 12 = '12'", - output: &queryGraph{ - tables: []*queryTable{{ - tableID: 1, - alias: tableAlias("t"), - table: tableName("t"), - }}, - crossTable: map[semantics.TableSet][]sqlparser.Expr{}, - noDeps: &sqlparser.AndExpr{ - Left: equals(literalString("1"), literalInt(1)), - Right: equals(literalInt(12), literalString("12")), - }, - }, + output: `{ +Tables: + 1:t +ForAll: '1' = 1 and 12 = '12' +}`, }, { input: "select 1 from t where exists (select 1)", - output: &queryGraph{ - tables: []*queryTable{{ - tableID: 1, - alias: tableAlias("t"), - table: tableName("t"), - }}, - crossTable: map[semantics.TableSet][]sqlparser.Expr{}, - noDeps: &sqlparser.ExistsExpr{ - Subquery: &sqlparser.Subquery{ - Select: &sqlparser.Select{ - SelectExprs: sqlparser.SelectExprs{&sqlparser.AliasedExpr{Expr: newIntLiteral("1")}}, - From: sqlparser.TableExprs{&sqlparser.AliasedTableExpr{Expr: sqlparser.TableName{Name: sqlparser.NewTableIdent("dual")}}}, - }, - }, - }, - subqueries: map[*sqlparser.Subquery][]*queryGraph{}, - }, -}} - -func literalInt(i int) *sqlparser.Literal { - return &sqlparser.Literal{Type: sqlparser.IntVal, Val: []byte(fmt.Sprintf("%d", i))} -} - -func literalString(s string) *sqlparser.Literal { - return &sqlparser.Literal{Type: sqlparser.StrVal, Val: []byte(s)} -} - -func or(left, right sqlparser.Expr) sqlparser.Expr { - return &sqlparser.OrExpr{ - Left: left, - Right: right, + output: `{ +Tables: + 1:t +ForAll: exists (select 1 from dual) +SubQueries: +(select 1 from dual) - { + Tables: + 2:dual } -} +}`, +}} func equals(left, right sqlparser.Expr) sqlparser.Expr { return &sqlparser.ComparisonExpr{ @@ -173,35 +109,127 @@ func colName(table, column string) *sqlparser.ColName { return &sqlparser.ColName{Name: sqlparser.NewColIdent(column), Qualifier: tableName(table)} } -func tableAlias(name string) *sqlparser.AliasedTableExpr { - return &sqlparser.AliasedTableExpr{Expr: sqlparser.TableName{Name: sqlparser.NewTableIdent(name)}} -} - func tableName(name string) sqlparser.TableName { return sqlparser.TableName{Name: sqlparser.NewTableIdent(name)} } func TestQueryGraph(t *testing.T) { - for _, tc := range tcases { + for i, tc := range tcases { sql := tc.input - t.Run(sql, func(t *testing.T) { + t.Run(fmt.Sprintf("%d %s", i, sql), func(t *testing.T) { tree, err := sqlparser.Parse(sql) require.NoError(t, err) semTable, err := semantics.Analyse(tree) require.NoError(t, err) qgraph, err := createQGFromSelect(tree.(*sqlparser.Select), semTable) require.NoError(t, err) - mustMatch(t, tc.output, qgraph, "incorrect query graph") - + fmt.Println(qgraph.testString()) + assert.Equal(t, tc.output, qgraph.testString()) + utils.MustMatch(t, tc.output, qgraph.testString(), "incorrect query graph") }) } } -var mustMatch = utils.MustMatchFn( - []interface{}{ // types with unexported fields - queryGraph{}, - queryTable{}, - sqlparser.TableIdent{}, - }, - []string{".subqueries"}, // ignored fields -) +func TestString(t *testing.T) { + tree, err := sqlparser.Parse("select * from a,b join c on b.id = c.id where a.id = b.id and b.col IN (select 42) and func() = 'foo'") + require.NoError(t, err) + semTable, err := semantics.Analyse(tree) + require.NoError(t, err) + qgraph, err := createQGFromSelect(tree.(*sqlparser.Select), semTable) + require.NoError(t, err) + utils.MustMatch(t, `{ +Tables: + 1:a + 2:b where b.col in (select 42 from dual) + 4:c +JoinPredicates: + 1:2 - a.id = b.id + 2:4 - b.id = c.id +ForAll: func() = 'foo' +SubQueries: +(select 42 from dual) - { + Tables: + 8:dual + } +}`, qgraph.testString()) +} + +func (qt *queryTable) testString() string { + var alias string + if !qt.alias.As.IsEmpty() { + alias = " AS " + sqlparser.String(qt.alias.As) + } + var preds []string + for _, predicate := range qt.predicates { + preds = append(preds, sqlparser.String(predicate)) + } + var where string + if len(preds) > 0 { + where = " where " + strings.Join(preds, " and ") + } + + return fmt.Sprintf("\t%d:%s%s%s", qt.tableID, sqlparser.String(qt.table), alias, where) +} + +func (qg *queryGraph) testString() string { + return fmt.Sprintf(`{ +Tables: +%s%s%s%s +}`, strings.Join(qg.tableNames(), "\n"), qg.crossPredicateString(), qg.noDepsString(), qg.subqueriesString()) +} + +func (qg *queryGraph) crossPredicateString() string { + if len(qg.crossTable) == 0 { + return "" + } + var joinPreds []string + for deps, predicates := range qg.crossTable { + var tables []string + for _, id := range deps.Constituents() { + tables = append(tables, fmt.Sprintf("%d", id)) + } + var expressions []string + for _, expr := range predicates { + expressions = append(expressions, sqlparser.String(expr)) + } + tableConcat := strings.Join(tables, ":") + exprConcat := strings.Join(expressions, " and ") + joinPreds = append(joinPreds, fmt.Sprintf("\t%s - %s", tableConcat, exprConcat)) + } + sort.Strings(joinPreds) + return fmt.Sprintf("\nJoinPredicates:\n%s", strings.Join(joinPreds, "\n")) +} + +func (qg *queryGraph) tableNames() []string { + var tables []string + for _, t := range qg.tables { + tables = append(tables, t.testString()) + } + return tables +} + +func (qg *queryGraph) subqueriesString() string { + if len(qg.subqueries) == 0 { + return "" + } + var graphs []string + for sq, qgraphs := range qg.subqueries { + key := sqlparser.String(sq) + for _, inner := range qgraphs { + str := inner.testString() + splitInner := strings.Split(str, "\n") + for i, s := range splitInner { + splitInner[i] = "\t" + s + } + graphs = append(graphs, fmt.Sprintf("%s - %s", key, strings.Join(splitInner, "\n"))) + } + } + return fmt.Sprintf("\nSubQueries:\n%s", strings.Join(graphs, "\n")) +} + +func (qg *queryGraph) noDepsString() string { + if qg.noDeps == nil { + return "" + } + return fmt.Sprintf("\nForAll: %s", sqlparser.String(qg.noDeps)) +} diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 555748d4f14..6e770a5c188 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -109,6 +109,18 @@ func (ts TableSet) NumberOfTables() int { return count } +// Constituents returns an slice with all the +// individual tables in their own TableSet identifier +func (ts TableSet) Constituents() (result []TableSet) { + for i := 0; i < 64; i++ { + i2 := TableSet(1 << i) + if ts&i2 == i2 { + result = append(result, i2) + } + } + return +} + // Merge creates a TableSet that contains both inputs func (ts TableSet) Merge(other TableSet) TableSet { return ts | other