diff --git a/go/vt/sqlparser/analyzer.go b/go/vt/sqlparser/analyzer.go index 98b7677a1f3..33ad99a9f26 100644 --- a/go/vt/sqlparser/analyzer.go +++ b/go/vt/sqlparser/analyzer.go @@ -137,7 +137,7 @@ func ASTToStatementType(stmt Statement) StatementType { // CanNormalize takes Statement and returns if the statement can be normalized. func CanNormalize(stmt Statement) bool { switch stmt.(type) { - case *Select, *Union, *Insert, *Update, *Delete, *Set, *CallProc, *Stream, *VExplainStmt: // TODO: we could merge this logic into ASTrewriter + case *Select, *Union, *Insert, *Update, *Delete, *Set, *CallProc, *Stream, *VExplainStmt, *Show: // TODO: we could merge this logic into ASTrewriter return true } return false diff --git a/go/vt/sqlparser/ast_rewriting.go b/go/vt/sqlparser/ast_rewriting.go index 05e7e290fc1..b2e164bcf0f 100644 --- a/go/vt/sqlparser/ast_rewriting.go +++ b/go/vt/sqlparser/ast_rewriting.go @@ -54,41 +54,16 @@ func PrepareAST( fkChecksState *bool, views VSchemaViews, ) (*RewriteASTResult, error) { - if parameterize { - err := Normalize(in, reservedVars, bindVars) - if err != nil { - return nil, err - } - } - return RewriteAST(in, keyspace, selectLimit, setVarComment, sysVars, fkChecksState, views) -} - -// RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries. -// SET_VAR comments are also added to the AST if required. -func RewriteAST( - in Statement, - keyspace string, - selectLimit int, - setVarComment string, - sysVars map[string]string, - fkChecksState *bool, - views VSchemaViews, -) (*RewriteASTResult, error) { - er := newASTRewriter(keyspace, selectLimit, setVarComment, sysVars, fkChecksState, views) - er.shouldRewriteDatabaseFunc = shouldRewriteDatabaseFunc(in) - result := SafeRewrite(in, er.rewriteDown, er.rewriteUp) - if er.err != nil { - return nil, er.err - } - - out, ok := result.(Statement) - if !ok { - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "statement rewriting returned a non statement: %s", String(out)) + nz := newNormalizer(reservedVars, bindVars, keyspace, selectLimit, setVarComment, sysVars, fkChecksState, views, parameterize) + nz.shouldRewriteDatabaseFunc = shouldRewriteDatabaseFunc(in) + out := SafeRewrite(in, nz.walkDown, nz.walkUp) + if nz.err != nil { + return nil, nz.err } r := &RewriteASTResult{ - AST: out, - BindVarNeeds: er.bindVars, + AST: out.(Statement), + BindVarNeeds: nz.bindVarNeeds, } return r, nil } @@ -112,34 +87,6 @@ func shouldRewriteDatabaseFunc(in Statement) bool { return tableName.Name.String() == "dual" } -type astRewriter struct { - bindVars *BindVarNeeds - shouldRewriteDatabaseFunc bool - err error - - // we need to know this to make a decision if we can safely rewrite JOIN USING => JOIN ON - hasStarInSelect bool - - keyspace string - selectLimit int - setVarComment string - fkChecksState *bool - sysVars map[string]string - views VSchemaViews -} - -func newASTRewriter(keyspace string, selectLimit int, setVarComment string, sysVars map[string]string, fkChecksState *bool, views VSchemaViews) *astRewriter { - return &astRewriter{ - bindVars: &BindVarNeeds{}, - keyspace: keyspace, - selectLimit: selectLimit, - setVarComment: setVarComment, - fkChecksState: fkChecksState, - sysVars: sysVars, - views: views, - } -} - const ( // LastInsertIDName is a reserved bind var name for last_insert_id() LastInsertIDName = "__lastInsertId" @@ -157,76 +104,84 @@ const ( UserDefinedVariableName = "__vtudv" ) -func (er *astRewriter) rewriteAliasedExpr(node *AliasedExpr) (*BindVarNeeds, error) { - inner := newASTRewriter(er.keyspace, er.selectLimit, er.setVarComment, er.sysVars, nil, er.views) - inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc - tmp := SafeRewrite(node.Expr, inner.rewriteDown, inner.rewriteUp) - newExpr, ok := tmp.(Expr) - if !ok { - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to rewrite AST. function expected to return Expr returned a %s", String(tmp)) - } - node.Expr = newExpr - return inner.bindVars, nil -} - -func (er *astRewriter) rewriteDown(node SQLNode, _ SQLNode) bool { +func (nz *normalizer) rewriteDown(node SQLNode, _ SQLNode) bool { switch node := node.(type) { + case *StarExpr: + nz.hasStarInSelect = true + case *AliasedExpr: + if node.As.NotEmpty() { + break + } + buf := NewTrackedBuffer(nil) + node.Expr.Format(buf) + rewrites := nz.bindVarNeeds.NumberOfRewrites() + nz.onLeave[node] = func(newAliasedExpr *AliasedExpr) { + if nz.bindVarNeeds.NumberOfRewrites() > rewrites { + newAliasedExpr.As = NewIdentifierCI(buf.String()) + } + } case *Select: - er.visitSelect(node) + if nz.selectLimit > 0 && node.Limit == nil { + node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(nz.selectLimit))} + } case *PrepareStmt, *ExecuteStmt: return false // nothing to rewrite here. } return true } -func (er *astRewriter) rewriteUp(cursor *Cursor) bool { +func (nz *normalizer) rewriteUp(cursor *Cursor) bool { // Add SET_VAR comment to this node if it supports it and is needed if supportOptimizerHint, supportsOptimizerHint := cursor.Node().(SupportOptimizerHint); supportsOptimizerHint { - if er.setVarComment != "" { - newComments, err := supportOptimizerHint.GetParsedComments().AddQueryHint(er.setVarComment) + if nz.setVarComment != "" { + newComments, err := supportOptimizerHint.GetParsedComments().AddQueryHint(nz.setVarComment) if err != nil { - er.err = err + nz.err = err return false } supportOptimizerHint.SetComments(newComments) } - if er.fkChecksState != nil { - newComments := supportOptimizerHint.GetParsedComments().SetMySQLSetVarValue(sysvars.ForeignKeyChecks, FkChecksStateString(er.fkChecksState)) + if nz.fkChecksState != nil { + newComments := supportOptimizerHint.GetParsedComments().SetMySQLSetVarValue(sysvars.ForeignKeyChecks, FkChecksStateString(nz.fkChecksState)) supportOptimizerHint.SetComments(newComments) } } switch node := cursor.Node().(type) { + case *AliasedExpr: + if onLeave, ok := nz.onLeave[node]; ok { + onLeave(node) + } case *Union: - er.rewriteUnion(node) + nz.rewriteUnion(node) case *FuncExpr: - er.funcRewrite(cursor, node) + nz.funcRewrite(cursor, node) case *Variable: - er.rewriteVariable(cursor, node) + nz.rewriteVariable(cursor, node) case *Subquery: - er.unnestSubQueries(cursor, node) + nz.unnestSubQueries(cursor, node) case *NotExpr: - er.rewriteNotExpr(cursor, node) + nz.rewriteNotExpr(cursor, node) case *AliasedTableExpr: - er.rewriteAliasedTable(cursor, node) + nz.rewriteAliasedTable(cursor, node) case *ShowBasic: - er.rewriteShowBasic(node) + nz.rewriteShowBasic(node) case *ExistsExpr: - er.existsRewrite(cursor, node) + nz.existsRewrite(cursor, node) case DistinctableAggr: - er.rewriteDistinctableAggr(cursor, node) + nz.rewriteDistinctableAggr(cursor, node) } return true } -func (er *astRewriter) rewriteUnion(node *Union) { +func (nz *normalizer) rewriteUnion(node *Union) { // set select limit if explicitly not set when sql_select_limit is set on the connection. - if er.selectLimit > 0 && node.Limit == nil { - node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(er.selectLimit))} + if nz.selectLimit > 0 && node.Limit == nil { + node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(nz.selectLimit))} } } -func (er *astRewriter) rewriteAliasedTable(cursor *Cursor, node *AliasedTableExpr) { +func (nz *normalizer) rewriteAliasedTable(cursor *Cursor, node *AliasedTableExpr) { aliasTableName, ok := node.Expr.(TableName) if !ok { return @@ -238,9 +193,9 @@ func (er *astRewriter) rewriteAliasedTable(cursor *Cursor, node *AliasedTableExp return } - if SystemSchema(er.keyspace) { + if SystemSchema(nz.keyspace) { if aliasTableName.Qualifier.IsEmpty() { - aliasTableName.Qualifier = NewIdentifierCS(er.keyspace) + aliasTableName.Qualifier = NewIdentifierCS(nz.keyspace) node.Expr = aliasTableName cursor.Replace(node) } @@ -248,10 +203,10 @@ func (er *astRewriter) rewriteAliasedTable(cursor *Cursor, node *AliasedTableExp } // Could we be dealing with a view? - if er.views == nil { + if nz.views == nil { return } - view := er.views.FindView(aliasTableName) + view := nz.views.FindView(aliasTableName) if view == nil { return } @@ -263,16 +218,16 @@ func (er *astRewriter) rewriteAliasedTable(cursor *Cursor, node *AliasedTableExp } } -func (er *astRewriter) rewriteShowBasic(node *ShowBasic) { +func (nz *normalizer) rewriteShowBasic(node *ShowBasic) { if node.Command == VariableGlobal || node.Command == VariableSession { varsToAdd := sysvars.GetInterestingVariables() for _, sysVar := range varsToAdd { - er.bindVars.AddSysVar(sysVar) + nz.bindVarNeeds.AddSysVar(sysVar) } } } -func (er *astRewriter) rewriteNotExpr(cursor *Cursor, node *NotExpr) { +func (nz *normalizer) rewriteNotExpr(cursor *Cursor, node *NotExpr) { switch inner := node.Expr.(type) { case *ComparisonExpr: // not col = 42 => col != 42 @@ -293,7 +248,7 @@ func (er *astRewriter) rewriteNotExpr(cursor *Cursor, node *NotExpr) { } } -func (er *astRewriter) rewriteVariable(cursor *Cursor, node *Variable) { +func (nz *normalizer) rewriteVariable(cursor *Cursor, node *Variable) { // Iff we are in SET, we want to change the scope of variables if a modifier has been set // and only on the lhs of the assignment: // set session sql_mode = @someElse @@ -305,40 +260,9 @@ func (er *astRewriter) rewriteVariable(cursor *Cursor, node *Variable) { // this should be returned from the underlying database. switch node.Scope { case VariableScope: - er.udvRewrite(cursor, node) + nz.udvRewrite(cursor, node) case SessionScope, NextTxScope: - er.sysVarRewrite(cursor, node) - } -} - -func (er *astRewriter) visitSelect(node *Select) { - for _, col := range node.SelectExprs { - if _, hasStar := col.(*StarExpr); hasStar { - er.hasStarInSelect = true - continue - } - - aliasedExpr, ok := col.(*AliasedExpr) - if !ok || aliasedExpr.As.NotEmpty() { - continue - } - buf := NewTrackedBuffer(nil) - aliasedExpr.Expr.Format(buf) - // select last_insert_id() -> select :__lastInsertId as `last_insert_id()` - innerBindVarNeeds, err := er.rewriteAliasedExpr(aliasedExpr) - if err != nil { - er.err = err - return - } - if innerBindVarNeeds.HasRewrites() { - aliasedExpr.As = NewIdentifierCI(buf.String()) - } - er.bindVars.MergeWith(innerBindVarNeeds) - - } - // set select limit if explicitly not set when sql_select_limit is set on the connection. - if er.selectLimit > 0 && node.Limit == nil { - node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(er.selectLimit))} + nz.sysVarRewrite(cursor, node) } } @@ -373,12 +297,12 @@ func inverseOp(i ComparisonExprOperator) (bool, ComparisonExprOperator) { return false, i } -func (er *astRewriter) sysVarRewrite(cursor *Cursor, node *Variable) { +func (nz *normalizer) sysVarRewrite(cursor *Cursor, node *Variable) { lowered := node.Name.Lowered() var found bool - if er.sysVars != nil { - _, found = er.sysVars[lowered] + if nz.sysVars != nil { + _, found = nz.sysVars[lowered] } switch lowered { @@ -406,14 +330,14 @@ func (er *astRewriter) sysVarRewrite(cursor *Cursor, node *Variable) { if found { cursor.Replace(bindVarExpression("__vt" + lowered)) - er.bindVars.AddSysVar(lowered) + nz.bindVarNeeds.AddSysVar(lowered) } } -func (er *astRewriter) udvRewrite(cursor *Cursor, node *Variable) { +func (nz *normalizer) udvRewrite(cursor *Cursor, node *Variable) { udv := strings.ToLower(node.Name.CompliantName()) cursor.Replace(bindVarExpression(UserDefinedVariableName + udv)) - er.bindVars.AddUserDefVar(udv) + nz.bindVarNeeds.AddUserDefVar(udv) } var funcRewrites = map[string]string{ @@ -424,7 +348,7 @@ var funcRewrites = map[string]string{ "row_count": RowCountName, } -func (er *astRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) { +func (nz *normalizer) funcRewrite(cursor *Cursor, node *FuncExpr) { lowered := node.Name.Lowered() if lowered == "last_insert_id" && len(node.Exprs) > 0 { // if we are dealing with is LAST_INSERT_ID() with an argument, we don't need to rewrite it. @@ -433,18 +357,18 @@ func (er *astRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) { return } bindVar, found := funcRewrites[lowered] - if !found || (bindVar == DBVarName && !er.shouldRewriteDatabaseFunc) { + if !found || (bindVar == DBVarName && !nz.shouldRewriteDatabaseFunc) { return } if len(node.Exprs) > 0 { - er.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Argument to %s() not supported", lowered) + nz.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Argument to %s() not supported", lowered) return } cursor.Replace(bindVarExpression(bindVar)) - er.bindVars.AddFuncResult(bindVar) + nz.bindVarNeeds.AddFuncResult(bindVar) } -func (er *astRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquery) { +func (nz *normalizer) unnestSubQueries(cursor *Cursor, subquery *Subquery) { if _, isExists := cursor.Parent().(*ExistsExpr); isExists { return } @@ -483,10 +407,10 @@ func (er *astRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquery) { // is perfectly valid - any aliased columns to the left are available inside subquery scopes return } - er.bindVars.NoteRewrite() + nz.bindVarNeeds.NoteRewrite() // we need to make sure that the inner expression also gets rewritten, // so we fire off another rewriter traversal here - rewritten := SafeRewrite(expr.Expr, er.rewriteDown, er.rewriteUp) + rewritten := SafeRewrite(expr.Expr, nz.walkDown, nz.walkUp) // Here we need to handle the subquery rewrite in case in occurs in an IN clause // For example, SELECT id FROM user WHERE id IN (SELECT 1 FROM DUAL) @@ -507,7 +431,7 @@ func (er *astRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquery) { cursor.Replace(rewritten) } -func (er *astRewriter) existsRewrite(cursor *Cursor, node *ExistsExpr) { +func (nz *normalizer) existsRewrite(cursor *Cursor, node *ExistsExpr) { sel, ok := node.Subquery.Select.(*Select) if !ok { return @@ -533,14 +457,14 @@ func (er *astRewriter) existsRewrite(cursor *Cursor, node *ExistsExpr) { } // rewriteDistinctableAggr removed Distinct from Max and Min Aggregations as it does not impact the result. But, makes the plan simpler. -func (er *astRewriter) rewriteDistinctableAggr(cursor *Cursor, node DistinctableAggr) { +func (nz *normalizer) rewriteDistinctableAggr(cursor *Cursor, node DistinctableAggr) { if !node.IsDistinct() { return } switch aggr := node.(type) { case *Max, *Min: aggr.SetDistinct(false) - er.bindVars.NoteRewrite() + nz.bindVarNeeds.NoteRewrite() } } diff --git a/go/vt/sqlparser/ast_rewriting_test.go b/go/vt/sqlparser/ast_rewriting_test.go index 8b3e3d44c54..53c2ccaba4a 100644 --- a/go/vt/sqlparser/ast_rewriting_test.go +++ b/go/vt/sqlparser/ast_rewriting_test.go @@ -21,10 +21,10 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sysvars" - - "github.com/stretchr/testify/require" ) type testCaseSetVar struct { @@ -339,15 +339,18 @@ func TestRewrites(in *testing.T) { for _, tc := range tests { in.Run(tc.in, func(t *testing.T) { require := require.New(t) - stmt, err := parser.Parse(tc.in) + stmt, known, err := parser.Parse2(tc.in) require.NoError(err) - - result, err := RewriteAST( + vars := NewReservedVars("v", known) + result, err := PrepareAST( stmt, - "ks", // passing `ks` just to test that no rewriting happens as it is not system schema - SQLSelectLimitUnset, + vars, + map[string]*querypb.BindVariable{}, + false, + "ks", + 0, "", - nil, + map[string]string{}, nil, &fakeViews{}, ) @@ -441,8 +444,20 @@ func TestRewritesWithSetVarComment(in *testing.T) { require := require.New(t) stmt, err := parser.Parse(tc.in) require.NoError(err) + vars := NewReservedVars("v", nil) + result, err := PrepareAST( + stmt, + vars, + map[string]*querypb.BindVariable{}, + false, + "ks", + 0, + tc.setVarComment, + map[string]string{}, + nil, + &fakeViews{}, + ) - result, err := RewriteAST(stmt, "ks", SQLSelectLimitUnset, tc.setVarComment, nil, nil, &fakeViews{}) require.NoError(err) expected, err := parser.Parse(tc.expected) @@ -490,8 +505,20 @@ func TestRewritesSysVar(in *testing.T) { require := require.New(t) stmt, err := parser.Parse(tc.in) require.NoError(err) + vars := NewReservedVars("v", nil) + result, err := PrepareAST( + stmt, + vars, + map[string]*querypb.BindVariable{}, + false, + "ks", + 0, + "", + tc.sysVar, + nil, + &fakeViews{}, + ) - result, err := RewriteAST(stmt, "ks", SQLSelectLimitUnset, "", tc.sysVar, nil, &fakeViews{}) require.NoError(err) expected, err := parser.Parse(tc.expected) @@ -541,8 +568,20 @@ func TestRewritesWithDefaultKeyspace(in *testing.T) { require := require.New(t) stmt, err := parser.Parse(tc.in) require.NoError(err) + vars := NewReservedVars("v", nil) + result, err := PrepareAST( + stmt, + vars, + map[string]*querypb.BindVariable{}, + false, + "sys", + 0, + "", + map[string]string{}, + nil, + &fakeViews{}, + ) - result, err := RewriteAST(stmt, "sys", SQLSelectLimitUnset, "", nil, nil, &fakeViews{}) require.NoError(err) expected, err := parser.Parse(tc.expected) diff --git a/go/vt/sqlparser/bind_var_needs.go b/go/vt/sqlparser/bind_var_needs.go index 1b26919ca03..64e5c528e97 100644 --- a/go/vt/sqlparser/bind_var_needs.go +++ b/go/vt/sqlparser/bind_var_needs.go @@ -22,14 +22,7 @@ type BindVarNeeds struct { NeedSystemVariable, // NeedUserDefinedVariables keeps track of all user defined variables a query is using NeedUserDefinedVariables []string - otherRewrites bool -} - -// MergeWith adds bind vars needs coming from sub scopes -func (bvn *BindVarNeeds) MergeWith(other *BindVarNeeds) { - bvn.NeedFunctionResult = append(bvn.NeedFunctionResult, other.NeedFunctionResult...) - bvn.NeedSystemVariable = append(bvn.NeedSystemVariable, other.NeedSystemVariable...) - bvn.NeedUserDefinedVariables = append(bvn.NeedUserDefinedVariables, other.NeedUserDefinedVariables...) + otherRewrites int } // AddFuncResult adds a function bindvar need @@ -58,14 +51,11 @@ func (bvn *BindVarNeeds) NeedsSysVar(name string) bool { } func (bvn *BindVarNeeds) NoteRewrite() { - bvn.otherRewrites = true + bvn.otherRewrites++ } -func (bvn *BindVarNeeds) HasRewrites() bool { - return bvn.otherRewrites || - len(bvn.NeedFunctionResult) > 0 || - len(bvn.NeedUserDefinedVariables) > 0 || - len(bvn.NeedSystemVariable) > 0 +func (bvn *BindVarNeeds) NumberOfRewrites() int { + return len(bvn.NeedFunctionResult) + len(bvn.NeedUserDefinedVariables) + len(bvn.NeedSystemVariable) + bvn.otherRewrites } func contains(strings []string, name string) bool { diff --git a/go/vt/sqlparser/normalizer.go b/go/vt/sqlparser/normalizer.go index 02cb11e2a97..91603dff6ad 100644 --- a/go/vt/sqlparser/normalizer.go +++ b/go/vt/sqlparser/normalizer.go @@ -18,13 +18,14 @@ package sqlparser import ( "bytes" + "strconv" "vitess.io/vitess/go/mysql/datetime" "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sysvars" "vitess.io/vitess/go/vt/vterrors" - - querypb "vitess.io/vitess/go/vt/proto/query" ) // BindVars is a set of reserved bind variables from a SQL statement @@ -37,11 +38,6 @@ type BindVars map[string]struct{} // Within Select constructs, bind vars are deduped. This allows // us to identify vindex equality. Otherwise, every value is // treated as distinct. -func Normalize(stmt Statement, reserved *ReservedVars, bindVars map[string]*querypb.BindVariable) error { - nz := newNormalizer(reserved, bindVars) - _ = SafeRewrite(stmt, nz.walkStatementDown, nz.walkStatementUp) - return nz.err -} type normalizer struct { bindVars map[string]*querypb.BindVariable @@ -50,56 +46,80 @@ type normalizer struct { err error inDerived int inSelect int -} -func newNormalizer(reserved *ReservedVars, bindVars map[string]*querypb.BindVariable) *normalizer { - return &normalizer{ - bindVars: bindVars, - reserved: reserved, - vals: make(map[Literal]string), - } + // old astRewriter fielda + bindVarNeeds *BindVarNeeds + shouldRewriteDatabaseFunc bool + + // we need to know this to make a decision if we can safely rewrite JOIN USING => JOIN ON + hasStarInSelect bool + + keyspace string + selectLimit int + setVarComment string + fkChecksState *bool + sysVars map[string]string + views VSchemaViews + + onLeave map[*AliasedExpr]func(*AliasedExpr) + shouldParameterize bool } -// walkStatementUp is one half of the top level walk function. -func (nz *normalizer) walkStatementUp(cursor *Cursor) bool { - if nz.err != nil { - return false - } - switch node := cursor.node.(type) { - case *DerivedTable: - nz.inDerived-- - case *Select: - nz.inSelect-- - case *Literal: - if nz.inSelect == 0 { - nz.convertLiteral(node, cursor) - return nz.err == nil - } - parent := cursor.Parent() - switch parent.(type) { - case *Order, *GroupBy: - return true - case *Limit: - nz.convertLiteral(node, cursor) - default: - nz.convertLiteralDedup(node, cursor) - } +func newNormalizer( + reserved *ReservedVars, + bindVars map[string]*querypb.BindVariable, + keyspace string, + selectLimit int, + setVarComment string, + sysVars map[string]string, + fkChecksState *bool, + views VSchemaViews, + parameterize bool, +) *normalizer { + return &normalizer{ + bindVars: bindVars, + reserved: reserved, + vals: make(map[Literal]string), + bindVarNeeds: &BindVarNeeds{}, + keyspace: keyspace, + selectLimit: selectLimit, + setVarComment: setVarComment, + fkChecksState: fkChecksState, + sysVars: sysVars, + views: views, + onLeave: make(map[*AliasedExpr]func(*AliasedExpr)), + shouldParameterize: parameterize, } - return nz.err == nil // only continue if we haven't found any errors } -// walkStatementDown is the top level walk function. +// walkDown is the top level walk function. // If it encounters a Select, it switches to a mode // where variables are deduped. -func (nz *normalizer) walkStatementDown(node, _ SQLNode) bool { +func (nz *normalizer) walkDown(node, _ SQLNode) bool { switch node := node.(type) { - // no need to normalize the statement types - case *Set, *Show, *Begin, *Commit, *Rollback, *Savepoint, DDLStatement, *SRollback, *Release, *OtherAdmin, *Analyze: - return false + // no need to normalize the statement types TODO? + case *Begin, *Commit, *Rollback, *Savepoint, *SRollback, *Release, *OtherAdmin, *Analyze, *AssignmentExpr: + return false // TODO: we should normalize these statements case *DerivedTable: nz.inDerived++ case *Select: nz.inSelect++ + if nz.selectLimit > 0 && node.Limit == nil { + node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(nz.selectLimit))} + } + case *AliasedExpr: + // we store the expression string, so we can rewrite the alias if the expression changes + if node.As.NotEmpty() { + break + } + buf := NewTrackedBuffer(nil) + node.Expr.Format(buf) + rewrites := nz.bindVarNeeds.NumberOfRewrites() + nz.onLeave[node] = func(newAliasedExpr *AliasedExpr) { + if nz.bindVarNeeds.NumberOfRewrites() > rewrites { + newAliasedExpr.As = NewIdentifierCI(buf.String()) + } + } case SelectExprs: return nz.inDerived == 0 case *ComparisonExpr: @@ -115,6 +135,87 @@ func (nz *normalizer) walkStatementDown(node, _ SQLNode) bool { case *FramePoint: // do not make a bind var for rows and range return false + case *StarExpr: + nz.hasStarInSelect = true + case *PrepareStmt, *ExecuteStmt: // TODO: is this OK to do for the normalizer? + return false // nothing to rewrite here. + case *ShowBasic: + if node.Command == VariableGlobal || node.Command == VariableSession { + varsToAdd := sysvars.GetInterestingVariables() + for _, sysVar := range varsToAdd { + nz.bindVarNeeds.AddSysVar(sysVar) + } + } + } + return nz.err == nil // only continue if we haven't found any errors +} + +// walkUp is one half of the top level walk function. +func (nz *normalizer) walkUp(cursor *Cursor) bool { + + // Add SET_VAR comment to this node if it supports it and is needed + if supportOptimizerHint, supportsOptimizerHint := cursor.Node().(SupportOptimizerHint); supportsOptimizerHint { + if nz.setVarComment != "" { + newComments, err := supportOptimizerHint.GetParsedComments().AddQueryHint(nz.setVarComment) + if err != nil { + nz.err = err + return false + } + supportOptimizerHint.SetComments(newComments) + } + if nz.fkChecksState != nil { + newComments := supportOptimizerHint.GetParsedComments().SetMySQLSetVarValue(sysvars.ForeignKeyChecks, FkChecksStateString(nz.fkChecksState)) + supportOptimizerHint.SetComments(newComments) + } + } + + if nz.err != nil { + return false + } + switch node := cursor.node.(type) { + case *DerivedTable: + nz.inDerived-- + case *Select: + nz.inSelect-- + case *AliasedExpr: + if onLeave, ok := nz.onLeave[node]; ok { + onLeave(node) + } + case *Union: + nz.rewriteUnion(node) + case *FuncExpr: + nz.funcRewrite(cursor, node) + case *Variable: + nz.rewriteVariable(cursor, node) + case *Subquery: + nz.unnestSubQueries(cursor, node) + case *NotExpr: + nz.rewriteNotExpr(cursor, node) + case *AliasedTableExpr: + nz.rewriteAliasedTable(cursor, node) + case *ShowBasic: + nz.rewriteShowBasic(node) + case *ExistsExpr: + nz.existsRewrite(cursor, node) + case DistinctableAggr: + nz.rewriteDistinctableAggr(cursor, node) + case *Literal: + if !nz.shouldParameterize { + break + } + if nz.inSelect == 0 { + nz.convertLiteral(node, cursor) + return nz.err == nil + } + parent := cursor.Parent() + switch parent.(type) { + case *Order, *GroupBy: + return true + case *Limit: + nz.convertLiteral(node, cursor) + default: + nz.convertLiteralDedup(node, cursor) + } } return nz.err == nil // only continue if we haven't found any errors } @@ -220,6 +321,9 @@ func (nz *normalizer) rewriteOtherComparisons(node *ComparisonExpr) { } func (nz *normalizer) parameterize(left, right Expr) Expr { + if !nz.shouldParameterize { + return nil + } col, ok := left.(*ColName) if !ok { return nil @@ -267,6 +371,9 @@ func (nz *normalizer) decideBindVarName(lit *Literal, col *ColName, bval *queryp } func (nz *normalizer) rewriteInComparisons(node *ComparisonExpr) { + if !nz.shouldParameterize { + return + } tupleVals, ok := node.Right.(ValTuple) if !ok { return diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 7919a321c91..787341c3241 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -26,7 +26,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" @@ -450,8 +449,9 @@ func TestNormalize(t *testing.T) { require.NoError(t, err) known := GetBindvars(stmt) bv := make(map[string]*querypb.BindVariable) - require.NoError(t, Normalize(stmt, NewReservedVars(prefix, known), bv)) - assert.Equal(t, tc.outstmt, String(stmt)) + out, err := PrepareAST(stmt, NewReservedVars(prefix, known), bv, true, "ks", 0, "", map[string]string{}, nil, nil) + require.NoError(t, err) + assert.Equal(t, tc.outstmt, String(out.AST)) assert.Equal(t, tc.outbv, bv) }) } @@ -478,7 +478,8 @@ func TestNormalizeInvalidDates(t *testing.T) { require.NoError(t, err) known := GetBindvars(stmt) bv := make(map[string]*querypb.BindVariable) - require.EqualError(t, Normalize(stmt, NewReservedVars("bv", known), bv), tc.err.Error()) + _, err = PrepareAST(stmt, NewReservedVars("bv", known), bv, true, "ks", 0, "", map[string]string{}, nil, nil) + require.EqualError(t, err, tc.err.Error()) }) } } @@ -498,9 +499,10 @@ func TestNormalizeValidSQL(t *testing.T) { } bv := make(map[string]*querypb.BindVariable) known := make(BindVars) - err = Normalize(tree, NewReservedVars("vtg", known), bv) + + out, err := PrepareAST(tree, NewReservedVars("vtg", known), bv, true, "ks", 0, "", map[string]string{}, nil, nil) require.NoError(t, err) - normalizerOutput := String(tree) + normalizerOutput := String(out.AST) if normalizerOutput == "otheradmin" || normalizerOutput == "otherread" { return } @@ -529,9 +531,9 @@ func TestNormalizeOneCasae(t *testing.T) { } bv := make(map[string]*querypb.BindVariable) known := make(BindVars) - err = Normalize(tree, NewReservedVars("vtg", known), bv) + out, err := PrepareAST(tree, NewReservedVars("vtg", known), bv, true, "ks", 0, "", map[string]string{}, nil, nil) require.NoError(t, err) - normalizerOutput := String(tree) + normalizerOutput := String(out.AST) require.EqualValues(t, testOne.output, normalizerOutput) if normalizerOutput == "otheradmin" || normalizerOutput == "otherread" { return @@ -573,7 +575,8 @@ func BenchmarkNormalize(b *testing.B) { b.Fatal(err) } for i := 0; i < b.N; i++ { - require.NoError(b, Normalize(ast, NewReservedVars("", reservedVars), map[string]*querypb.BindVariable{})) + _, err := PrepareAST(ast, NewReservedVars("", reservedVars), map[string]*querypb.BindVariable{}, true, "ks", 0, "", map[string]string{}, nil, nil) + require.NoError(b, err) } } @@ -602,7 +605,8 @@ func BenchmarkNormalizeTraces(b *testing.B) { for i := 0; i < b.N; i++ { for i, query := range parsed { - _ = Normalize(query, NewReservedVars("", reservedVars[i]), map[string]*querypb.BindVariable{}) + _, err := PrepareAST(query, NewReservedVars("", reservedVars[i]), map[string]*querypb.BindVariable{}, true, "ks", 0, "", map[string]string{}, nil, nil) + require.NoError(b, err) } } }) diff --git a/go/vt/sqlparser/redact_query.go b/go/vt/sqlparser/redact_query.go index e6b8c009c68..2d018d7c0eb 100644 --- a/go/vt/sqlparser/redact_query.go +++ b/go/vt/sqlparser/redact_query.go @@ -28,10 +28,10 @@ func (p *Parser) RedactSQLQuery(sql string) (string, error) { return "", err } - err = Normalize(stmt, NewReservedVars("redacted", reservedVars), bv) + out, err := PrepareAST(stmt, NewReservedVars("redacted", reservedVars), bv, true, "ks", 0, "", map[string]string{}, nil, nil) if err != nil { return "", err } - return comments.Leading + String(stmt) + comments.Trailing, nil + return comments.Leading + String(out.AST) + comments.Trailing, nil } diff --git a/go/vt/sqlparser/utils.go b/go/vt/sqlparser/utils.go index b785128917f..c56e7740fc5 100644 --- a/go/vt/sqlparser/utils.go +++ b/go/vt/sqlparser/utils.go @@ -41,11 +41,12 @@ func (p *Parser) QueryMatchesTemplates(query string, queryTemplates []string) (m if err != nil { return "", err } - err = Normalize(stmt, NewReservedVars("", reservedVars), bv) + + out, err := PrepareAST(stmt, NewReservedVars("", reservedVars), bv, true, "ks", 0, "", map[string]string{}, nil, nil) if err != nil { return "", err } - normalized := CanonicalString(stmt) + normalized := CanonicalString(out.AST) return normalized, nil } diff --git a/go/vt/vtgate/planbuilder/builder.go b/go/vt/vtgate/planbuilder/builder.go index 065c50a6dfa..85d9f5f94ea 100644 --- a/go/vt/vtgate/planbuilder/builder.go +++ b/go/vt/vtgate/planbuilder/builder.go @@ -73,7 +73,7 @@ func (staticConfig) DirectEnabled() bool { // TestBuilder builds a plan for a query based on the specified vschema. // This method is only used from tests func TestBuilder(query string, vschema plancontext.VSchema, keyspace string) (*engine.Plan, error) { - stmt, reserved, err := vschema.Environment().Parser().Parse2(query) + stmt, known, err := vschema.Environment().Parser().Parse2(query) if err != nil { return nil, err } @@ -93,12 +93,12 @@ func TestBuilder(query string, vschema plancontext.VSchema, keyspace string) (*e }() } } - result, err := sqlparser.RewriteAST(stmt, keyspace, sqlparser.SQLSelectLimitUnset, "", nil, vschema.GetForeignKeyChecksState(), vschema) + reservedVars := sqlparser.NewReservedVars("vtg", known) + result, err := sqlparser.PrepareAST(stmt, reservedVars, map[string]*querypb.BindVariable{}, false, keyspace, sqlparser.SQLSelectLimitUnset, "", nil, vschema.GetForeignKeyChecksState(), vschema) if err != nil { return nil, err } - reservedVars := sqlparser.NewReservedVars("vtg", reserved) return BuildFromStmt(context.Background(), query, result.AST, reservedVars, vschema, result.BindVarNeeds, staticConfig{}) } diff --git a/go/vt/vtgate/planbuilder/simplifier_test.go b/go/vt/vtgate/planbuilder/simplifier_test.go index 5aeb0565f9b..012475ba021 100644 --- a/go/vt/vtgate/planbuilder/simplifier_test.go +++ b/go/vt/vtgate/planbuilder/simplifier_test.go @@ -21,6 +21,8 @@ import ( "fmt" "testing" + querypb "vitess.io/vitess/go/vt/proto/query" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -45,8 +47,8 @@ func TestSimplifyBuggyQuery(t *testing.T) { stmt, reserved, err := sqlparser.NewTestParser().Parse2(query) require.NoError(t, err) - rewritten, _ := sqlparser.RewriteAST(sqlparser.Clone(stmt), vw.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) reservedVars := sqlparser.NewReservedVars("vtg", reserved) + rewritten, _ := sqlparser.PrepareAST(sqlparser.Clone(stmt), reservedVars, map[string]*querypb.BindVariable{}, false, vw.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) simplified := simplifier.SimplifyStatement( stmt.(sqlparser.TableStatement), @@ -69,8 +71,8 @@ func TestSimplifyPanic(t *testing.T) { stmt, reserved, err := sqlparser.NewTestParser().Parse2(query) require.NoError(t, err) - rewritten, _ := sqlparser.RewriteAST(sqlparser.Clone(stmt), vw.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) reservedVars := sqlparser.NewReservedVars("vtg", reserved) + rewritten, _ := sqlparser.PrepareAST(sqlparser.Clone(stmt), reservedVars, map[string]*querypb.BindVariable{}, false, vw.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) simplified := simplifier.SimplifyStatement( stmt.(sqlparser.TableStatement), @@ -100,12 +102,12 @@ func TestUnsupportedFile(t *testing.T) { t.Skip() return } - rewritten, err := sqlparser.RewriteAST(stmt, vw.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) + reservedVars := sqlparser.NewReservedVars("vtg", reserved) + rewritten, err := sqlparser.PrepareAST(stmt, reservedVars, map[string]*querypb.BindVariable{}, false, vw.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) if err != nil { t.Skip() } - reservedVars := sqlparser.NewReservedVars("vtg", reserved) ast := rewritten.AST origQuery := sqlparser.String(ast) stmt, _, _ = sqlparser.NewTestParser().Parse2(tcase.Query) @@ -133,7 +135,7 @@ func keepSameError(query string, reservedVars *sqlparser.ReservedVars, vschema * if err != nil { panic(err) } - rewritten, _ := sqlparser.RewriteAST(stmt, vschema.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) + rewritten, _ := sqlparser.PrepareAST(stmt, reservedVars, map[string]*querypb.BindVariable{}, false, vschema.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) ast := rewritten.AST _, expected := BuildFromStmt(context.Background(), query, ast, reservedVars, vschema, rewritten.BindVarNeeds, staticConfig{}) if expected == nil {