diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 9587635fe90..6a3a65ea583 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -718,6 +718,15 @@ func NewComparisonExpr(operator ComparisonExprOperator, left, right, escape Expr } } +// NewCaseExpr makes a new CaseExpr +func NewCaseExpr(expr Expr, whens []*When, elseExpr Expr) *CaseExpr { + return &CaseExpr{ + Expr: expr, + Whens: whens, + Else: elseExpr, + } +} + // NewLimit makes a new Limit func NewLimit(offset, rowCount int) *Limit { return &Limit{ diff --git a/go/vt/vtgate/planbuilder/simplifier_test.go b/go/vt/vtgate/planbuilder/simplifier_test.go index 509564225d7..0854934ee7f 100644 --- a/go/vt/vtgate/planbuilder/simplifier_test.go +++ b/go/vt/vtgate/planbuilder/simplifier_test.go @@ -37,7 +37,9 @@ import ( // TestSimplifyBuggyQuery should be used to whenever we get a planner bug reported // It will try to minimize the query to make it easier to understand and work with the bug. func TestSimplifyBuggyQuery(t *testing.T) { - query := "(select id from unsharded union select id from unsharded_auto) union (select id from user union select name from unsharded)" + query := "select distinct count(distinct a), count(distinct 4) from user left join unsharded on 0 limit 5" + // select 0 from unsharded union select 0 from `user` union select 0 from unsharded + // select 0 from unsharded union (select 0 from `user` union select 0 from unsharded) vschema := &vschemaWrapper{ v: loadSchema(t, "vschemas/schema.json", true), version: Gen4, diff --git a/go/vt/vtgate/simplifier/expression_simplifier.go b/go/vt/vtgate/simplifier/expression_simplifier.go index 279cb1ac7dd..4537a137e76 100644 --- a/go/vt/vtgate/simplifier/expression_simplifier.go +++ b/go/vt/vtgate/simplifier/expression_simplifier.go @@ -21,104 +21,59 @@ import ( "strconv" "vitess.io/vitess/go/vt/log" + "vitess.io/vitess/go/vt/sqlparser" ) // CheckF is used to see if the given expression exhibits the sought after issue type CheckF = func(sqlparser.Expr) bool -func SimplifyExpr(in sqlparser.Expr, test CheckF) (smallestKnown sqlparser.Expr) { - var maxDepth, level int - resetTo := func(e sqlparser.Expr) { - smallestKnown = e - maxDepth = depth(e) - level = 0 +func SimplifyExpr(in sqlparser.Expr, test CheckF) sqlparser.Expr { + // since we can't rewrite the top level, wrap the expr in an Exprs object + smallestKnown := sqlparser.Exprs{sqlparser.CloneExpr(in)} + + alwaysVisit := func(node, parent sqlparser.SQLNode) bool { + return true } - resetTo(in) - for level <= maxDepth { - current := sqlparser.CloneExpr(smallestKnown) - nodes, replaceF := getNodesAtLevel(current, level) - replace := func(e sqlparser.Expr, idx int) { - // if we are at the first level, we are replacing the root, - // not rewriting something deep in the tree - if level == 0 { - current = e + + up := func(cursor *sqlparser.Cursor) bool { + node := sqlparser.CloneSQLNode(cursor.Node()) + s := &shrinker{orig: node} + expr := s.Next() + for expr != nil { + cursor.Replace(expr) + + valid := test(smallestKnown[0]) + log.Errorf("test: %t: simplified %s to %s, full expr: %s", valid, sqlparser.String(node), sqlparser.String(expr), sqlparser.String(smallestKnown)) + if valid { + break // we will still continue trying to simplify other expressions at this level } else { - // replace `node` in current with the simplified expression - replaceF[idx](e) + // undo the change + cursor.Replace(node) } + expr = s.Next() } - simplified := false - for idx, node := range nodes { - // simplify each element and create a new expression with the node replaced by the simplification - // this means that we not only need the node, but also a way to replace the node - s := &shrinker{orig: node} - expr := s.Next() - for expr != nil { - replace(expr, idx) - - valid := test(current) - log.Errorf("test: %t - %s", valid, sqlparser.String(current)) - if valid { - simplified = true - break // we will still continue trying to simplify other expressions at this level - } else { - // undo the change - replace(node, idx) - } - expr = s.Next() - } - } - if simplified { - resetTo(current) - } else { - level++ - } - } - return smallestKnown -} - -func getNodesAtLevel(e sqlparser.Expr, level int) (result []sqlparser.Expr, replaceF []func(node sqlparser.SQLNode)) { - lvl := 0 - pre := func(cursor *sqlparser.Cursor) bool { - if expr, isExpr := cursor.Node().(sqlparser.Expr); level == lvl && isExpr { - result = append(result, expr) - replaceF = append(replaceF, cursor.ReplacerF()) - } - lvl++ - return true - } - post := func(cursor *sqlparser.Cursor) bool { - lvl-- return true } - sqlparser.Rewrite(e, pre, post) - return -} -func depth(e sqlparser.Expr) (depth int) { - lvl := 0 - pre := func(cursor *sqlparser.Cursor) bool { - lvl++ - if lvl > depth { - depth = lvl + // loop until rewriting introduces no more changes + for { + prevSmallest := sqlparser.CloneExprs(smallestKnown) + sqlparser.SafeRewrite(smallestKnown, alwaysVisit, up) + if sqlparser.Equals.Exprs(prevSmallest, smallestKnown) { + break } - return true } - post := func(cursor *sqlparser.Cursor) bool { - lvl-- - return true - } - sqlparser.Rewrite(e, pre, post) - return + + return smallestKnown[0] } type shrinker struct { - orig sqlparser.Expr - queue []sqlparser.Expr + orig sqlparser.SQLNode + queue []sqlparser.SQLNode } -func (s *shrinker) Next() sqlparser.Expr { +func (s *shrinker) Next() sqlparser.SQLNode { for { // first we check if there is already something in the queue. // note that we are doing a nil check and not a length check here. @@ -142,6 +97,10 @@ func (s *shrinker) Next() sqlparser.Expr { func (s *shrinker) fillQueue() bool { before := len(s.queue) switch e := s.orig.(type) { + case *sqlparser.AndExpr: + s.queue = append(s.queue, e.Left, e.Right) + case *sqlparser.OrExpr: + s.queue = append(s.queue, e.Left, e.Right) case *sqlparser.ComparisonExpr: s.queue = append(s.queue, e.Left, e.Right) case *sqlparser.BinaryExpr: @@ -228,9 +187,39 @@ func (s *shrinker) fillQueue() bool { for _, ae := range e.GetArgs() { s.queue = append(s.queue, ae) } + + clone := sqlparser.CloneAggrFunc(e) + if da, ok := clone.(sqlparser.DistinctableAggr); ok { + if da.IsDistinct() { + da.SetDistinct(false) + s.queue = append(s.queue, clone) + } + } case *sqlparser.ColName: // we can try to replace the column with a literal value - s.queue = []sqlparser.Expr{sqlparser.NewIntLiteral("0")} + s.queue = append(s.queue, sqlparser.NewIntLiteral("0")) + case *sqlparser.CaseExpr: + s.queue = append(s.queue, e.Expr, e.Else) + for _, when := range e.Whens { + s.queue = append(s.queue, when.Cond, when.Val) + } + + if len(e.Whens) > 1 { + for i := range e.Whens { + whensCopy := sqlparser.CloneSliceOfRefOfWhen(e.Whens) + // replace ith element with last element, then truncate last element + whensCopy[i] = whensCopy[len(whensCopy)-1] + whensCopy = whensCopy[:len(whensCopy)-1] + s.queue = append(s.queue, sqlparser.NewCaseExpr(e.Expr, whensCopy, e.Else)) + } + } + + if e.Else != nil { + s.queue = append(s.queue, sqlparser.NewCaseExpr(e.Expr, e.Whens, nil)) + } + if e.Expr != nil { + s.queue = append(s.queue, sqlparser.NewCaseExpr(nil, e.Whens, e.Else)) + } default: return false } diff --git a/go/vt/vtgate/simplifier/simplifier.go b/go/vt/vtgate/simplifier/simplifier.go index ef7be4e30e5..971224112f8 100644 --- a/go/vt/vtgate/simplifier/simplifier.go +++ b/go/vt/vtgate/simplifier/simplifier.go @@ -40,12 +40,12 @@ func SimplifyStatement( return testF(sqlparser.CloneSelectStatement(s)) } + // first we try to simplify the query by removing any unions if success := trySimplifyUnions(sqlparser.CloneSelectStatement(in), test); success != nil { return SimplifyStatement(success, currentDB, si, testF) } - // first we try to simplify the query by removing any table. - // If we can remove a table and all uses of it, that's a good start + // then we try to remove a table and all uses of it if success := tryRemoveTable(tables, sqlparser.CloneSelectStatement(in), currentDB, si, testF); success != nil { return SimplifyStatement(success, currentDB, si, testF) } @@ -55,54 +55,92 @@ func SimplifyStatement( return SimplifyStatement(success, currentDB, si, testF) } - // we try to remove select expressions next + // we try to remove/replace any expressions next if success := trySimplifyExpressions(sqlparser.CloneSelectStatement(in), test); success != nil { return SimplifyStatement(success, currentDB, si, testF) } + + // we try to remove distinct last + if success := trySimplifyDistinct(sqlparser.CloneSelectStatement(in), test); success != nil { + return SimplifyStatement(success, currentDB, si, testF) + } + return in } +func trySimplifyDistinct(in sqlparser.SelectStatement, test func(statement sqlparser.SelectStatement) bool) sqlparser.SelectStatement { + simplified := false + alwaysVisitChildren := func(node, parent sqlparser.SQLNode) bool { + return true + } + + up := func(cursor *sqlparser.Cursor) bool { + if sel, ok := cursor.Node().(*sqlparser.Select); ok { + if sel.Distinct { + sel.Distinct = false + if test(sel) { + log.Errorf("removed distinct to yield: %s", sqlparser.String(sel)) + simplified = true + } else { + sel.Distinct = true + } + } + } + + return true + } + + sqlparser.SafeRewrite(in, alwaysVisitChildren, up) + + if simplified { + + return in + } + // we found no simplifications + return nil +} + func trySimplifyExpressions(in sqlparser.SelectStatement, test func(sqlparser.SelectStatement) bool) sqlparser.SelectStatement { simplified := false - visitAllExpressionsInAST(in, func(cursor expressionCursor) bool { + visit := func(cursor expressionCursor) bool { // first - let's try to remove the expression if cursor.remove() { if test(in) { log.Errorf("removed expression: %s", sqlparser.String(cursor.expr)) simplified = true - return false + return true } cursor.restore() } // ok, we seem to need this expression. let's see if we can find a simpler version - s := &shrinker{orig: cursor.expr} - newExpr := s.Next() - for newExpr != nil { - cursor.replace(newExpr) + newExpr := SimplifyExpr(cursor.expr, func(expr sqlparser.Expr) bool { + cursor.replace(expr) if test(in) { - log.Errorf("simplified expression: %s -> %s", sqlparser.String(cursor.expr), sqlparser.String(newExpr)) + log.Errorf("simplified expression: %s -> %s", sqlparser.String(cursor.expr), sqlparser.String(expr)) + cursor.restore() simplified = true - return false + return true } - newExpr = s.Next() - } - // if we get here, we failed to simplify this expression, - // so we put back in the original expression - cursor.restore() + cursor.restore() + return false + }) + + cursor.replace(newExpr) return true - }) + } + + visitAllExpressionsInAST(in, visit) if simplified { return in } - + // we found no simplifications return nil } func trySimplifyUnions(in sqlparser.SelectStatement, test func(sqlparser.SelectStatement) bool) (res sqlparser.SelectStatement) { - if union, ok := in.(*sqlparser.Union); ok { // the root object is an UNION if test(sqlparser.CloneSelectStatement(union.Left)) { @@ -113,9 +151,12 @@ func trySimplifyUnions(in sqlparser.SelectStatement, test func(sqlparser.SelectS } } - abort := false + simplified := false + alwaysVisitChildren := func(node, parent sqlparser.SQLNode) bool { + return true + } - sqlparser.Rewrite(in, func(cursor *sqlparser.Cursor) bool { + up := func(cursor *sqlparser.Cursor) bool { switch node := cursor.Node().(type) { case *sqlparser.Union: if _, ok := cursor.Parent().(*sqlparser.RootNode); ok { @@ -125,29 +166,30 @@ func trySimplifyUnions(in sqlparser.SelectStatement, test func(sqlparser.SelectS cursor.Replace(node.Left) clone := sqlparser.CloneSelectStatement(in) if test(clone) { - log.Errorf("replaced UNION with one of its children") - abort = true + log.Errorf("replaced UNION with its left child: %s -> %s", sqlparser.String(node), sqlparser.String(node.Left)) + simplified = true return true } cursor.Replace(node.Right) clone = sqlparser.CloneSelectStatement(in) if test(clone) { - log.Errorf("replaced UNION with one of its children") - abort = true + log.Errorf("replaced UNION with its right child: %s -> %s", sqlparser.String(node), sqlparser.String(node.Right)) + simplified = true return true } cursor.Replace(node) } return true - }, func(*sqlparser.Cursor) bool { - return !abort - }) + } - if !abort { - // we found no simplifications - return nil + sqlparser.SafeRewrite(in, alwaysVisitChildren, up) + + if simplified { + + return in } - return in + // we found no simplifications + return nil } func tryRemoveTable(tables []semantics.TableInfo, in sqlparser.SelectStatement, currentDB string, si semantics.SchemaInformation, test func(sqlparser.SelectStatement) bool) sqlparser.SelectStatement { @@ -158,11 +200,11 @@ func tryRemoveTable(tables []semantics.TableInfo, in sqlparser.SelectStatement, simplified := removeTable(clone, searchedTS, currentDB, si) name, _ := tbl.Name() if simplified && test(clone) { - log.Errorf("removed table %s", sqlparser.String(name)) + log.Errorf("removed table %s: %s -> %s", sqlparser.String(name), sqlparser.String(in), sqlparser.String(clone)) return clone } } - + // we found no simplifications return nil } @@ -178,7 +220,11 @@ func getTables(in sqlparser.SelectStatement, currentDB string, si semantics.Sche func simplifyStarExpr(in sqlparser.SelectStatement, test func(sqlparser.SelectStatement) bool) sqlparser.SelectStatement { simplified := false - sqlparser.Rewrite(in, func(cursor *sqlparser.Cursor) bool { + alwaysVisitChildren := func(node, parent sqlparser.SQLNode) bool { + return true + } + + up := func(cursor *sqlparser.Cursor) bool { se, ok := cursor.Node().(*sqlparser.StarExpr) if !ok { return true @@ -189,15 +235,19 @@ func simplifyStarExpr(in sqlparser.SelectStatement, test func(sqlparser.SelectSt if test(in) { log.Errorf("replaced star with literal") simplified = true - return false + return true } cursor.Replace(se) return true - }, nil) + } + + sqlparser.SafeRewrite(in, alwaysVisitChildren, up) + if simplified { return in } + // we found no simplifications return nil } @@ -209,92 +259,148 @@ func removeTable(clone sqlparser.SelectStatement, searchedTS semantics.TableSet, panic(err) } - simplified := true + simplified, kontinue := false, true shouldKeepExpr := func(expr sqlparser.Expr) bool { + // why do we keep if the expr contains an aggregation? return !semTable.RecursiveDeps(expr).IsOverlapping(searchedTS) || sqlparser.ContainsAggregation(expr) } - sqlparser.Rewrite(clone, func(cursor *sqlparser.Cursor) bool { + checkSelect := func(node, parent sqlparser.SQLNode) bool { + if sel, ok := node.(*sqlparser.Select); ok { + // remove the table from the from clause on the way down + // so that it happens before removing it anywhere else + kontinue, simplified = removeTableinSelect(sel, searchedTS, semTable, simplified) + } + + return kontinue + } + + up := func(cursor *sqlparser.Cursor) bool { switch node := cursor.Node().(type) { case *sqlparser.JoinTableExpr: - lft, ok := node.LeftExpr.(*sqlparser.AliasedTableExpr) - if ok { - ts := semTable.TableSetFor(lft) - if searchedTS == ts { - cursor.Replace(node.RightExpr) - } - } - rgt, ok := node.RightExpr.(*sqlparser.AliasedTableExpr) - if ok { - ts := semTable.TableSetFor(rgt) - if searchedTS == ts { - cursor.Replace(node.LeftExpr) - } - } - case *sqlparser.Select: - if len(node.From) == 1 { - _, notJoin := node.From[0].(*sqlparser.AliasedTableExpr) - if notJoin { - simplified = false - return false - } - } - for i, tbl := range node.From { - lft, ok := tbl.(*sqlparser.AliasedTableExpr) - if ok { - ts := semTable.TableSetFor(lft) - if searchedTS == ts { - node.From = append(node.From[:i], node.From[i+1:]...) - return true - } - } - } + simplified = removeTableinJoinTableExpr(node, searchedTS, semTable, cursor, simplified) case *sqlparser.Where: - exprs := sqlparser.SplitAndExpression(nil, node.Expr) - var newPredicate sqlparser.Expr - for _, expr := range exprs { - if !semTable.RecursiveDeps(expr).IsOverlapping(searchedTS) { - newPredicate = sqlparser.AndExpressions(newPredicate, expr) - } - } - node.Expr = newPredicate + simplified = removeTableinWhere(node, shouldKeepExpr, simplified) case sqlparser.SelectExprs: - _, isSel := cursor.Parent().(*sqlparser.Select) - if !isSel { - return true - } - - var newExprs sqlparser.SelectExprs - for _, ae := range node { - expr, ok := ae.(*sqlparser.AliasedExpr) - if !ok { - newExprs = append(newExprs, ae) - continue - } - if shouldKeepExpr(expr.Expr) { - newExprs = append(newExprs, ae) - } - } - cursor.Replace(newExprs) + simplified = removeTableinSelectExprs(node, cursor, shouldKeepExpr, simplified) case sqlparser.GroupBy: - var newExprs sqlparser.GroupBy - for _, expr := range node { - if shouldKeepExpr(expr) { - newExprs = append(newExprs, expr) - } - } - cursor.Replace(newExprs) + simplified = removeTableinGroupBy(node, cursor, shouldKeepExpr, simplified) case sqlparser.OrderBy: - var newExprs sqlparser.OrderBy - for _, expr := range node { - if shouldKeepExpr(expr.Expr) { - newExprs = append(newExprs, expr) - } + simplified = removeTableinOrderBy(node, cursor, shouldKeepExpr, simplified) + } + return true + } + + sqlparser.SafeRewrite(clone, checkSelect, up) + return simplified +} + +func removeTableinJoinTableExpr(node *sqlparser.JoinTableExpr, searchedTS semantics.TableSet, semTable *semantics.SemTable, cursor *sqlparser.Cursor, simplified bool) bool { + lft, ok := node.LeftExpr.(*sqlparser.AliasedTableExpr) + if ok { + ts := semTable.TableSetFor(lft) + if searchedTS == ts { + cursor.Replace(node.RightExpr) + simplified = true + } + } + rgt, ok := node.RightExpr.(*sqlparser.AliasedTableExpr) + if ok { + ts := semTable.TableSetFor(rgt) + if searchedTS == ts { + cursor.Replace(node.LeftExpr) + simplified = true + } + } + + return simplified +} + +func removeTableinSelect(node *sqlparser.Select, searchedTS semantics.TableSet, semTable *semantics.SemTable, simplified bool) (bool, bool) { + if len(node.From) == 1 { + _, notJoin := node.From[0].(*sqlparser.AliasedTableExpr) + if notJoin { + return false, simplified + } + } + for i, tbl := range node.From { + lft, ok := tbl.(*sqlparser.AliasedTableExpr) + if ok { + ts := semTable.TableSetFor(lft) + if searchedTS == ts { + node.From = append(node.From[:i], node.From[i+1:]...) + simplified = true } + } + } + + return true, simplified +} - cursor.Replace(newExprs) +func removeTableinWhere(node *sqlparser.Where, shouldKeepExpr func(sqlparser.Expr) bool, simplified bool) bool { + exprs := sqlparser.SplitAndExpression(nil, node.Expr) + var newPredicate sqlparser.Expr + for _, expr := range exprs { + if shouldKeepExpr(expr) { + newPredicate = sqlparser.AndExpressions(newPredicate, expr) + } else { + simplified = true } - return true - }, nil) + } + node.Expr = newPredicate + + return simplified +} + +func removeTableinSelectExprs(node sqlparser.SelectExprs, cursor *sqlparser.Cursor, shouldKeepExpr func(sqlparser.Expr) bool, simplified bool) bool { + _, isSel := cursor.Parent().(*sqlparser.Select) + if !isSel { + return simplified + } + + var newExprs sqlparser.SelectExprs + for _, ae := range node { + expr, ok := ae.(*sqlparser.AliasedExpr) + if !ok { + newExprs = append(newExprs, ae) + continue + } + if shouldKeepExpr(expr.Expr) { + newExprs = append(newExprs, ae) + } else { + simplified = true + } + } + cursor.Replace(newExprs) + + return simplified +} + +func removeTableinGroupBy(node sqlparser.GroupBy, cursor *sqlparser.Cursor, shouldKeepExpr func(sqlparser.Expr) bool, simplified bool) bool { + var newExprs sqlparser.GroupBy + for _, expr := range node { + if shouldKeepExpr(expr) { + newExprs = append(newExprs, expr) + } else { + simplified = true + } + } + cursor.Replace(newExprs) + + return simplified +} + +func removeTableinOrderBy(node sqlparser.OrderBy, cursor *sqlparser.Cursor, shouldKeepExpr func(sqlparser.Expr) bool, simplified bool) bool { + var newExprs sqlparser.OrderBy + for _, expr := range node { + if shouldKeepExpr(expr.Expr) { + newExprs = append(newExprs, expr) + } else { + simplified = true + } + } + + cursor.Replace(newExprs) + return simplified } @@ -315,180 +421,223 @@ func newExprCursor(expr sqlparser.Expr, replace func(replaceWith sqlparser.Expr) } // visitAllExpressionsInAST will walk the AST and visit all expressions -// This cursor has a few extra capabilities that the normal sqlparser.Rewrite does not have, +// This cursor has a few extra capabilities that the normal sqlparser.SafeRewrite does not have, // such as visiting and being able to change individual expressions in a AND tree +// if visit returns true, then traversal continues, otherwise traversal stops func visitAllExpressionsInAST(clone sqlparser.SelectStatement, visit func(expressionCursor) bool) { - abort := false - post := func(*sqlparser.Cursor) bool { - return !abort + alwaysVisitChildren := func(node, parent sqlparser.SQLNode) bool { + return true } - pre := func(cursor *sqlparser.Cursor) bool { - if abort { - return true - } + up := func(cursor *sqlparser.Cursor) bool { switch node := cursor.Node().(type) { case sqlparser.SelectExprs: - _, isSel := cursor.Parent().(*sqlparser.Select) - if !isSel { - return true - } - for idx := 0; idx < len(node); idx++ { - ae := node[idx] - expr, ok := ae.(*sqlparser.AliasedExpr) - if !ok { - continue - } - removed := false - original := sqlparser.CloneExpr(expr.Expr) - item := newExprCursor( - expr.Expr, - /*replace*/ func(replaceWith sqlparser.Expr) { - if removed { - panic("cant replace after remove without restore") - } - expr.Expr = replaceWith - }, - /*remove*/ func() bool { - if removed { - panic("can't remove twice, silly") - } - if len(node) == 1 { - // can't remove the last expressions - we'd end up with an empty SELECT clause - return false - } - withoutElement := append(node[:idx], node[idx+1:]...) - cursor.Replace(withoutElement) - node = withoutElement - removed = true - return true - }, - /*restore*/ func() { - if removed { - front := make(sqlparser.SelectExprs, idx) - copy(front, node[:idx]) - back := make(sqlparser.SelectExprs, len(node)-idx) - copy(back, node[idx:]) - frontWithRestoredExpr := append(front, ae) - node = append(frontWithRestoredExpr, back...) - cursor.Replace(node) - removed = false - return - } - expr.Expr = original - }, - ) - abort = !visit(item) - } + return visitSelectExprs(node, cursor, visit) case *sqlparser.Where: - exprs := sqlparser.SplitAndExpression(nil, node.Expr) - set := func(input []sqlparser.Expr) { - node.Expr = sqlparser.AndExpressions(input...) - exprs = input - } - abort = !visitExpressions(exprs, set, visit) + return visitWhere(node, visit) case *sqlparser.JoinCondition: - join, ok := cursor.Parent().(*sqlparser.JoinTableExpr) - if !ok { - return true - } - if join.Join != sqlparser.NormalJoinType || node.Using != nil { - return false - } - exprs := sqlparser.SplitAndExpression(nil, node.On) - set := func(input []sqlparser.Expr) { - node.On = sqlparser.AndExpressions(input...) - exprs = input - } - abort = !visitExpressions(exprs, set, visit) + return visitJoinCondition(node, cursor, visit) case sqlparser.GroupBy: - set := func(input []sqlparser.Expr) { - node = input - cursor.Replace(node) - } - abort = !visitExpressions(node, set, visit) + return visitGroupBy(node, cursor, visit) case sqlparser.OrderBy: - for idx := 0; idx < len(node); idx++ { - order := node[idx] - removed := false - original := sqlparser.CloneExpr(order.Expr) - item := newExprCursor( - order.Expr, - /*replace*/ func(replaceWith sqlparser.Expr) { - if removed { - panic("cant replace after remove without restore") - } - order.Expr = replaceWith - }, - /*remove*/ func() bool { - if removed { - panic("can't remove twice, silly") - } - withoutElement := append(node[:idx], node[idx+1:]...) - if len(withoutElement) == 0 { - var nilVal sqlparser.OrderBy // this is used to create a typed nil value - cursor.Replace(nilVal) - } else { - cursor.Replace(withoutElement) - } - node = withoutElement - removed = true - return true - }, - /*restore*/ func() { - if removed { - front := make(sqlparser.OrderBy, idx) - copy(front, node[:idx]) - back := make(sqlparser.OrderBy, len(node)-idx) - copy(back, node[idx:]) - frontWithRestoredExpr := append(front, order) - node = append(frontWithRestoredExpr, back...) - cursor.Replace(node) - removed = false - return - } - order.Expr = original - }, - ) - abort = visit(item) - if abort { - break - } - } + return visitOrderBy(node, cursor, visit) case *sqlparser.Limit: - if node.Offset != nil { - original := node.Offset - cursor := newExprCursor(node.Offset, - /*replace*/ func(replaceWith sqlparser.Expr) { - node.Offset = replaceWith - }, - /*remove*/ func() bool { - node.Offset = nil - return true - }, - /*restore*/ func() { - node.Offset = original - }) - abort = visit(cursor) - } - if !abort && node.Rowcount != nil { - original := node.Rowcount - cursor := newExprCursor(node.Rowcount, - /*replace*/ func(replaceWith sqlparser.Expr) { - node.Rowcount = replaceWith - }, - /*remove*/ func() bool { - // removing Rowcount is an invalid op - return false - }, - /*restore*/ func() { - node.Rowcount = original - }) - abort = visit(cursor) - } + return visitLimit(node, cursor, visit) } return true } - sqlparser.Rewrite(clone, pre, post) + sqlparser.SafeRewrite(clone, alwaysVisitChildren, up) +} + +func visitSelectExprs(node sqlparser.SelectExprs, cursor *sqlparser.Cursor, visit func(expressionCursor) bool) bool { + _, isSel := cursor.Parent().(*sqlparser.Select) + if !isSel { + return true + } + for idx := 0; idx < len(node); idx++ { + ae := node[idx] + expr, ok := ae.(*sqlparser.AliasedExpr) + if !ok { + continue + } + removed := false + original := sqlparser.CloneExpr(expr.Expr) + item := newExprCursor( + expr.Expr, + /*replace*/ func(replaceWith sqlparser.Expr) { + if removed { + panic("cant replace after remove without restore") + } + expr.Expr = replaceWith + }, + /*remove*/ func() bool { + if removed { + panic("can't remove twice, silly") + } + if len(node) == 1 { + // can't remove the last expressions - we'd end up with an empty SELECT clause + return false + } + withoutElement := append(node[:idx], node[idx+1:]...) + cursor.Replace(withoutElement) + node = withoutElement + removed = true + return true + }, + /*restore*/ func() { + if removed { + front := make(sqlparser.SelectExprs, idx) + copy(front, node[:idx]) + back := make(sqlparser.SelectExprs, len(node)-idx) + copy(back, node[idx:]) + frontWithRestoredExpr := append(front, ae) + node = append(frontWithRestoredExpr, back...) + cursor.Replace(node) + removed = false + return + } + expr.Expr = original + }, + ) + if !visit(item) { + return false + } + } + + return true +} + +func visitWhere(node *sqlparser.Where, visit func(expressionCursor) bool) bool { + exprs := sqlparser.SplitAndExpression(nil, node.Expr) + set := func(input []sqlparser.Expr) { + node.Expr = sqlparser.AndExpressions(input...) + exprs = input + } + return visitExpressions(exprs, set, visit, 0) +} + +func visitJoinCondition(node *sqlparser.JoinCondition, cursor *sqlparser.Cursor, visit func(expressionCursor) bool) bool { + join, ok := cursor.Parent().(*sqlparser.JoinTableExpr) + if !ok { + return true + } + + if node.Using != nil { + return true + } + + // for only left and right joins must the join condition be nonempty + minExprs := 0 + if join.Join == sqlparser.LeftJoinType || join.Join == sqlparser.RightJoinType { + minExprs = 1 + } + + exprs := sqlparser.SplitAndExpression(nil, node.On) + set := func(input []sqlparser.Expr) { + node.On = sqlparser.AndExpressions(input...) + exprs = input + } + return visitExpressions(exprs, set, visit, minExprs) +} + +func visitGroupBy(node sqlparser.GroupBy, cursor *sqlparser.Cursor, visit func(expressionCursor) bool) bool { + set := func(input []sqlparser.Expr) { + node = input + cursor.Replace(node) + } + return visitExpressions(node, set, visit, 0) +} + +func visitOrderBy(node sqlparser.OrderBy, cursor *sqlparser.Cursor, visit func(expressionCursor) bool) bool { + for idx := 0; idx < len(node); idx++ { + order := node[idx] + removed := false + original := sqlparser.CloneExpr(order.Expr) + item := newExprCursor( + order.Expr, + /*replace*/ func(replaceWith sqlparser.Expr) { + if removed { + panic("cant replace after remove without restore") + } + order.Expr = replaceWith + }, + /*remove*/ func() bool { + if removed { + panic("can't remove twice, silly") + } + withoutElement := append(node[:idx], node[idx+1:]...) + if len(withoutElement) == 0 { + var nilVal sqlparser.OrderBy // this is used to create a typed nil value + cursor.Replace(nilVal) + } else { + cursor.Replace(withoutElement) + } + node = withoutElement + removed = true + return true + }, + /*restore*/ func() { + if removed { + front := make(sqlparser.OrderBy, idx) + copy(front, node[:idx]) + back := make(sqlparser.OrderBy, len(node)-idx) + copy(back, node[idx:]) + frontWithRestoredExpr := append(front, order) + node = append(frontWithRestoredExpr, back...) + cursor.Replace(node) + removed = false + return + } + order.Expr = original + }, + ) + if !visit(item) { + return false + } + } + + return true +} + +func visitLimit(node *sqlparser.Limit, cursor *sqlparser.Cursor, visit func(expressionCursor) bool) bool { + if node.Offset != nil { + original := node.Offset + item := newExprCursor(node.Offset, + /*replace*/ func(replaceWith sqlparser.Expr) { + node.Offset = replaceWith + }, + /*remove*/ func() bool { + node.Offset = nil + return true + }, + /*restore*/ func() { + node.Offset = original + }) + if !visit(item) { + return false + } + } + if node.Rowcount != nil { + original := node.Rowcount + item := newExprCursor(node.Rowcount, + /*replace*/ func(replaceWith sqlparser.Expr) { + node.Rowcount = replaceWith + }, + // this removes the whole limit clause + /*remove*/ + func() bool { + var nilVal *sqlparser.Limit // this is used to create a typed nil value + cursor.Replace(nilVal) + return true + }, + /*restore*/ func() { + node.Rowcount = original + }) + if !visit(item) { + return false + } + } + + return true } // visitExpressions allows the cursor to visit all expressions in a slice, @@ -497,6 +646,7 @@ func visitExpressions( exprs []sqlparser.Expr, set func(input []sqlparser.Expr), visit func(expressionCursor) bool, + minExprs int, ) bool { for idx := 0; idx < len(exprs); idx++ { expr := exprs[idx] @@ -513,6 +663,10 @@ func visitExpressions( if removed { panic("can't remove twice, silly") } + // need to keep at least minExprs + if len(exprs) <= minExprs { + return false + } exprs = append(exprs[:idx], exprs[idx+1:]...) set(exprs) removed = true diff --git a/go/vt/vtgate/simplifier/simplifier_test.go b/go/vt/vtgate/simplifier/simplifier_test.go index 63f43a7febb..569f5adfab0 100644 --- a/go/vt/vtgate/simplifier/simplifier_test.go +++ b/go/vt/vtgate/simplifier/simplifier_test.go @@ -20,8 +20,9 @@ import ( "fmt" "testing" - "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/vt/log" + + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/vt/vtgate/evalengine" "github.com/stretchr/testify/require" @@ -82,26 +83,35 @@ func TestAbortExpressionCursor(t *testing.T) { func TestSimplifyEvalEngineExpr(t *testing.T) { // ast struct for L0 + - // L1 + + - // L2 + + + + - // L3 1 2 3 4 5 6 7 8 + // L1 + + + // L2 + + + + + // L3 1 2 3 4 5 6 + + + // L4 7 8 9 10 + + // L4 + i7, i8, i9, i10 := + sqlparser.NewIntLiteral("7"), + sqlparser.NewIntLiteral("8"), + sqlparser.NewIntLiteral("9"), + sqlparser.NewIntLiteral("10") // L3 - i1, i2, i3, i4, i5, i6, i7, i8 := + i1, i2, i3, i4, i5, i6, p31, p32 := sqlparser.NewIntLiteral("1"), sqlparser.NewIntLiteral("2"), sqlparser.NewIntLiteral("3"), sqlparser.NewIntLiteral("4"), sqlparser.NewIntLiteral("5"), sqlparser.NewIntLiteral("6"), - sqlparser.NewIntLiteral("7"), - sqlparser.NewIntLiteral("8") + plus(i7, i8), + plus(i9, i10) + // L2 p21, p22, p23, p24 := plus(i1, i2), plus(i3, i4), plus(i5, i6), - plus(i7, i8) + plus(p31, p32) // L1 p11, p12 := @@ -126,7 +136,7 @@ func TestSimplifyEvalEngineExpr(t *testing.T) { } return toInt64 >= 8 }) - log.Infof("simplest expr to evaluate to >= 8: [%s], started from: [%s]", sqlparser.String(expr), sqlparser.String(p0)) + log.Errorf("simplest expr to evaluate to >= 8: [%s], started from: [%s]", sqlparser.String(expr), sqlparser.String(p0)) } func plus(a, b sqlparser.Expr) sqlparser.Expr {