Skip to content

Commit

Permalink
Refactor Expression and Statement Simplifier (#13636)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrés Taylor <[email protected]>
  • Loading branch information
arvind-murty and systay authored Aug 9, 2023
1 parent eee50e5 commit 87a4b16
Show file tree
Hide file tree
Showing 5 changed files with 525 additions and 361 deletions.
9 changes: 9 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,15 @@ func NewComparisonExpr(operator ComparisonExprOperator, left, right, escape Expr
}
}

// NewCaseExpr makes a new CaseExpr
func NewCaseExpr(expr Expr, whens []*When, elseExpr Expr) *CaseExpr {
return &CaseExpr{
Expr: expr,
Whens: whens,
Else: elseExpr,
}
}

// NewLimit makes a new Limit
func NewLimit(offset, rowCount int) *Limit {
return &Limit{
Expand Down
4 changes: 3 additions & 1 deletion go/vt/vtgate/planbuilder/simplifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ import (
// TestSimplifyBuggyQuery should be used to whenever we get a planner bug reported
// It will try to minimize the query to make it easier to understand and work with the bug.
func TestSimplifyBuggyQuery(t *testing.T) {
query := "(select id from unsharded union select id from unsharded_auto) union (select id from user union select name from unsharded)"
query := "select distinct count(distinct a), count(distinct 4) from user left join unsharded on 0 limit 5"
// select 0 from unsharded union select 0 from `user` union select 0 from unsharded
// select 0 from unsharded union (select 0 from `user` union select 0 from unsharded)
vschema := &vschemaWrapper{
v: loadSchema(t, "vschemas/schema.json", true),
version: Gen4,
Expand Down
147 changes: 68 additions & 79 deletions go/vt/vtgate/simplifier/expression_simplifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,104 +21,59 @@ import (
"strconv"

"vitess.io/vitess/go/vt/log"

"vitess.io/vitess/go/vt/sqlparser"
)

// CheckF is used to see if the given expression exhibits the sought after issue
type CheckF = func(sqlparser.Expr) bool

func SimplifyExpr(in sqlparser.Expr, test CheckF) (smallestKnown sqlparser.Expr) {
var maxDepth, level int
resetTo := func(e sqlparser.Expr) {
smallestKnown = e
maxDepth = depth(e)
level = 0
func SimplifyExpr(in sqlparser.Expr, test CheckF) sqlparser.Expr {
// since we can't rewrite the top level, wrap the expr in an Exprs object
smallestKnown := sqlparser.Exprs{sqlparser.CloneExpr(in)}

alwaysVisit := func(node, parent sqlparser.SQLNode) bool {
return true
}
resetTo(in)
for level <= maxDepth {
current := sqlparser.CloneExpr(smallestKnown)
nodes, replaceF := getNodesAtLevel(current, level)
replace := func(e sqlparser.Expr, idx int) {
// if we are at the first level, we are replacing the root,
// not rewriting something deep in the tree
if level == 0 {
current = e

up := func(cursor *sqlparser.Cursor) bool {
node := sqlparser.CloneSQLNode(cursor.Node())
s := &shrinker{orig: node}
expr := s.Next()
for expr != nil {
cursor.Replace(expr)

valid := test(smallestKnown[0])
log.Errorf("test: %t: simplified %s to %s, full expr: %s", valid, sqlparser.String(node), sqlparser.String(expr), sqlparser.String(smallestKnown))
if valid {
break // we will still continue trying to simplify other expressions at this level
} else {
// replace `node` in current with the simplified expression
replaceF[idx](e)
// undo the change
cursor.Replace(node)
}
expr = s.Next()
}
simplified := false
for idx, node := range nodes {
// simplify each element and create a new expression with the node replaced by the simplification
// this means that we not only need the node, but also a way to replace the node
s := &shrinker{orig: node}
expr := s.Next()
for expr != nil {
replace(expr, idx)

valid := test(current)
log.Errorf("test: %t - %s", valid, sqlparser.String(current))
if valid {
simplified = true
break // we will still continue trying to simplify other expressions at this level
} else {
// undo the change
replace(node, idx)
}
expr = s.Next()
}
}
if simplified {
resetTo(current)
} else {
level++
}
}
return smallestKnown
}

func getNodesAtLevel(e sqlparser.Expr, level int) (result []sqlparser.Expr, replaceF []func(node sqlparser.SQLNode)) {
lvl := 0
pre := func(cursor *sqlparser.Cursor) bool {
if expr, isExpr := cursor.Node().(sqlparser.Expr); level == lvl && isExpr {
result = append(result, expr)
replaceF = append(replaceF, cursor.ReplacerF())
}
lvl++
return true
}
post := func(cursor *sqlparser.Cursor) bool {
lvl--
return true
}
sqlparser.Rewrite(e, pre, post)
return
}

func depth(e sqlparser.Expr) (depth int) {
lvl := 0
pre := func(cursor *sqlparser.Cursor) bool {
lvl++
if lvl > depth {
depth = lvl
// loop until rewriting introduces no more changes
for {
prevSmallest := sqlparser.CloneExprs(smallestKnown)
sqlparser.SafeRewrite(smallestKnown, alwaysVisit, up)
if sqlparser.Equals.Exprs(prevSmallest, smallestKnown) {
break
}
return true
}
post := func(cursor *sqlparser.Cursor) bool {
lvl--
return true
}
sqlparser.Rewrite(e, pre, post)
return

return smallestKnown[0]
}

type shrinker struct {
orig sqlparser.Expr
queue []sqlparser.Expr
orig sqlparser.SQLNode
queue []sqlparser.SQLNode
}

func (s *shrinker) Next() sqlparser.Expr {
func (s *shrinker) Next() sqlparser.SQLNode {
for {
// first we check if there is already something in the queue.
// note that we are doing a nil check and not a length check here.
Expand All @@ -142,6 +97,10 @@ func (s *shrinker) Next() sqlparser.Expr {
func (s *shrinker) fillQueue() bool {
before := len(s.queue)
switch e := s.orig.(type) {
case *sqlparser.AndExpr:
s.queue = append(s.queue, e.Left, e.Right)
case *sqlparser.OrExpr:
s.queue = append(s.queue, e.Left, e.Right)
case *sqlparser.ComparisonExpr:
s.queue = append(s.queue, e.Left, e.Right)
case *sqlparser.BinaryExpr:
Expand Down Expand Up @@ -228,9 +187,39 @@ func (s *shrinker) fillQueue() bool {
for _, ae := range e.GetArgs() {
s.queue = append(s.queue, ae)
}

clone := sqlparser.CloneAggrFunc(e)
if da, ok := clone.(sqlparser.DistinctableAggr); ok {
if da.IsDistinct() {
da.SetDistinct(false)
s.queue = append(s.queue, clone)
}
}
case *sqlparser.ColName:
// we can try to replace the column with a literal value
s.queue = []sqlparser.Expr{sqlparser.NewIntLiteral("0")}
s.queue = append(s.queue, sqlparser.NewIntLiteral("0"))
case *sqlparser.CaseExpr:
s.queue = append(s.queue, e.Expr, e.Else)
for _, when := range e.Whens {
s.queue = append(s.queue, when.Cond, when.Val)
}

if len(e.Whens) > 1 {
for i := range e.Whens {
whensCopy := sqlparser.CloneSliceOfRefOfWhen(e.Whens)
// replace ith element with last element, then truncate last element
whensCopy[i] = whensCopy[len(whensCopy)-1]
whensCopy = whensCopy[:len(whensCopy)-1]
s.queue = append(s.queue, sqlparser.NewCaseExpr(e.Expr, whensCopy, e.Else))
}
}

if e.Else != nil {
s.queue = append(s.queue, sqlparser.NewCaseExpr(e.Expr, e.Whens, nil))
}
if e.Expr != nil {
s.queue = append(s.queue, sqlparser.NewCaseExpr(nil, e.Whens, e.Else))
}
default:
return false
}
Expand Down
Loading

0 comments on commit 87a4b16

Please sign in to comment.