Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner operators refactoring #11680

Merged
merged 6 commits into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import (
"strconv"
"strings"

"vitess.io/vitess/go/vt/vtgate/planbuilder/operators/rewrite"

"vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops"

"vitess.io/vitess/go/sqltypes"

"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
Expand All @@ -39,7 +43,7 @@ import (
"vitess.io/vitess/go/vt/vterrors"
)

func transformToLogicalPlan(ctx *plancontext.PlanningContext, op operators.Operator, isRoot bool) (logicalPlan, error) {
func transformToLogicalPlan(ctx *plancontext.PlanningContext, op ops.Operator, isRoot bool) (logicalPlan, error) {
switch op := op.(type) {
case *operators.Route:
return transformRoutePlan(ctx, op)
Expand Down Expand Up @@ -167,7 +171,10 @@ func transformRoutePlan(ctx *plancontext.PlanningContext, op *operators.Route) (
values = op.Selected.Values
}
condition := getVindexPredicate(ctx, op)
sel := toSQL(ctx, op.Source)
sel, err := operators.ToSQL(ctx, op.Source)
if err != nil {
return nil, err
}
replaceSubQuery(ctx, sel)
return &routeGen4{
eroute: &engine.Route{
Expand Down Expand Up @@ -323,7 +330,7 @@ func getVindexPredicate(ctx *plancontext.PlanningContext, op *operators.Route) s

func getAllTableNames(op *operators.Route) ([]string, error) {
tableNameMap := map[string]any{}
err := operators.VisitTopDown(op, func(op operators.Operator) error {
err := rewrite.Visit(op, func(op ops.Operator) error {
tbl, isTbl := op.(*operators.Table)
var name string
if isTbl {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2021 The Vitess Authors.
Copyright 2022 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -14,106 +14,38 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package planbuilder
package operators

import (
"fmt"
"sort"
"strings"

"vitess.io/vitess/go/vt/log"
"vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops"

"vitess.io/vitess/go/vt/log"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"

"vitess.io/vitess/go/vt/vtgate/semantics"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/planbuilder/operators"
)

type queryBuilder struct {
ctx *plancontext.PlanningContext
sel sqlparser.SelectStatement
tableNames []string
}
type (
queryBuilder struct {
ctx *plancontext.PlanningContext
sel sqlparser.SelectStatement
tableNames []string
}
)

func toSQL(ctx *plancontext.PlanningContext, op operators.Operator) sqlparser.SelectStatement {
func ToSQL(ctx *plancontext.PlanningContext, op ops.Operator) (sqlparser.SelectStatement, error) {
q := &queryBuilder{ctx: ctx}
buildQuery(op, q)
q.sortTables()
return q.sel
}

func buildQuery(op operators.Operator, qb *queryBuilder) {
switch op := op.(type) {
case *operators.Table:
dbName := ""

if op.QTable.IsInfSchema {
dbName = op.QTable.Table.Qualifier.String()
}
qb.addTable(dbName, op.QTable.Table.Name.String(), op.QTable.Alias.As.String(), operators.TableID(op), op.QTable.Alias.Hints)
for _, pred := range op.QTable.Predicates {
qb.addPredicate(pred)
}
for _, name := range op.Columns {
qb.addProjection(&sqlparser.AliasedExpr{Expr: name})
}
case *operators.ApplyJoin:
buildQuery(op.LHS, qb)
// If we are going to add the predicate used in join here
// We should not add the predicate's copy of when it was split into
// two parts. To avoid this, we use the SkipPredicates map.
for _, expr := range qb.ctx.JoinPredicates[op.Predicate] {
qb.ctx.SkipPredicates[expr] = nil
}
qbR := &queryBuilder{ctx: qb.ctx}
buildQuery(op.RHS, qbR)
if op.LeftJoin {
qb.joinOuterWith(qbR, op.Predicate)
} else {
qb.joinInnerWith(qbR, op.Predicate)
}
case *operators.Filter:
buildQuery(op.Source, qb)
for _, pred := range op.Predicates {
qb.addPredicate(pred)
}
case *operators.Derived:
buildQuery(op.Source, qb)
sel := qb.sel.(*sqlparser.Select) // we can only handle SELECT in derived tables at the moment
qb.sel = nil
opQuery := sqlparser.RemoveKeyspace(op.Query).(*sqlparser.Select)
sel.Limit = opQuery.Limit
sel.OrderBy = opQuery.OrderBy
sel.GroupBy = opQuery.GroupBy
sel.Having = opQuery.Having
sel.SelectExprs = opQuery.SelectExprs
qb.addTableExpr(op.Alias, op.Alias, operators.TableID(op), &sqlparser.DerivedTable{
Select: sel,
}, nil, op.ColumnAliases)
for _, col := range op.Columns {
qb.addProjection(&sqlparser.AliasedExpr{Expr: col})
}
default:
panic(fmt.Sprintf("%T", op))
err := buildQuery(op, q)
if err != nil {
return nil, err
}
}

func (qb *queryBuilder) sortTables() {
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
sel, isSel := node.(*sqlparser.Select)
if !isSel {
return true, nil
}
ts := &tableSorter{
sel: sel,
tbl: qb.ctx.SemTable,
}
sort.Sort(ts)
return true, nil
}, qb.sel)

q.sortTables()
return q.sel, nil
}

func (qb *queryBuilder) addTable(db, tableName, alias string, tableID semantics.TableSet, hints sqlparser.IndexHints) {
Expand Down Expand Up @@ -256,6 +188,22 @@ func (qb *queryBuilder) hasTable(tableName string) bool {
return false
}

func (qb *queryBuilder) sortTables() {
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
sel, isSel := node.(*sqlparser.Select)
if !isSel {
return true, nil
}
ts := &tableSorter{
sel: sel,
tbl: qb.ctx.SemTable,
}
sort.Sort(ts)
return true, nil
}, qb.sel)

}

type tableSorter struct {
sel *sqlparser.Select
tbl *semantics.SemTable
Expand Down Expand Up @@ -286,3 +234,151 @@ func (ts *tableSorter) Less(i, j int) bool {
func (ts *tableSorter) Swap(i, j int) {
ts.sel.From[i], ts.sel.From[j] = ts.sel.From[j], ts.sel.From[i]
}

func (h *Horizon) toSQL(qb *queryBuilder) error {
err := stripDownQuery(h.Select, qb.sel)
if err != nil {
return err
}
sqlparser.Rewrite(qb.sel, func(cursor *sqlparser.Cursor) bool {
if aliasedExpr, ok := cursor.Node().(sqlparser.SelectExpr); ok {
removeKeyspaceFromSelectExpr(aliasedExpr)
}
return true
}, nil)
return nil
}

func removeKeyspaceFromSelectExpr(expr sqlparser.SelectExpr) {
switch expr := expr.(type) {
case *sqlparser.AliasedExpr:
sqlparser.RemoveKeyspaceFromColName(expr.Expr)
case *sqlparser.StarExpr:
expr.TableName.Qualifier = sqlparser.NewIdentifierCS("")
}
}

func stripDownQuery(from, to sqlparser.SelectStatement) error {
var err error

switch node := from.(type) {
case *sqlparser.Select:
toNode, ok := to.(*sqlparser.Select)
if !ok {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "AST did not match")
}
toNode.Distinct = node.Distinct
toNode.GroupBy = node.GroupBy
toNode.Having = node.Having
toNode.OrderBy = node.OrderBy
toNode.Comments = node.Comments
toNode.SelectExprs = node.SelectExprs
for _, expr := range toNode.SelectExprs {
removeKeyspaceFromSelectExpr(expr)
}
case *sqlparser.Union:
toNode, ok := to.(*sqlparser.Union)
if !ok {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "AST did not match")
}
err = stripDownQuery(node.Left, toNode.Left)
if err != nil {
return err
}
err = stripDownQuery(node.Right, toNode.Right)
if err != nil {
return err
}
toNode.OrderBy = node.OrderBy
default:
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: this should not happen - we have covered all implementations of SelectStatement %T", from)
}
return nil
}

func buildQuery(op ops.Operator, qb *queryBuilder) error {
switch op := op.(type) {
case *Table:
dbName := ""

if op.QTable.IsInfSchema {
dbName = op.QTable.Table.Qualifier.String()
}
qb.addTable(dbName, op.QTable.Table.Name.String(), op.QTable.Alias.As.String(), TableID(op), op.QTable.Alias.Hints)
for _, pred := range op.QTable.Predicates {
qb.addPredicate(pred)
}
for _, name := range op.Columns {
qb.addProjection(&sqlparser.AliasedExpr{Expr: name})
}
case *ApplyJoin:
err := buildQuery(op.LHS, qb)
if err != nil {
return err
}
// If we are going to add the predicate used in join here
// We should not add the predicate's copy of when it was split into
// two parts. To avoid this, we use the SkipPredicates map.
for _, expr := range qb.ctx.JoinPredicates[op.Predicate] {
qb.ctx.SkipPredicates[expr] = nil
}
qbR := &queryBuilder{ctx: qb.ctx}
err = buildQuery(op.RHS, qbR)
if err != nil {
return err
}
if op.LeftJoin {
qb.joinOuterWith(qbR, op.Predicate)
} else {
qb.joinInnerWith(qbR, op.Predicate)
}
case *Filter:
err := buildQuery(op.Source, qb)
if err != nil {
return err
}
for _, pred := range op.Predicates {
qb.addPredicate(pred)
}
case *Derived:
err := buildQuery(op.Source, qb)
if err != nil {
return err
}
sel := qb.sel.(*sqlparser.Select) // we can only handle SELECT in derived tables at the moment
qb.sel = nil
opQuery := sqlparser.RemoveKeyspace(op.Query).(*sqlparser.Select)
sel.Limit = opQuery.Limit
sel.OrderBy = opQuery.OrderBy
sel.GroupBy = opQuery.GroupBy
sel.Having = opQuery.Having
sel.SelectExprs = opQuery.SelectExprs
qb.addTableExpr(op.Alias, op.Alias, TableID(op), &sqlparser.DerivedTable{
Select: sel,
}, nil, op.ColumnAliases)
for _, col := range op.Columns {
qb.addProjection(&sqlparser.AliasedExpr{Expr: col})
}
case *Horizon:
err := buildQuery(op.Source, qb)
if err != nil {
return err
}

err = stripDownQuery(op.Select, qb.sel)
if err != nil {
return err
}
sqlparser.Rewrite(qb.sel, func(cursor *sqlparser.Cursor) bool {
if aliasedExpr, ok := cursor.Node().(sqlparser.SelectExpr); ok {
removeKeyspaceFromSelectExpr(aliasedExpr)
}
return true
}, nil)
return nil

default:
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "don't know how to turn %T into SQL", op)
}
return nil
}
Loading