From c11184addb8d0ce24798eae71f6af1885379d163 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Fri, 19 Mar 2021 16:29:48 +0530 Subject: [PATCH 01/15] rewriter gen interface and struct methods implementation Signed-off-by: Harshit Gangal --- go/tools/asthelpergen/asthelpergen.go | 4 + .../asthelpergen/integration/ast_helper.go | 287 ++++++++++++++++ go/tools/asthelpergen/rewrite_gen.go | 312 ++++++++++++++++++ go/vt/sqlparser/rewriter_api.go | 3 + 4 files changed, 606 insertions(+) create mode 100644 go/tools/asthelpergen/rewrite_gen.go diff --git a/go/tools/asthelpergen/asthelpergen.go b/go/tools/asthelpergen/asthelpergen.go index 9b1768a660a..014675ea244 100644 --- a/go/tools/asthelpergen/asthelpergen.go +++ b/go/tools/asthelpergen/asthelpergen.go @@ -310,6 +310,7 @@ func GenerateASTHelpers(packagePatterns []string, rootIface, exceptCloneType str generator.gens2 = append(generator.gens2, &equalsGen{}) generator.gens2 = append(generator.gens2, newCloneGen(exceptCloneType)) generator.gens2 = append(generator.gens2, &visitGen{}) + generator.gens2 = append(generator.gens2, &rewriteGen{}) it, err := generator.GenerateCode() if err != nil { @@ -335,6 +336,7 @@ const ( clone methodType = iota equals visit + rewrite ) func (gen *astHelperGen) addFunc(name string, typ methodType, code jen.Code) { @@ -346,6 +348,8 @@ func (gen *astHelperGen) addFunc(name string, typ methodType, code jen.Code) { comment = " does deep equals between the two objects." case visit: comment = " will visit all parts of the AST" + case rewrite: + comment = " is part of the Rewrite implementation" } gen.methods = append(gen.methods, jen.Comment(name+comment), code) } diff --git a/go/tools/asthelpergen/integration/ast_helper.go b/go/tools/asthelpergen/integration/ast_helper.go index c65490df329..fd4f1179861 100644 --- a/go/tools/asthelpergen/integration/ast_helper.go +++ b/go/tools/asthelpergen/integration/ast_helper.go @@ -176,6 +176,42 @@ func VisitAST(in AST, f Visit) error { } } +// rewriteAST is part of the Rewrite implementation +func rewriteAST(parent AST, node AST, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return true + } + switch node := node.(type) { + case BasicType: + return rewriteBasicType(parent, node, replacer, pre, post) + case Bytes: + return rewriteBytes(parent, node, replacer, pre, post) + case InterfaceContainer: + return rewriteInterfaceContainer(parent, node, replacer, pre, post) + case InterfaceSlice: + return rewriteInterfaceSlice(parent, node, replacer, pre, post) + case *Leaf: + return rewriteRefOfLeaf(parent, node, replacer, pre, post) + case LeafSlice: + return rewriteLeafSlice(parent, node, replacer, pre, post) + case *NoCloneType: + return rewriteRefOfNoCloneType(parent, node, replacer, pre, post) + case *RefContainer: + return rewriteRefOfRefContainer(parent, node, replacer, pre, post) + case *RefSliceContainer: + return rewriteRefOfRefSliceContainer(parent, node, replacer, pre, post) + case *SubImpl: + return rewriteRefOfSubImpl(parent, node, replacer, pre, post) + case ValueContainer: + return rewriteValueContainer(parent, node, replacer, pre, post) + case ValueSliceContainer: + return rewriteValueSliceContainer(parent, node, replacer, pre, post) + default: + // this should never happen + return true + } +} + // EqualsBytes does deep equals between the two objects. func EqualsBytes(a, b Bytes) bool { if len(a) != len(b) { @@ -202,6 +238,11 @@ func VisitBytes(in Bytes, f Visit) error { return err } +// rewriteBytes is part of the Rewrite implementation +func rewriteBytes(parent AST, node Bytes, replacer replacerFunc, pre, post ApplyFunc) error { + // ptrToStructMethod +} + // EqualsInterfaceContainer does deep equals between the two objects. func EqualsInterfaceContainer(a, b InterfaceContainer) bool { return true @@ -220,6 +261,10 @@ func VisitInterfaceContainer(in InterfaceContainer, f Visit) error { return nil } +// rewriteInterfaceContainer is part of the Rewrite implementation +func rewriteInterfaceContainer(parent AST, node InterfaceContainer, replacer replacerFunc, pre, post ApplyFunc) error { +} + // EqualsInterfaceSlice does deep equals between the two objects. func EqualsInterfaceSlice(a, b InterfaceSlice) bool { if len(a) != len(b) { @@ -258,6 +303,11 @@ func VisitInterfaceSlice(in InterfaceSlice, f Visit) error { return nil } +// rewriteInterfaceSlice is part of the Rewrite implementation +func rewriteInterfaceSlice(parent AST, node InterfaceSlice, replacer replacerFunc, pre, post ApplyFunc) error { + // ptrToStructMethod +} + // EqualsRefOfLeaf does deep equals between the two objects. func EqualsRefOfLeaf(a, b *Leaf) bool { if a == b { @@ -289,6 +339,22 @@ func VisitRefOfLeaf(in *Leaf, f Visit) error { return nil } +// rewriteRefOfLeaf is part of the Rewrite implementation +func rewriteRefOfLeaf(parent AST, node *Leaf, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return true + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return true + } + return post(&cur) +} + // EqualsLeafSlice does deep equals between the two objects. func EqualsLeafSlice(a, b LeafSlice) bool { if len(a) != len(b) { @@ -327,6 +393,11 @@ func VisitLeafSlice(in LeafSlice, f Visit) error { return nil } +// rewriteLeafSlice is part of the Rewrite implementation +func rewriteLeafSlice(parent AST, node LeafSlice, replacer replacerFunc, pre, post ApplyFunc) error { + // ptrToStructMethod +} + // EqualsRefOfNoCloneType does deep equals between the two objects. func EqualsRefOfNoCloneType(a, b *NoCloneType) bool { if a == b { @@ -354,6 +425,22 @@ func VisitRefOfNoCloneType(in *NoCloneType, f Visit) error { return nil } +// rewriteRefOfNoCloneType is part of the Rewrite implementation +func rewriteRefOfNoCloneType(parent AST, node *NoCloneType, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return true + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return true + } + return post(&cur) +} + // EqualsRefOfRefContainer does deep equals between the two objects. func EqualsRefOfRefContainer(a, b *RefContainer) bool { if a == b { @@ -395,6 +482,32 @@ func VisitRefOfRefContainer(in *RefContainer, f Visit) error { return nil } +// rewriteRefOfRefContainer is part of the Rewrite implementation +func rewriteRefOfRefContainer(parent AST, node *RefContainer, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return true + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return true + } + if cont := rewriteAST(node, node.ASTType, func(newNode, parent AST) { + parent.(*RefContainer).ASTType = newNode.(AST) + }, pre, post); !cont { + return false + } + if cont := rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { + parent.(*RefContainer).ASTImplementationType = newNode.(*Leaf) + }, pre, post); !cont { + return false + } + return post(&cur) +} + // EqualsRefOfRefSliceContainer does deep equals between the two objects. func EqualsRefOfRefSliceContainer(a, b *RefSliceContainer) bool { if a == b { @@ -441,6 +554,36 @@ func VisitRefOfRefSliceContainer(in *RefSliceContainer, f Visit) error { return nil } +// rewriteRefOfRefSliceContainer is part of the Rewrite implementation +func rewriteRefOfRefSliceContainer(parent AST, node *RefSliceContainer, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return true + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return true + } + for i, el := range node.ASTElements { + if cont := rewriteAST(node, el, func(newNode, parent AST) { + parent.(*RefSliceContainer).ASTElements[i] = newNode.(AST) + }, pre, post); !cont { + return false + } + } + for i, el := range node.ASTImplementationElements { + if cont := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + parent.(*RefSliceContainer).ASTImplementationElements[i] = newNode.(*Leaf) + }, pre, post); !cont { + return false + } + } + return post(&cur) +} + // EqualsRefOfSubImpl does deep equals between the two objects. func EqualsRefOfSubImpl(a, b *SubImpl) bool { if a == b { @@ -478,6 +621,27 @@ func VisitRefOfSubImpl(in *SubImpl, f Visit) error { return nil } +// rewriteRefOfSubImpl is part of the Rewrite implementation +func rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return true + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return true + } + if cont := rewriteSubIface(node, node.inner, func(newNode, parent AST) { + parent.(*SubImpl).inner = newNode.(SubIface) + }, pre, post); !cont { + return false + } + return post(&cur) +} + // EqualsValueContainer does deep equals between the two objects. func EqualsValueContainer(a, b ValueContainer) bool { return a.NotASTType == b.NotASTType && @@ -504,6 +668,20 @@ func VisitValueContainer(in ValueContainer, f Visit) error { return nil } +// rewriteValueContainer is part of the Rewrite implementation +func rewriteValueContainer(parent AST, node ValueContainer, replacer replacerFunc, pre, post ApplyFunc) error { + if cont := rewriteAST(node, node.ASTType, func(newNode, parent AST) { + parent.(ValueContainer).ASTType = newNode.(AST) + }, pre, post); !cont { + return false + } + if cont := rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { + parent.(ValueContainer).ASTImplementationType = newNode.(*Leaf) + }, pre, post); !cont { + return false + } +} + // EqualsValueSliceContainer does deep equals between the two objects. func EqualsValueSliceContainer(a, b ValueSliceContainer) bool { return EqualsSliceOfAST(a.ASTElements, b.ASTElements) && @@ -534,6 +712,24 @@ func VisitValueSliceContainer(in ValueSliceContainer, f Visit) error { return nil } +// rewriteValueSliceContainer is part of the Rewrite implementation +func rewriteValueSliceContainer(parent AST, node ValueSliceContainer, replacer replacerFunc, pre, post ApplyFunc) error { + for i, el := range node.ASTElements { + if cont := rewriteAST(node, el, func(newNode, parent AST) { + parent.(ValueSliceContainer).ASTElements[i] = newNode.(AST) + }, pre, post); !cont { + return false + } + } + for i, el := range node.ASTImplementationElements { + if cont := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + parent.(ValueSliceContainer).ASTImplementationElements[i] = newNode.(*Leaf) + }, pre, post); !cont { + return false + } + } +} + // EqualsSubIface does deep equals between the two objects. func EqualsSubIface(inA, inB SubIface) bool { if inA == nil && inB == nil { @@ -583,12 +779,31 @@ func VisitSubIface(in SubIface, f Visit) error { } } +// rewriteSubIface is part of the Rewrite implementation +func rewriteSubIface(parent AST, node SubIface, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return true + } + switch node := node.(type) { + case *SubImpl: + return rewriteRefOfSubImpl(parent, node, replacer, pre, post) + default: + // this should never happen + return true + } +} + // VisitBasicType will visit all parts of the AST func VisitBasicType(in BasicType, f Visit) error { _, err := f(in) return err } +// rewriteBasicType is part of the Rewrite implementation +func rewriteBasicType(parent AST, node BasicType, replacer replacerFunc, pre, post ApplyFunc) error { + // ptrToStructMethod +} + // EqualsRefOfInterfaceContainer does deep equals between the two objects. func EqualsRefOfInterfaceContainer(a, b *InterfaceContainer) bool { if a == b { @@ -621,6 +836,22 @@ func VisitRefOfInterfaceContainer(in *InterfaceContainer, f Visit) error { return nil } +// rewriteRefOfInterfaceContainer is part of the Rewrite implementation +func rewriteRefOfInterfaceContainer(parent AST, node *InterfaceContainer, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return true + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return true + } + return post(&cur) +} + // EqualsSliceOfAST does deep equals between the two objects. func EqualsSliceOfAST(a, b []AST) bool { if len(a) != len(b) { @@ -746,6 +977,32 @@ func VisitRefOfValueContainer(in *ValueContainer, f Visit) error { return nil } +// rewriteRefOfValueContainer is part of the Rewrite implementation +func rewriteRefOfValueContainer(parent AST, node *ValueContainer, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return true + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return true + } + if cont := rewriteAST(node, node.ASTType, func(newNode, parent AST) { + parent.(*ValueContainer).ASTType = newNode.(AST) + }, pre, post); !cont { + return false + } + if cont := rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { + parent.(*ValueContainer).ASTImplementationType = newNode.(*Leaf) + }, pre, post); !cont { + return false + } + return post(&cur) +} + // EqualsRefOfValueSliceContainer does deep equals between the two objects. func EqualsRefOfValueSliceContainer(a, b *ValueSliceContainer) bool { if a == b { @@ -791,3 +1048,33 @@ func VisitRefOfValueSliceContainer(in *ValueSliceContainer, f Visit) error { } return nil } + +// rewriteRefOfValueSliceContainer is part of the Rewrite implementation +func rewriteRefOfValueSliceContainer(parent AST, node *ValueSliceContainer, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return true + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return true + } + for i, el := range node.ASTElements { + if cont := rewriteAST(node, el, func(newNode, parent AST) { + parent.(*ValueSliceContainer).ASTElements[i] = newNode.(AST) + }, pre, post); !cont { + return false + } + } + for i, el := range node.ASTImplementationElements { + if cont := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + parent.(*ValueSliceContainer).ASTImplementationElements[i] = newNode.(*Leaf) + }, pre, post); !cont { + return false + } + } + return post(&cur) +} diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go new file mode 100644 index 00000000000..aff210ab268 --- /dev/null +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -0,0 +1,312 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package asthelpergen + +import ( + "fmt" + "go/types" + + "github.com/dave/jennifer/jen" +) + +const rewriteName = "rewrite" + +type rewriteGen struct{} + +var _ generator2 = (*rewriteGen)(nil) + +func (e rewriteGen) interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error { + if !shouldAdd(t, spi.iface()) { + return nil + } + /* + func VisitAST(in AST) (bool, error) { + if in == nil { + return false, nil + } + switch a := inA.(type) { + case *SubImpl: + return VisitSubImpl(a, b) + default: + return false, nil + } + } + */ + stmts := []jen.Code{ + jen.If(jen.Id("node == nil").Block(returnTrue())), + } + + var cases []jen.Code + _ = spi.findImplementations(iface, func(t types.Type) error { + if _, ok := t.Underlying().(*types.Interface); ok { + return nil + } + typeString := types.TypeString(t, noQualifier) + funcName := rewriteName + printableTypeName(t) + spi.addType(t) + caseBlock := jen.Case(jen.Id(typeString)).Block( + jen.Return(jen.Id(funcName).Call(jen.Id("parent, node, replacer, pre, post"))), + ) + cases = append(cases, caseBlock) + return nil + }) + + cases = append(cases, + jen.Default().Block( + jen.Comment("this should never happen"), + returnTrue(), + )) + + stmts = append(stmts, jen.Switch(jen.Id("node := node.(type)").Block( + cases..., + ))) + + rewriteFunc(t, stmts, spi) + return nil +} + +func (e rewriteGen) structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error { + if !shouldAdd(t, spi.iface()) { + return nil + } + + /* + */ + + stmts := rewriteAllStructFields(t, strct, spi) + rewriteFunc(t, stmts, spi) + + return nil +} + +func (e rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error { + + /* + if cont := rewriteAST(node, node.ASTType, func(newNode, parent AST) { + parent.(*RefContainer).ASTType = newNode.(AST) + }, pre, post); !cont { + return false + } + if cont := rewriteAST(node, node.ASTImplementationType, func(newNode, parent AST) { + parent.(*RefContainer).ASTImplementationType = newNode.(AST) + }, pre, post); !cont { + return false + } + + return post(&cur) + } + + */ + + if !shouldAdd(t, spi.iface()) { + return nil + } + + /* + */ + + stmts := []jen.Code{ + /* + if node == nil { return true } + */ + jen.If(jen.Id("node == nil").Block(returnTrue())), + + /* + cur := Cursor{ + parent: parent, + replacer: replacer, + node: node, + } + */ + jen.Id("cur := Cursor").Values( + jen.Dict{ + jen.Id("parent"): jen.Id("parent"), + jen.Id("replacer"): jen.Id("replacer"), + jen.Id("node"): jen.Id("node"), + }), + + /* + if !pre(&cur) { + return true + } + */ + jen.If(jen.Id("!pre(&cur)")).Block(returnTrue()), + } + + stmts = append(stmts, rewriteAllStructFields(t, strct, spi)...) + + stmts = append(stmts, jen.Return(jen.Id("post").Call(jen.Id("&cur")))) + rewriteFunc(t, stmts, spi) + + return nil +} + +func (e rewriteGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { + if !shouldAdd(t, spi.iface()) { + return nil + } + + /* + */ + + stmts := []jen.Code{ + jen.Comment("ptrToStructMethod"), + } + rewriteFunc(t, stmts, spi) + + return nil +} + +func (e rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error { + if !shouldAdd(t, spi.iface()) { + return nil + } + + /* + */ + + stmts := []jen.Code{ + jen.Comment("ptrToStructMethod"), + } + rewriteFunc(t, stmts, spi) + + return nil +} + +func (e rewriteGen) ptrToOtherMethod(t types.Type, _ *types.Pointer, spi generatorSPI) error { + if !shouldAdd(t, spi.iface()) { + return nil + } + + /* + */ + + stmts := []jen.Code{ + jen.Comment("ptrToStructMethod"), + } + rewriteFunc(t, stmts, spi) + + return nil +} + +func (e rewriteGen) basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error { + if !shouldAdd(t, spi.iface()) { + return nil + } + + /* + */ + + stmts := []jen.Code{ + jen.Comment("ptrToStructMethod"), + } + rewriteFunc(t, stmts, spi) + + return nil +} + +func (e rewriteGen) visitNoChildren(t types.Type, spi generatorSPI) error { + if !shouldAdd(t, spi.iface()) { + return nil + } + + /* + */ + + stmts := []jen.Code{ + jen.Comment("ptrToStructMethod"), + } + rewriteFunc(t, stmts, spi) + + return nil +} + +func rewriteFunc(t types.Type, stmts []jen.Code, spi generatorSPI) { + + /* + func (a *application) rewriteNodeType(parent AST, node NodeType, replacer replacerFunc) { + */ + + typeString := types.TypeString(t, noQualifier) + funcName := fmt.Sprintf("%s%s", rewriteName, printableTypeName(t)) + code := jen.Func().Id(funcName).Params( + jen.Id(fmt.Sprintf("parent AST, node %s, replacer replacerFunc, pre, post ApplyFunc", typeString)), + ).Bool(). + Block(stmts...) + + spi.addFunc(funcName, rewrite, code) +} + +func rewriteAllStructFields(t types.Type, strct *types.Struct, spi generatorSPI) []jen.Code { + var output []jen.Code + for i := 0; i < strct.NumFields(); i++ { + field := strct.Field(i) + if types.Implements(field.Type(), spi.iface()) { + spi.addType(field.Type()) + output = append(output, rewriteChild(t, field.Type(), jen.Id("node").Dot(field.Name()), jen.Dot(field.Name()))) + continue + } + slice, isSlice := field.Type().(*types.Slice) + if isSlice && types.Implements(slice.Elem(), spi.iface()) { + elem := slice.Elem() + spi.addType(elem) + output = append(output, + jen.For(jen.Id("i, el := range node."+field.Name())). + Block(rewriteChild(t, elem, jen.Id("el"), jen.Dot(field.Name()).Index(jen.Id("i"))))) + } + } + return output +} + +func rewriteChild(t, field types.Type, param jen.Code, replace jen.Code) jen.Code { + /* + if cont := rewriteAST(node, node.ASTType, func(newNode, parent AST) { + parent.(*RefContainer).ASTType = newNode.(AST) + }, pre, post); !cont { + return false + } + + if cont := rewriteAST(node, el, func(newNode, parent AST) { + parent.(*RefSliceContainer).ASTElements[i] = newNode.(AST) + }, pre, post); !cont { + return false + } + + */ + funcName := rewriteName + printableTypeName(field) + funcBlock := jen.Func().Call(jen.Id("newNode, parent AST")). + Block(jen.Id("parent"). + Assert(jen.Id(types.TypeString(t, noQualifier))). + Add(replace). + Op("="). + Id("newNode").Assert(jen.Id(types.TypeString(field, noQualifier)))) + + rewriteField := jen.If( + jen.Id("cont := ").Id(funcName).Call( + jen.Id("node"), + param, + funcBlock, + jen.Id("pre"), + jen.Id("post")), + jen.Id("!cont").Block(jen.Return(jen.False()))) + + return rewriteField +} + +func returnTrue() jen.Code { + return jen.Return(jen.True()) +} diff --git a/go/vt/sqlparser/rewriter_api.go b/go/vt/sqlparser/rewriter_api.go index ea25e67b1d6..4cebbf1dd3e 100644 --- a/go/vt/sqlparser/rewriter_api.go +++ b/go/vt/sqlparser/rewriter_api.go @@ -17,6 +17,7 @@ limitations under the License. package sqlparser import ( + "fmt" "reflect" "runtime" ) @@ -85,6 +86,8 @@ type abortT int var abort = abortT(0) // singleton, to signal termination of Apply +var abortE = fmt.Errorf("this error is to abort the rewriter, it is not an actual error") + // A Cursor describes a node encountered during Apply. // Information about the node and its parent is available // from the Node and Parent methods. From 217158fd8936990e1da4c23bc2b3e4c01aaf202a Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Fri, 19 Mar 2021 17:16:36 +0530 Subject: [PATCH 02/15] rewriter method to return error than bool Signed-off-by: Harshit Gangal --- .../asthelpergen/integration/ast_helper.go | 188 +++++++++++------- go/tools/asthelpergen/integration/types.go | 2 + go/tools/asthelpergen/rewrite_gen.go | 80 ++++---- 3 files changed, 156 insertions(+), 114 deletions(-) diff --git a/go/tools/asthelpergen/integration/ast_helper.go b/go/tools/asthelpergen/integration/ast_helper.go index fd4f1179861..dca5fb9efa4 100644 --- a/go/tools/asthelpergen/integration/ast_helper.go +++ b/go/tools/asthelpergen/integration/ast_helper.go @@ -17,6 +17,8 @@ limitations under the License. package integration +import "fmt" + // EqualsAST does deep equals between the two objects. func EqualsAST(inA, inB AST) bool { if inA == nil && inB == nil { @@ -179,7 +181,7 @@ func VisitAST(in AST, f Visit) error { // rewriteAST is part of the Rewrite implementation func rewriteAST(parent AST, node AST, replacer replacerFunc, pre, post ApplyFunc) error { if node == nil { - return true + return nil } switch node := node.(type) { case BasicType: @@ -208,7 +210,7 @@ func rewriteAST(parent AST, node AST, replacer replacerFunc, pre, post ApplyFunc return rewriteValueSliceContainer(parent, node, replacer, pre, post) default: // this should never happen - return true + return nil } } @@ -342,7 +344,7 @@ func VisitRefOfLeaf(in *Leaf, f Visit) error { // rewriteRefOfLeaf is part of the Rewrite implementation func rewriteRefOfLeaf(parent AST, node *Leaf, replacer replacerFunc, pre, post ApplyFunc) error { if node == nil { - return true + return nil } cur := Cursor{ node: node, @@ -350,9 +352,12 @@ func rewriteRefOfLeaf(parent AST, node *Leaf, replacer replacerFunc, pre, post A replacer: replacer, } if !pre(&cur) { - return true + return nil } - return post(&cur) + if !post(&cur) { + return abortE + } + return nil } // EqualsLeafSlice does deep equals between the two objects. @@ -428,7 +433,7 @@ func VisitRefOfNoCloneType(in *NoCloneType, f Visit) error { // rewriteRefOfNoCloneType is part of the Rewrite implementation func rewriteRefOfNoCloneType(parent AST, node *NoCloneType, replacer replacerFunc, pre, post ApplyFunc) error { if node == nil { - return true + return nil } cur := Cursor{ node: node, @@ -436,9 +441,12 @@ func rewriteRefOfNoCloneType(parent AST, node *NoCloneType, replacer replacerFun replacer: replacer, } if !pre(&cur) { - return true + return nil + } + if !post(&cur) { + return abortE } - return post(&cur) + return nil } // EqualsRefOfRefContainer does deep equals between the two objects. @@ -485,7 +493,7 @@ func VisitRefOfRefContainer(in *RefContainer, f Visit) error { // rewriteRefOfRefContainer is part of the Rewrite implementation func rewriteRefOfRefContainer(parent AST, node *RefContainer, replacer replacerFunc, pre, post ApplyFunc) error { if node == nil { - return true + return nil } cur := Cursor{ node: node, @@ -493,19 +501,22 @@ func rewriteRefOfRefContainer(parent AST, node *RefContainer, replacer replacerF replacer: replacer, } if !pre(&cur) { - return true + return nil } - if cont := rewriteAST(node, node.ASTType, func(newNode, parent AST) { + if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { parent.(*RefContainer).ASTType = newNode.(AST) - }, pre, post); !cont { - return false + }, pre, post); errF != nil { + return errF } - if cont := rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { + if errF := rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { parent.(*RefContainer).ASTImplementationType = newNode.(*Leaf) - }, pre, post); !cont { - return false + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return abortE } - return post(&cur) + return nil } // EqualsRefOfRefSliceContainer does deep equals between the two objects. @@ -557,7 +568,7 @@ func VisitRefOfRefSliceContainer(in *RefSliceContainer, f Visit) error { // rewriteRefOfRefSliceContainer is part of the Rewrite implementation func rewriteRefOfRefSliceContainer(parent AST, node *RefSliceContainer, replacer replacerFunc, pre, post ApplyFunc) error { if node == nil { - return true + return nil } cur := Cursor{ node: node, @@ -565,23 +576,26 @@ func rewriteRefOfRefSliceContainer(parent AST, node *RefSliceContainer, replacer replacer: replacer, } if !pre(&cur) { - return true + return nil } for i, el := range node.ASTElements { - if cont := rewriteAST(node, el, func(newNode, parent AST) { + if errF := rewriteAST(node, el, func(newNode, parent AST) { parent.(*RefSliceContainer).ASTElements[i] = newNode.(AST) - }, pre, post); !cont { - return false + }, pre, post); errF != nil { + return errF } } for i, el := range node.ASTImplementationElements { - if cont := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + if errF := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { parent.(*RefSliceContainer).ASTImplementationElements[i] = newNode.(*Leaf) - }, pre, post); !cont { - return false + }, pre, post); errF != nil { + return errF } } - return post(&cur) + if !post(&cur) { + return abortE + } + return nil } // EqualsRefOfSubImpl does deep equals between the two objects. @@ -624,7 +638,7 @@ func VisitRefOfSubImpl(in *SubImpl, f Visit) error { // rewriteRefOfSubImpl is part of the Rewrite implementation func rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer replacerFunc, pre, post ApplyFunc) error { if node == nil { - return true + return nil } cur := Cursor{ node: node, @@ -632,14 +646,17 @@ func rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer replacerFunc, pre, replacer: replacer, } if !pre(&cur) { - return true + return nil } - if cont := rewriteSubIface(node, node.inner, func(newNode, parent AST) { + if errF := rewriteSubIface(node, node.inner, func(newNode, parent AST) { parent.(*SubImpl).inner = newNode.(SubIface) - }, pre, post); !cont { - return false + }, pre, post); errF != nil { + return errF } - return post(&cur) + if !post(&cur) { + return abortE + } + return nil } // EqualsValueContainer does deep equals between the two objects. @@ -668,18 +685,40 @@ func VisitValueContainer(in ValueContainer, f Visit) error { return nil } -// rewriteValueContainer is part of the Rewrite implementation +// +// +//ueContainer is part of the Rewrite implementation func rewriteValueContainer(parent AST, node ValueContainer, replacer replacerFunc, pre, post ApplyFunc) error { - if cont := rewriteAST(node, node.ASTType, func(newNode, parent AST) { - parent.(ValueContainer).ASTType = newNode.(AST) - }, pre, post); !cont { - return false + var err error + + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if cont := rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { - parent.(ValueContainer).ASTImplementationType = newNode.(*Leaf) - }, pre, post); !cont { - return false + + if !pre(&cur) { + return nil + } + + if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { + err = fmt.Errorf("oh noes") + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { + err = fmt.Errorf("oh noes") + }, pre, post); errF != nil { + return errF + } + + if err != nil { + return err + } + if !post(&cur) { + return abortE } + return nil } // EqualsValueSliceContainer does deep equals between the two objects. @@ -715,17 +754,17 @@ func VisitValueSliceContainer(in ValueSliceContainer, f Visit) error { // rewriteValueSliceContainer is part of the Rewrite implementation func rewriteValueSliceContainer(parent AST, node ValueSliceContainer, replacer replacerFunc, pre, post ApplyFunc) error { for i, el := range node.ASTElements { - if cont := rewriteAST(node, el, func(newNode, parent AST) { + if errF := rewriteAST(node, el, func(newNode, parent AST) { parent.(ValueSliceContainer).ASTElements[i] = newNode.(AST) - }, pre, post); !cont { - return false + }, pre, post); errF != nil { + return errF } } for i, el := range node.ASTImplementationElements { - if cont := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + if errF := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { parent.(ValueSliceContainer).ASTImplementationElements[i] = newNode.(*Leaf) - }, pre, post); !cont { - return false + }, pre, post); errF != nil { + return errF } } } @@ -782,14 +821,14 @@ func VisitSubIface(in SubIface, f Visit) error { // rewriteSubIface is part of the Rewrite implementation func rewriteSubIface(parent AST, node SubIface, replacer replacerFunc, pre, post ApplyFunc) error { if node == nil { - return true + return nil } switch node := node.(type) { case *SubImpl: return rewriteRefOfSubImpl(parent, node, replacer, pre, post) default: // this should never happen - return true + return nil } } @@ -839,7 +878,7 @@ func VisitRefOfInterfaceContainer(in *InterfaceContainer, f Visit) error { // rewriteRefOfInterfaceContainer is part of the Rewrite implementation func rewriteRefOfInterfaceContainer(parent AST, node *InterfaceContainer, replacer replacerFunc, pre, post ApplyFunc) error { if node == nil { - return true + return nil } cur := Cursor{ node: node, @@ -847,9 +886,12 @@ func rewriteRefOfInterfaceContainer(parent AST, node *InterfaceContainer, replac replacer: replacer, } if !pre(&cur) { - return true + return nil } - return post(&cur) + if !post(&cur) { + return abortE + } + return nil } // EqualsSliceOfAST does deep equals between the two objects. @@ -980,7 +1022,7 @@ func VisitRefOfValueContainer(in *ValueContainer, f Visit) error { // rewriteRefOfValueContainer is part of the Rewrite implementation func rewriteRefOfValueContainer(parent AST, node *ValueContainer, replacer replacerFunc, pre, post ApplyFunc) error { if node == nil { - return true + return nil } cur := Cursor{ node: node, @@ -988,19 +1030,22 @@ func rewriteRefOfValueContainer(parent AST, node *ValueContainer, replacer repla replacer: replacer, } if !pre(&cur) { - return true + return nil } - if cont := rewriteAST(node, node.ASTType, func(newNode, parent AST) { + if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { parent.(*ValueContainer).ASTType = newNode.(AST) - }, pre, post); !cont { - return false + }, pre, post); errF != nil { + return errF } - if cont := rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { + if errF := rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { parent.(*ValueContainer).ASTImplementationType = newNode.(*Leaf) - }, pre, post); !cont { - return false + }, pre, post); errF != nil { + return errF } - return post(&cur) + if !post(&cur) { + return abortE + } + return nil } // EqualsRefOfValueSliceContainer does deep equals between the two objects. @@ -1052,7 +1097,7 @@ func VisitRefOfValueSliceContainer(in *ValueSliceContainer, f Visit) error { // rewriteRefOfValueSliceContainer is part of the Rewrite implementation func rewriteRefOfValueSliceContainer(parent AST, node *ValueSliceContainer, replacer replacerFunc, pre, post ApplyFunc) error { if node == nil { - return true + return nil } cur := Cursor{ node: node, @@ -1060,21 +1105,24 @@ func rewriteRefOfValueSliceContainer(parent AST, node *ValueSliceContainer, repl replacer: replacer, } if !pre(&cur) { - return true + return nil } for i, el := range node.ASTElements { - if cont := rewriteAST(node, el, func(newNode, parent AST) { + if errF := rewriteAST(node, el, func(newNode, parent AST) { parent.(*ValueSliceContainer).ASTElements[i] = newNode.(AST) - }, pre, post); !cont { - return false + }, pre, post); errF != nil { + return errF } } for i, el := range node.ASTImplementationElements { - if cont := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + if errF := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { parent.(*ValueSliceContainer).ASTImplementationElements[i] = newNode.(*Leaf) - }, pre, post); !cont { - return false + }, pre, post); errF != nil { + return errF } } - return post(&cur) + if !post(&cur) { + return abortE + } + return nil } diff --git a/go/tools/asthelpergen/integration/types.go b/go/tools/asthelpergen/integration/types.go index ae9fa38f01a..2233bb7b84e 100644 --- a/go/tools/asthelpergen/integration/types.go +++ b/go/tools/asthelpergen/integration/types.go @@ -173,3 +173,5 @@ func (r *NoCloneType) String() string { } type Visit func(node AST) (bool, error) + +var abortE = fmt.Errorf("this error is to abort the rewriter, it is not an actual error") diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go index aff210ab268..a4d5454d87a 100644 --- a/go/tools/asthelpergen/rewrite_gen.go +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -47,7 +47,7 @@ func (e rewriteGen) interfaceMethod(t types.Type, iface *types.Interface, spi ge } */ stmts := []jen.Code{ - jen.If(jen.Id("node == nil").Block(returnTrue())), + jen.If(jen.Id("node == nil").Block(returnNil())), } var cases []jen.Code @@ -68,7 +68,7 @@ func (e rewriteGen) interfaceMethod(t types.Type, iface *types.Interface, spi ge cases = append(cases, jen.Default().Block( jen.Comment("this should never happen"), - returnTrue(), + returnNil(), )) stmts = append(stmts, jen.Switch(jen.Id("node := node.(type)").Block( @@ -87,43 +87,22 @@ func (e rewriteGen) structMethod(t types.Type, strct *types.Struct, spi generato /* */ - stmts := rewriteAllStructFields(t, strct, spi) + stmts := rewriteAllStructFields(t, strct, spi, true) rewriteFunc(t, stmts, spi) return nil } func (e rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error { - - /* - if cont := rewriteAST(node, node.ASTType, func(newNode, parent AST) { - parent.(*RefContainer).ASTType = newNode.(AST) - }, pre, post); !cont { - return false - } - if cont := rewriteAST(node, node.ASTImplementationType, func(newNode, parent AST) { - parent.(*RefContainer).ASTImplementationType = newNode.(AST) - }, pre, post); !cont { - return false - } - - return post(&cur) - } - - */ - if !shouldAdd(t, spi.iface()) { return nil } - /* - */ - stmts := []jen.Code{ /* - if node == nil { return true } + if node == nil { return nil } */ - jen.If(jen.Id("node == nil").Block(returnTrue())), + jen.If(jen.Id("node == nil").Block(returnNil())), /* cur := Cursor{ @@ -141,15 +120,18 @@ func (e rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi gen /* if !pre(&cur) { - return true + return nil } */ - jen.If(jen.Id("!pre(&cur)")).Block(returnTrue()), + jen.If(jen.Id("!pre(&cur)")).Block(returnNil()), } - stmts = append(stmts, rewriteAllStructFields(t, strct, spi)...) + stmts = append(stmts, rewriteAllStructFields(t, strct, spi, false)...) - stmts = append(stmts, jen.Return(jen.Id("post").Call(jen.Id("&cur")))) + stmts = append(stmts, + jen.If(jen.Id("!post").Call(jen.Id("&cur"))).Block(jen.Return(jen.Id("abortE"))), + returnNil(), + ) rewriteFunc(t, stmts, spi) return nil @@ -245,13 +227,23 @@ func rewriteFunc(t types.Type, stmts []jen.Code, spi generatorSPI) { funcName := fmt.Sprintf("%s%s", rewriteName, printableTypeName(t)) code := jen.Func().Id(funcName).Params( jen.Id(fmt.Sprintf("parent AST, node %s, replacer replacerFunc, pre, post ApplyFunc", typeString)), - ).Bool(). + ).Error(). Block(stmts...) spi.addFunc(funcName, rewrite, code) } -func rewriteAllStructFields(t types.Type, strct *types.Struct, spi generatorSPI) []jen.Code { +func rewriteAllStructFields(t types.Type, strct *types.Struct, spi generatorSPI, fail bool) []jen.Code { + // _, ok := t.Underlying().(*types.Pointer) + + /* + if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { + err = vterrors.New(vtrpcpb.Code_INTERNAL, "[BUG] tried to replace '%s' on '%s'") + }, pre, post); errF != nil { + return errF + } + + */ var output []jen.Code for i := 0; i < strct.NumFields(); i++ { field := strct.Field(i) @@ -272,18 +264,22 @@ func rewriteAllStructFields(t types.Type, strct *types.Struct, spi generatorSPI) return output } +func failReplacer(t types.Type, f *types.Var) jen.Code { + +} + func rewriteChild(t, field types.Type, param jen.Code, replace jen.Code) jen.Code { /* - if cont := rewriteAST(node, node.ASTType, func(newNode, parent AST) { + if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { parent.(*RefContainer).ASTType = newNode.(AST) - }, pre, post); !cont { - return false + }, pre, post); errF != nil { + return errF } - if cont := rewriteAST(node, el, func(newNode, parent AST) { + if errF := rewriteAST(node, el, func(newNode, parent AST) { parent.(*RefSliceContainer).ASTElements[i] = newNode.(AST) - }, pre, post); !cont { - return false + }, pre, post); errF != nil { + return errF } */ @@ -296,17 +292,13 @@ func rewriteChild(t, field types.Type, param jen.Code, replace jen.Code) jen.Cod Id("newNode").Assert(jen.Id(types.TypeString(field, noQualifier)))) rewriteField := jen.If( - jen.Id("cont := ").Id(funcName).Call( + jen.Id("errF := ").Id(funcName).Call( jen.Id("node"), param, funcBlock, jen.Id("pre"), jen.Id("post")), - jen.Id("!cont").Block(jen.Return(jen.False()))) + jen.Id("errF != nil").Block(jen.Return(jen.Id("errF")))) return rewriteField } - -func returnTrue() jen.Code { - return jen.Return(jen.True()) -} From d00256516441c3f28aada665f1169a2ac8bae777 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Fri, 19 Mar 2021 13:37:04 +0100 Subject: [PATCH 03/15] generate rewrite methods for value types Signed-off-by: Andres Taylor --- .../asthelpergen/integration/ast_helper.go | 57 ++++++++++++---- go/tools/asthelpergen/rewrite_gen.go | 67 +++++++++++++------ 2 files changed, 90 insertions(+), 34 deletions(-) diff --git a/go/tools/asthelpergen/integration/ast_helper.go b/go/tools/asthelpergen/integration/ast_helper.go index dca5fb9efa4..3bcf893b274 100644 --- a/go/tools/asthelpergen/integration/ast_helper.go +++ b/go/tools/asthelpergen/integration/ast_helper.go @@ -17,7 +17,10 @@ limitations under the License. package integration -import "fmt" +import ( + vtrpc "vitess.io/vitess/go/vt/proto/vtrpc" + vterrors "vitess.io/vitess/go/vt/vterrors" +) // EqualsAST does deep equals between the two objects. func EqualsAST(inA, inB AST) bool { @@ -265,6 +268,22 @@ func VisitInterfaceContainer(in InterfaceContainer, f Visit) error { // rewriteInterfaceContainer is part of the Rewrite implementation func rewriteInterfaceContainer(parent AST, node InterfaceContainer, replacer replacerFunc, pre, post ApplyFunc) error { + var err error + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if err != nil { + return err + } + if !post(&cur) { + return abortE + } + return nil } // EqualsInterfaceSlice does deep equals between the two objects. @@ -685,33 +704,27 @@ func VisitValueContainer(in ValueContainer, f Visit) error { return nil } -// -// -//ueContainer is part of the Rewrite implementation +// rewriteValueContainer is part of the Rewrite implementation func rewriteValueContainer(parent AST, node ValueContainer, replacer replacerFunc, pre, post ApplyFunc) error { var err error - cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if !pre(&cur) { return nil } - if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { - err = fmt.Errorf("oh noes") + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTType' on 'ValueContainer'") }, pre, post); errF != nil { return errF } if errF := rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { - err = fmt.Errorf("oh noes") + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTImplementationType' on 'ValueContainer'") }, pre, post); errF != nil { return errF } - if err != nil { return err } @@ -753,20 +766,36 @@ func VisitValueSliceContainer(in ValueSliceContainer, f Visit) error { // rewriteValueSliceContainer is part of the Rewrite implementation func rewriteValueSliceContainer(parent AST, node ValueSliceContainer, replacer replacerFunc, pre, post ApplyFunc) error { - for i, el := range node.ASTElements { + var err error + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for _, el := range node.ASTElements { if errF := rewriteAST(node, el, func(newNode, parent AST) { - parent.(ValueSliceContainer).ASTElements[i] = newNode.(AST) + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTElements' on 'ValueSliceContainer'") }, pre, post); errF != nil { return errF } } - for i, el := range node.ASTImplementationElements { + for _, el := range node.ASTImplementationElements { if errF := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { - parent.(ValueSliceContainer).ASTImplementationElements[i] = newNode.(*Leaf) + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTImplementationElements' on 'ValueSliceContainer'") }, pre, post); errF != nil { return errF } } + if err != nil { + return err + } + if !post(&cur) { + return abortE + } + return nil } // EqualsSubIface does deep equals between the two objects. diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go index a4d5454d87a..7ef3600cc32 100644 --- a/go/tools/asthelpergen/rewrite_gen.go +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -84,10 +84,17 @@ func (e rewriteGen) structMethod(t types.Type, strct *types.Struct, spi generato return nil } - /* - */ - - stmts := rewriteAllStructFields(t, strct, spi, true) + stmts := []jen.Code{ + jen.Var().Id("err").Error(), + createCursor(), + jen.If(jen.Id("!pre(&cur)")).Block(returnNil()), + } + stmts = append(stmts, rewriteAllStructFields(t, strct, spi, true)...) + stmts = append(stmts, + jen.If(jen.Id("err != nil")).Block(jen.Return(jen.Err())), + jen.If(jen.Id("!post").Call(jen.Id("&cur"))).Block(jen.Return(jen.Id("abortE"))), + returnNil(), + ) rewriteFunc(t, stmts, spi) return nil @@ -111,12 +118,7 @@ func (e rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi gen node: node, } */ - jen.Id("cur := Cursor").Values( - jen.Dict{ - jen.Id("parent"): jen.Id("parent"), - jen.Id("replacer"): jen.Id("replacer"), - jen.Id("node"): jen.Id("node"), - }), + createCursor(), /* if !pre(&cur) { @@ -137,6 +139,15 @@ func (e rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi gen return nil } +func createCursor() *jen.Statement { + return jen.Id("cur := Cursor").Values( + jen.Dict{ + jen.Id("parent"): jen.Id("parent"), + jen.Id("replacer"): jen.Id("replacer"), + jen.Id("node"): jen.Id("node"), + }) +} + func (e rewriteGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { if !shouldAdd(t, spi.iface()) { return nil @@ -249,26 +260,35 @@ func rewriteAllStructFields(t types.Type, strct *types.Struct, spi generatorSPI, field := strct.Field(i) if types.Implements(field.Type(), spi.iface()) { spi.addType(field.Type()) - output = append(output, rewriteChild(t, field.Type(), jen.Id("node").Dot(field.Name()), jen.Dot(field.Name()))) + output = append(output, rewriteChild(t, field.Type(), field.Name(), jen.Id("node").Dot(field.Name()), jen.Dot(field.Name()), fail)) continue } slice, isSlice := field.Type().(*types.Slice) if isSlice && types.Implements(slice.Elem(), spi.iface()) { - elem := slice.Elem() - spi.addType(elem) + spi.addType(slice.Elem()) + id := jen.Id("i") + if fail { + id = jen.Id("_") + } output = append(output, - jen.For(jen.Id("i, el := range node."+field.Name())). - Block(rewriteChild(t, elem, jen.Id("el"), jen.Dot(field.Name()).Index(jen.Id("i"))))) + jen.For(jen.List(id, jen.Id("el")).Op(":=").Id("range node."+field.Name())). + Block(rewriteChild(t, slice.Elem(), field.Name(), jen.Id("el"), jen.Dot(field.Name()).Index(id), fail))) } } return output } -func failReplacer(t types.Type, f *types.Var) jen.Code { +func failReplacer(t types.Type, f string) *jen.Statement { + //err = vterrors.New(vtrpcpb.Code_INTERNAL, "[BUG] tried to replace '%s' on '%s'") + typeString := types.TypeString(t, noQualifier) + return jen.Err().Op("=").Qual("vitess.io/vitess/go/vt/vterrors", "New").Call( + jen.Qual("vitess.io/vitess/go/vt/proto/vtrpc", "Code_INTERNAL"), + jen.Lit(fmt.Sprintf("[BUG] tried to replace '%s' on '%s'", f, typeString)), + ) } -func rewriteChild(t, field types.Type, param jen.Code, replace jen.Code) jen.Code { +func rewriteChild(t, field types.Type, fieldName string, param jen.Code, replace jen.Code, fail bool) jen.Code { /* if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { parent.(*RefContainer).ASTType = newNode.(AST) @@ -284,12 +304,19 @@ func rewriteChild(t, field types.Type, param jen.Code, replace jen.Code) jen.Cod */ funcName := rewriteName + printableTypeName(field) - funcBlock := jen.Func().Call(jen.Id("newNode, parent AST")). - Block(jen.Id("parent"). + var replaceOrFail *jen.Statement + if fail { + replaceOrFail = failReplacer(t, fieldName) + } else { + replaceOrFail = jen.Id("parent"). Assert(jen.Id(types.TypeString(t, noQualifier))). Add(replace). Op("="). - Id("newNode").Assert(jen.Id(types.TypeString(field, noQualifier)))) + Id("newNode").Assert(jen.Id(types.TypeString(field, noQualifier))) + + } + funcBlock := jen.Func().Call(jen.Id("newNode, parent AST")). + Block(replaceOrFail) rewriteField := jen.If( jen.Id("errF := ").Id(funcName).Call( From 527e4d99fc70fe2b3aa46132ab6366c607b03e1f Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Fri, 19 Mar 2021 18:21:05 +0530 Subject: [PATCH 04/15] generate rewriter method for basic types Signed-off-by: Harshit Gangal --- go/tools/asthelpergen/integration/ast_helper.go | 13 ++++++++++++- go/tools/asthelpergen/rewrite_gen.go | 10 +++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/go/tools/asthelpergen/integration/ast_helper.go b/go/tools/asthelpergen/integration/ast_helper.go index 3bcf893b274..8c2450b2587 100644 --- a/go/tools/asthelpergen/integration/ast_helper.go +++ b/go/tools/asthelpergen/integration/ast_helper.go @@ -869,7 +869,18 @@ func VisitBasicType(in BasicType, f Visit) error { // rewriteBasicType is part of the Rewrite implementation func rewriteBasicType(parent AST, node BasicType, replacer replacerFunc, pre, post ApplyFunc) error { - // ptrToStructMethod + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return abortE + } + return nil } // EqualsRefOfInterfaceContainer does deep equals between the two objects. diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go index 7ef3600cc32..79c8e7f9b9e 100644 --- a/go/tools/asthelpergen/rewrite_gen.go +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -201,14 +201,14 @@ func (e rewriteGen) basicMethod(t types.Type, basic *types.Basic, spi generatorS return nil } - /* - */ - stmts := []jen.Code{ - jen.Comment("ptrToStructMethod"), + createCursor(), + jen.If(jen.Id("!pre(&cur)")).Block(returnNil()), + jen.If(jen.Id("!post").Call(jen.Id("&cur"))).Block(jen.Return(jen.Id("abortE"))), + returnNil(), } - rewriteFunc(t, stmts, spi) + rewriteFunc(t, stmts, spi) return nil } From f7b3178f606599b929f88fa024a71a4658eed11e Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Fri, 19 Mar 2021 18:56:54 +0530 Subject: [PATCH 05/15] generate rewriter method for slice types Signed-off-by: Harshit Gangal --- go/tools/asthelpergen/asthelpergen.go | 1 - go/tools/asthelpergen/equals_gen.go | 4 -- .../asthelpergen/integration/ast_helper.go | 62 ++++++++++++++++++- .../asthelpergen/integration/test_helpers.go | 28 ++++++--- go/tools/asthelpergen/rewrite_gen.go | 56 ++++++++++++----- go/tools/asthelpergen/visit_gen.go | 14 ----- 6 files changed, 119 insertions(+), 46 deletions(-) diff --git a/go/tools/asthelpergen/asthelpergen.go b/go/tools/asthelpergen/asthelpergen.go index 014675ea244..d75621b8eeb 100644 --- a/go/tools/asthelpergen/asthelpergen.go +++ b/go/tools/asthelpergen/asthelpergen.go @@ -63,7 +63,6 @@ type generator2 interface { structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error ptrToBasicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error - ptrToOtherMethod(t types.Type, ptr *types.Pointer, spi generatorSPI) error sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error } diff --git a/go/tools/asthelpergen/equals_gen.go b/go/tools/asthelpergen/equals_gen.go index 9def354589f..24b89141b5c 100644 --- a/go/tools/asthelpergen/equals_gen.go +++ b/go/tools/asthelpergen/equals_gen.go @@ -235,7 +235,3 @@ func (e equalsGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSP func (e equalsGen) basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error { return nil } - -func (e equalsGen) ptrToOtherMethod(types.Type, *types.Pointer, generatorSPI) error { - return nil -} diff --git a/go/tools/asthelpergen/integration/ast_helper.go b/go/tools/asthelpergen/integration/ast_helper.go index 8c2450b2587..5fe136a2f2f 100644 --- a/go/tools/asthelpergen/integration/ast_helper.go +++ b/go/tools/asthelpergen/integration/ast_helper.go @@ -245,7 +245,21 @@ func VisitBytes(in Bytes, f Visit) error { // rewriteBytes is part of the Rewrite implementation func rewriteBytes(parent AST, node Bytes, replacer replacerFunc, pre, post ApplyFunc) error { - // ptrToStructMethod + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return abortE + } + return nil } // EqualsInterfaceContainer does deep equals between the two objects. @@ -326,7 +340,28 @@ func VisitInterfaceSlice(in InterfaceSlice, f Visit) error { // rewriteInterfaceSlice is part of the Rewrite implementation func rewriteInterfaceSlice(parent AST, node InterfaceSlice, replacer replacerFunc, pre, post ApplyFunc) error { - // ptrToStructMethod + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node { + if errF := rewriteAST(node, el, func(newNode, parent AST) { + parent.(InterfaceSlice)[i] = newNode.(AST) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return abortE + } + return nil } // EqualsRefOfLeaf does deep equals between the two objects. @@ -419,7 +454,28 @@ func VisitLeafSlice(in LeafSlice, f Visit) error { // rewriteLeafSlice is part of the Rewrite implementation func rewriteLeafSlice(parent AST, node LeafSlice, replacer replacerFunc, pre, post ApplyFunc) error { - // ptrToStructMethod + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node { + if errF := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + parent.(LeafSlice)[i] = newNode.(*Leaf) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return abortE + } + return nil } // EqualsRefOfNoCloneType does deep equals between the two objects. diff --git a/go/tools/asthelpergen/integration/test_helpers.go b/go/tools/asthelpergen/integration/test_helpers.go index 3a2da19be80..13757f0e888 100644 --- a/go/tools/asthelpergen/integration/test_helpers.go +++ b/go/tools/asthelpergen/integration/test_helpers.go @@ -19,6 +19,8 @@ package integration import ( "reflect" "strings" + + "vitess.io/vitess/go/vt/log" ) // ast type helpers @@ -78,16 +80,28 @@ func isNilValue(i interface{}) bool { var abort = new(int) // singleton, to signal termination of Apply func Rewrite(node AST, pre, post ApplyFunc) (result AST) { - parent := &struct{ AST }{node} + outer := &struct{ AST }{node} + + if pre == nil { + pre = func(cursor *Cursor) bool { + return true + } + } - a := &application{ - pre: pre, - post: post, - cursor: Cursor{}, + if post == nil { + post = func(cursor *Cursor) bool { + return true + } } - a.apply(parent.AST, node, nil) - return parent.AST + err := rewriteAST(outer, node, func(newNode, parent AST) { + outer.AST = newNode + }, pre, post) + + if err != nil { + log.Fatal(err) + } + return outer.AST } func replacePanic(msg string) func(newNode, parent AST) { diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go index 79c8e7f9b9e..20526a067ee 100644 --- a/go/tools/asthelpergen/rewrite_gen.go +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -157,7 +157,7 @@ func (e rewriteGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generator */ stmts := []jen.Code{ - jen.Comment("ptrToStructMethod"), + jen.Comment("ptrToBasicMethod"), } rewriteFunc(t, stmts, spi) @@ -170,29 +170,51 @@ func (e rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorS } /* - */ - + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + */ stmts := []jen.Code{ - jen.Comment("ptrToStructMethod"), + jen.If(jen.Id("node == nil").Block(returnNil())), + createCursor(), + jen.If(jen.Id("!pre(&cur)")).Block(returnNil()), } - rewriteFunc(t, stmts, spi) - - return nil -} -func (e rewriteGen) ptrToOtherMethod(t types.Type, _ *types.Pointer, spi generatorSPI) error { - if !shouldAdd(t, spi.iface()) { - return nil + if shouldAdd(slice.Elem(), spi.iface()) { + /* + for i, el := range node { + if err := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + parent.(LeafSlice)[i] = newNode.(*Leaf) + }, pre, post); err != nil { + return err + } + } + */ + stmts = append(stmts, + jen.For(jen.Id("i, el").Op(":=").Id("range node")). + Block(rewriteChild(t, slice.Elem(), "notUsed", jen.Id("el"), jen.Index(jen.Id("i")), false))) } - /* - */ + stmts = append(stmts, + /* + if !post(&cur) { + return abortE + } + return nil - stmts := []jen.Code{ - jen.Comment("ptrToStructMethod"), - } + */ + jen.If(jen.Id("!post").Call(jen.Id("&cur"))).Block(jen.Return(jen.Id("abortE"))), + returnNil(), + ) rewriteFunc(t, stmts, spi) - return nil } diff --git a/go/tools/asthelpergen/visit_gen.go b/go/tools/asthelpergen/visit_gen.go index 325d71aceb2..04dddcae0ec 100644 --- a/go/tools/asthelpergen/visit_gen.go +++ b/go/tools/asthelpergen/visit_gen.go @@ -181,20 +181,6 @@ func (e visitGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI return nil } -func (e visitGen) ptrToOtherMethod(t types.Type, _ *types.Pointer, spi generatorSPI) error { - if !shouldAdd(t, spi.iface()) { - return nil - } - - stmts := []jen.Code{ - jen.Comment("ptrToOtherMethod "), - } - - visitFunc(t, stmts, spi) - - return nil -} - func (e visitGen) basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error { if !shouldAdd(t, spi.iface()) { return nil From 8cd1d645b35b64e91bb878481b2451cbe651c13d Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Fri, 19 Mar 2021 15:25:16 +0100 Subject: [PATCH 06/15] replace the sqlparser rewriter with the new one Signed-off-by: Andres Taylor --- go/tools/asthelpergen/asthelpergen.go | 93 +- .../asthelpergen/integration/ast_helper.go | 30 +- .../integration/integration_rewriter_test.go | 110 +- go/tools/asthelpergen/integration/rewriter.go | 102 - .../asthelpergen/integration/test_helpers.go | 29 +- go/tools/asthelpergen/integration/types.go | 2 +- go/tools/asthelpergen/rewrite_gen.go | 83 +- go/tools/asthelpergen/rewriter_gen.go | 209 - go/vt/sqlparser/ast_helper.go | 4815 ++++++++++++++++- go/vt/sqlparser/rewriter.go | 931 ---- go/vt/sqlparser/rewriter_api.go | 74 +- go/vt/sqlparser/rewriter_test.go | 13 - 12 files changed, 4911 insertions(+), 1580 deletions(-) delete mode 100644 go/tools/asthelpergen/integration/rewriter.go delete mode 100644 go/tools/asthelpergen/rewriter_gen.go delete mode 100644 go/vt/sqlparser/rewriter.go diff --git a/go/tools/asthelpergen/asthelpergen.go b/go/tools/asthelpergen/asthelpergen.go index d75621b8eeb..0fbcd19cb8d 100644 --- a/go/tools/asthelpergen/asthelpergen.go +++ b/go/tools/asthelpergen/asthelpergen.go @@ -43,13 +43,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.` -type generator interface { - visitStruct(t types.Type, stroct *types.Struct) error - visitInterface(t types.Type, iface *types.Interface) error - visitSlice(t types.Type, slice *types.Slice) error - createFile(pkgName string) (string, *jen.File) -} - type generatorSPI interface { addType(t types.Type) addFunc(name string, t methodType, code jen.Code) @@ -75,8 +68,7 @@ type astHelperGen struct { sizes types.Sizes namedIface *types.Named _iface *types.Interface - gens []generator - gens2 []generator2 + gens []generator2 methods []jen.Code _scope *types.Scope @@ -87,7 +79,7 @@ func (gen *astHelperGen) iface() *types.Interface { return gen._iface } -func newGenerator(mod *packages.Module, sizes types.Sizes, named *types.Named, generators ...generator) *astHelperGen { +func newGenerator(mod *packages.Module, sizes types.Sizes, named *types.Named, generators ...generator2) *astHelperGen { return &astHelperGen{ DebugTypes: true, mod: mod, @@ -149,73 +141,11 @@ func (gen *astHelperGen) findImplementations(iff *types.Interface, impl func(typ return nil } -func (gen *astHelperGen) visitStruct(t types.Type, stroct *types.Struct) error { - for _, g := range gen.gens { - err := g.visitStruct(t, stroct) - if err != nil { - return err - } - } - return nil -} - -func (gen *astHelperGen) visitSlice(t types.Type, slice *types.Slice) error { - for _, g := range gen.gens { - err := g.visitSlice(t, slice) - if err != nil { - return err - } - } - return nil -} - -func (gen *astHelperGen) visitInterface(t types.Type, iface *types.Interface) error { - for _, g := range gen.gens { - err := g.visitInterface(t, iface) - if err != nil { - return err - } - } - return nil -} - // GenerateCode is the main loop where we build up the code per file. func (gen *astHelperGen) GenerateCode() (map[string]*jen.File, error) { pkg := gen.namedIface.Obj().Pkg() - iface, ok := gen._iface.Underlying().(*types.Interface) - if !ok { - return nil, fmt.Errorf("expected interface, but got %T", gen.iface) - } - - err := findImplementations(pkg.Scope(), iface, func(t types.Type) error { - switch n := t.Underlying().(type) { - case *types.Struct: - return gen.visitStruct(t, n) - case *types.Slice: - return gen.visitSlice(t, n) - case *types.Pointer: - strct, isStrct := n.Elem().Underlying().(*types.Struct) - if isStrct { - return gen.visitStruct(t, strct) - } - case *types.Interface: - return gen.visitInterface(t, n) - default: - // do nothing - } - return nil - }) - - if err != nil { - return nil, err - } result := map[string]*jen.File{} - for _, g := range gen.gens { - file, code := g.createFile(pkg.Name()) - fullPath := path.Join(gen.mod.Dir, strings.TrimPrefix(pkg.Path(), gen.mod.Path), file) - result[fullPath] = code - } gen._scope = pkg.Scope() gen.todo = append(gen.todo, gen.namedIface) @@ -299,17 +229,12 @@ func GenerateASTHelpers(packagePatterns []string, rootIface, exceptCloneType str nt := tt.Type().(*types.Named) - iface := nt.Underlying().(*types.Interface) - - interestingType := func(t types.Type) bool { - return types.Implements(t, iface) - } - rewriter := newRewriterGen(interestingType, nt.Obj().Name()) - generator := newGenerator(loaded[0].Module, loaded[0].TypesSizes, nt, rewriter) - generator.gens2 = append(generator.gens2, &equalsGen{}) - generator.gens2 = append(generator.gens2, newCloneGen(exceptCloneType)) - generator.gens2 = append(generator.gens2, &visitGen{}) - generator.gens2 = append(generator.gens2, &rewriteGen{}) + generator := newGenerator(loaded[0].Module, loaded[0].TypesSizes, nt, + &equalsGen{}, + newCloneGen(exceptCloneType), + &visitGen{}, + &rewriteGen{types.TypeString(nt, noQualifier)}, + ) it, err := generator.GenerateCode() if err != nil { @@ -417,7 +342,7 @@ func (gen *astHelperGen) createFile(pkgName string) (string, *jen.File) { } func (gen *astHelperGen) allGenerators(f func(g generator2) error) { - for _, g := range gen.gens2 { + for _, g := range gen.gens { err := f(g) if err != nil { diff --git a/go/tools/asthelpergen/integration/ast_helper.go b/go/tools/asthelpergen/integration/ast_helper.go index 5fe136a2f2f..ee7ad844879 100644 --- a/go/tools/asthelpergen/integration/ast_helper.go +++ b/go/tools/asthelpergen/integration/ast_helper.go @@ -257,7 +257,7 @@ func rewriteBytes(parent AST, node Bytes, replacer replacerFunc, pre, post Apply return nil } if !post(&cur) { - return abortE + return errAbort } return nil } @@ -295,7 +295,7 @@ func rewriteInterfaceContainer(parent AST, node InterfaceContainer, replacer rep return err } if !post(&cur) { - return abortE + return errAbort } return nil } @@ -359,7 +359,7 @@ func rewriteInterfaceSlice(parent AST, node InterfaceSlice, replacer replacerFun } } if !post(&cur) { - return abortE + return errAbort } return nil } @@ -409,7 +409,7 @@ func rewriteRefOfLeaf(parent AST, node *Leaf, replacer replacerFunc, pre, post A return nil } if !post(&cur) { - return abortE + return errAbort } return nil } @@ -473,7 +473,7 @@ func rewriteLeafSlice(parent AST, node LeafSlice, replacer replacerFunc, pre, po } } if !post(&cur) { - return abortE + return errAbort } return nil } @@ -519,7 +519,7 @@ func rewriteRefOfNoCloneType(parent AST, node *NoCloneType, replacer replacerFun return nil } if !post(&cur) { - return abortE + return errAbort } return nil } @@ -589,7 +589,7 @@ func rewriteRefOfRefContainer(parent AST, node *RefContainer, replacer replacerF return errF } if !post(&cur) { - return abortE + return errAbort } return nil } @@ -668,7 +668,7 @@ func rewriteRefOfRefSliceContainer(parent AST, node *RefSliceContainer, replacer } } if !post(&cur) { - return abortE + return errAbort } return nil } @@ -729,7 +729,7 @@ func rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer replacerFunc, pre, return errF } if !post(&cur) { - return abortE + return errAbort } return nil } @@ -785,7 +785,7 @@ func rewriteValueContainer(parent AST, node ValueContainer, replacer replacerFun return err } if !post(&cur) { - return abortE + return errAbort } return nil } @@ -849,7 +849,7 @@ func rewriteValueSliceContainer(parent AST, node ValueSliceContainer, replacer r return err } if !post(&cur) { - return abortE + return errAbort } return nil } @@ -934,7 +934,7 @@ func rewriteBasicType(parent AST, node BasicType, replacer replacerFunc, pre, po return nil } if !post(&cur) { - return abortE + return errAbort } return nil } @@ -985,7 +985,7 @@ func rewriteRefOfInterfaceContainer(parent AST, node *InterfaceContainer, replac return nil } if !post(&cur) { - return abortE + return errAbort } return nil } @@ -1139,7 +1139,7 @@ func rewriteRefOfValueContainer(parent AST, node *ValueContainer, replacer repla return errF } if !post(&cur) { - return abortE + return errAbort } return nil } @@ -1218,7 +1218,7 @@ func rewriteRefOfValueSliceContainer(parent AST, node *ValueSliceContainer, repl } } if !post(&cur) { - return abortE + return errAbort } return nil } diff --git a/go/tools/asthelpergen/integration/integration_rewriter_test.go b/go/tools/asthelpergen/integration/integration_rewriter_test.go index 9648abfef44..df204dc37f2 100644 --- a/go/tools/asthelpergen/integration/integration_rewriter_test.go +++ b/go/tools/asthelpergen/integration/integration_rewriter_test.go @@ -21,6 +21,8 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" ) @@ -32,7 +34,8 @@ func TestRewriteVisitRefContainer(t *testing.T) { tv := &rewriteTestVisitor{} - Rewrite(containerContainer, tv.pre, tv.post) + _, err := Rewrite(containerContainer, tv.pre, tv.post) + require.NoError(t, err) expected := []step{ Pre{containerContainer}, @@ -55,7 +58,8 @@ func TestRewriteVisitValueContainer(t *testing.T) { tv := &rewriteTestVisitor{} - Rewrite(containerContainer, tv.pre, tv.post) + _, err := Rewrite(containerContainer, tv.pre, tv.post) + require.NoError(t, err) expected := []step{ Pre{containerContainer}, @@ -80,7 +84,8 @@ func TestRewriteVisitRefSliceContainer(t *testing.T) { tv := &rewriteTestVisitor{} - Rewrite(containerContainer, tv.pre, tv.post) + _, err := Rewrite(containerContainer, tv.pre, tv.post) + require.NoError(t, err) tv.assertEquals(t, []step{ Pre{containerContainer}, @@ -108,7 +113,8 @@ func TestRewriteVisitValueSliceContainer(t *testing.T) { tv := &rewriteTestVisitor{} - Rewrite(containerContainer, tv.pre, tv.post) + _, err := Rewrite(containerContainer, tv.pre, tv.post) + require.NoError(t, err) tv.assertEquals(t, []step{ Pre{containerContainer}, @@ -144,7 +150,8 @@ func TestRewriteVisitInterfaceSlice(t *testing.T) { tv := &rewriteTestVisitor{} - Rewrite(ast, tv.pre, tv.post) + _, err := Rewrite(ast, tv.pre, tv.post) + require.NoError(t, err) tv.assertEquals(t, []step{ Pre{ast}, @@ -169,20 +176,22 @@ func TestRewriteVisitRefContainerReplace(t *testing.T) { } // rewrite field of type AST - Rewrite(ast, func(cursor *Cursor) bool { + _, err := Rewrite(ast, func(cursor *Cursor) bool { leaf, ok := cursor.node.(*RefContainer) if ok && leaf.NotASTType == 12 { cursor.Replace(&Leaf{99}) } return true }, nil) + require.NoError(t, err) assert.Equal(t, &RefContainer{ ASTType: &Leaf{99}, ASTImplementationType: &Leaf{2}, }, ast) - Rewrite(ast, rewriteLeaf(2, 55), nil) + _, err = Rewrite(ast, rewriteLeaf(2, 55), nil) + require.NoError(t, err) assert.Equal(t, &RefContainer{ ASTType: &Leaf{99}, @@ -196,13 +205,7 @@ func TestRewriteVisitValueContainerReplace(t *testing.T) { ASTImplementationType: &Leaf{2}, } - defer func() { - if r := recover(); r != nil { - assert.Contains(t, r, "ValueContainer ASTType") - } - }() - - Rewrite(ast, func(cursor *Cursor) bool { + _, err := Rewrite(ast, func(cursor *Cursor) bool { leaf, ok := cursor.node.(ValueContainer) if ok && leaf.NotASTType == 12 { cursor.Replace(&Leaf{99}) @@ -210,7 +213,7 @@ func TestRewriteVisitValueContainerReplace(t *testing.T) { return true }, nil) - t.Fatalf("should not get here") + require.Error(t, err) } func TestRewriteVisitValueContainerReplace2(t *testing.T) { @@ -219,15 +222,8 @@ func TestRewriteVisitValueContainerReplace2(t *testing.T) { ASTImplementationType: &Leaf{2}, } - defer func() { - if r := recover(); r != nil { - assert.Contains(t, r, "ValueContainer ASTImplementationType") - } - }() - - Rewrite(ast, rewriteLeaf(2, 10), nil) - - t.Fatalf("should not get here") + _, err := Rewrite(ast, rewriteLeaf(2, 10), nil) + require.Error(t, err) } func rewriteLeaf(from, to int) func(*Cursor) bool { @@ -246,14 +242,16 @@ func TestRefSliceContainerReplace(t *testing.T) { ASTImplementationElements: []*Leaf{{3}, {4}}, } - Rewrite(ast, rewriteLeaf(2, 42), nil) + _, err := Rewrite(ast, rewriteLeaf(2, 42), nil) + require.NoError(t, err) assert.Equal(t, &RefSliceContainer{ ASTElements: []AST{&Leaf{1}, &Leaf{42}}, ASTImplementationElements: []*Leaf{{3}, {4}}, }, ast) - Rewrite(ast, rewriteLeaf(3, 88), nil) + _, err = Rewrite(ast, rewriteLeaf(3, 88), nil) + require.NoError(t, err) assert.Equal(t, &RefSliceContainer{ ASTElements: []AST{&Leaf{1}, &Leaf{42}}, @@ -327,63 +325,3 @@ func (tv *rewriteTestVisitor) assertEquals(t *testing.T, expected []step) { } } - -// below follows two different ways of creating the replacement method for slices, and benchmark -// between them. Diff seems to be very small, so I'll use the most readable form -type replaceA int - -func (r *replaceA) replace(newNode, container AST) { - container.(InterfaceSlice)[int(*r)] = newNode.(AST) -} - -func (r *replaceA) inc() { - *r++ -} - -func replaceB(idx int) func(AST, AST) { - return func(newNode, container AST) { - container.(InterfaceSlice)[idx] = newNode.(AST) - } -} - -func BenchmarkSliceReplacerA(b *testing.B) { - islice := make(InterfaceSlice, 20) - for i := range islice { - islice[i] = &Leaf{i} - } - a := &application{ - pre: func(c *Cursor) bool { - return true - }, - post: nil, - cursor: Cursor{}, - } - - for i := 0; i < b.N; i++ { - replacer := replaceA(0) - for _, el := range islice { - a.apply(islice, el, replacer.replace) - replacer.inc() - } - } -} - -func BenchmarkSliceReplacerB(b *testing.B) { - islice := make(InterfaceSlice, 20) - for i := range islice { - islice[i] = &Leaf{i} - } - a := &application{ - pre: func(c *Cursor) bool { - return true - }, - post: nil, - cursor: Cursor{}, - } - - for i := 0; i < b.N; i++ { - for x, el := range islice { - a.apply(islice, el, replaceB(x)) - } - } -} diff --git a/go/tools/asthelpergen/integration/rewriter.go b/go/tools/asthelpergen/integration/rewriter.go deleted file mode 100644 index 300ccef16ea..00000000000 --- a/go/tools/asthelpergen/integration/rewriter.go +++ /dev/null @@ -1,102 +0,0 @@ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -// Code generated by ASTHelperGen. DO NOT EDIT. - -package integration - -func (a *application) apply(parent, node AST, replacer replacerFunc) { - if node == nil || isNilValue(node) { - return - } - saved := a.cursor - a.cursor.replacer = replacer - a.cursor.node = node - a.cursor.parent = parent - if a.pre != nil && !a.pre(&a.cursor) { - a.cursor = saved - return - } - switch n := node.(type) { - case Bytes: - case InterfaceContainer: - case InterfaceSlice: - for x, el := range n { - a.apply(node, el, func(idx int) func(AST, AST) { - return func(newNode, container AST) { - container.(InterfaceSlice)[idx] = newNode.(AST) - } - }(x)) - } - case *Leaf: - case LeafSlice: - for x, el := range n { - a.apply(node, el, func(idx int) func(AST, AST) { - return func(newNode, container AST) { - container.(LeafSlice)[idx] = newNode.(*Leaf) - } - }(x)) - } - case *NoCloneType: - case *RefContainer: - a.apply(node, n.ASTType, func(newNode, parent AST) { - parent.(*RefContainer).ASTType = newNode.(AST) - }) - a.apply(node, n.ASTImplementationType, func(newNode, parent AST) { - parent.(*RefContainer).ASTImplementationType = newNode.(*Leaf) - }) - case *RefSliceContainer: - for x, el := range n.ASTElements { - a.apply(node, el, func(idx int) func(AST, AST) { - return func(newNode, container AST) { - container.(*RefSliceContainer).ASTElements[idx] = newNode.(AST) - } - }(x)) - } - for x, el := range n.ASTImplementationElements { - a.apply(node, el, func(idx int) func(AST, AST) { - return func(newNode, container AST) { - container.(*RefSliceContainer).ASTImplementationElements[idx] = newNode.(*Leaf) - } - }(x)) - } - case *SubImpl: - a.apply(node, n.inner, func(newNode, parent AST) { - parent.(*SubImpl).inner = newNode.(SubIface) - }) - case ValueContainer: - a.apply(node, n.ASTType, replacePanic("ValueContainer ASTType")) - a.apply(node, n.ASTImplementationType, replacePanic("ValueContainer ASTImplementationType")) - case ValueSliceContainer: - for x, el := range n.ASTElements { - a.apply(node, el, func(idx int) func(AST, AST) { - return func(newNode, container AST) { - container.(ValueSliceContainer).ASTElements[idx] = newNode.(AST) - } - }(x)) - } - for x, el := range n.ASTImplementationElements { - a.apply(node, el, func(idx int) func(AST, AST) { - return func(newNode, container AST) { - container.(ValueSliceContainer).ASTImplementationElements[idx] = newNode.(*Leaf) - } - }(x)) - } - } - if a.post != nil && !a.post(&a.cursor) { - panic(abort) - } - a.cursor = saved -} diff --git a/go/tools/asthelpergen/integration/test_helpers.go b/go/tools/asthelpergen/integration/test_helpers.go index 13757f0e888..063a1f7a81d 100644 --- a/go/tools/asthelpergen/integration/test_helpers.go +++ b/go/tools/asthelpergen/integration/test_helpers.go @@ -17,10 +17,7 @@ limitations under the License. package integration import ( - "reflect" "strings" - - "vitess.io/vitess/go/vt/log" ) // ast type helpers @@ -42,11 +39,6 @@ func sliceStringLeaf(els ...*Leaf) string { // the methods below are what the generated code expected to be there in the package -type application struct { - pre, post ApplyFunc - cursor Cursor -} - type ApplyFunc func(*Cursor) bool type Cursor struct { @@ -70,16 +62,7 @@ func (c *Cursor) Replace(newNode AST) { type replacerFunc func(newNode, parent AST) -func isNilValue(i interface{}) bool { - valueOf := reflect.ValueOf(i) - kind := valueOf.Kind() - isNullable := kind == reflect.Ptr || kind == reflect.Array || kind == reflect.Slice - return isNullable && valueOf.IsNil() -} - -var abort = new(int) // singleton, to signal termination of Apply - -func Rewrite(node AST, pre, post ApplyFunc) (result AST) { +func Rewrite(node AST, pre, post ApplyFunc) (AST, error) { outer := &struct{ AST }{node} if pre == nil { @@ -99,13 +82,7 @@ func Rewrite(node AST, pre, post ApplyFunc) (result AST) { }, pre, post) if err != nil { - log.Fatal(err) - } - return outer.AST -} - -func replacePanic(msg string) func(newNode, parent AST) { - return func(newNode, parent AST) { - panic("Tried replacing a field of a value type. This is not supported. " + msg) + return nil, err } + return outer.AST, nil } diff --git a/go/tools/asthelpergen/integration/types.go b/go/tools/asthelpergen/integration/types.go index 2233bb7b84e..1e25c50ed75 100644 --- a/go/tools/asthelpergen/integration/types.go +++ b/go/tools/asthelpergen/integration/types.go @@ -174,4 +174,4 @@ func (r *NoCloneType) String() string { type Visit func(node AST) (bool, error) -var abortE = fmt.Errorf("this error is to abort the rewriter, it is not an actual error") +var errAbort = fmt.Errorf("this error is to abort the rewriter, it is not an actual error") diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go index 20526a067ee..458dfd3d7a1 100644 --- a/go/tools/asthelpergen/rewrite_gen.go +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -23,9 +23,14 @@ import ( "github.com/dave/jennifer/jen" ) -const rewriteName = "rewrite" +const ( + rewriteName = "rewrite" + abort = "errAbort" +) -type rewriteGen struct{} +type rewriteGen struct { + ifaceName string +} var _ generator2 = (*rewriteGen)(nil) @@ -75,7 +80,7 @@ func (e rewriteGen) interfaceMethod(t types.Type, iface *types.Interface, spi ge cases..., ))) - rewriteFunc(t, stmts, spi) + e.rewriteFunc(t, stmts, spi) return nil } @@ -89,13 +94,13 @@ func (e rewriteGen) structMethod(t types.Type, strct *types.Struct, spi generato createCursor(), jen.If(jen.Id("!pre(&cur)")).Block(returnNil()), } - stmts = append(stmts, rewriteAllStructFields(t, strct, spi, true)...) + stmts = append(stmts, e.rewriteAllStructFields(t, strct, spi, true)...) stmts = append(stmts, jen.If(jen.Id("err != nil")).Block(jen.Return(jen.Err())), - jen.If(jen.Id("!post").Call(jen.Id("&cur"))).Block(jen.Return(jen.Id("abortE"))), + jen.If(jen.Id("!post").Call(jen.Id("&cur"))).Block(jen.Return(jen.Id(abort))), returnNil(), ) - rewriteFunc(t, stmts, spi) + e.rewriteFunc(t, stmts, spi) return nil } @@ -128,13 +133,13 @@ func (e rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi gen jen.If(jen.Id("!pre(&cur)")).Block(returnNil()), } - stmts = append(stmts, rewriteAllStructFields(t, strct, spi, false)...) + stmts = append(stmts, e.rewriteAllStructFields(t, strct, spi, false)...) stmts = append(stmts, - jen.If(jen.Id("!post").Call(jen.Id("&cur"))).Block(jen.Return(jen.Id("abortE"))), + jen.If(jen.Id("!post").Call(jen.Id("&cur"))).Block(jen.Return(jen.Id(abort))), returnNil(), ) - rewriteFunc(t, stmts, spi) + e.rewriteFunc(t, stmts, spi) return nil } @@ -159,7 +164,7 @@ func (e rewriteGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generator stmts := []jen.Code{ jen.Comment("ptrToBasicMethod"), } - rewriteFunc(t, stmts, spi) + e.rewriteFunc(t, stmts, spi) return nil } @@ -200,25 +205,25 @@ func (e rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorS */ stmts = append(stmts, jen.For(jen.Id("i, el").Op(":=").Id("range node")). - Block(rewriteChild(t, slice.Elem(), "notUsed", jen.Id("el"), jen.Index(jen.Id("i")), false))) + Block(e.rewriteChild(t, slice.Elem(), "notUsed", jen.Id("el"), jen.Index(jen.Id("i")), false))) } stmts = append(stmts, /* if !post(&cur) { - return abortE + return errAbort } return nil */ - jen.If(jen.Id("!post").Call(jen.Id("&cur"))).Block(jen.Return(jen.Id("abortE"))), + jen.If(jen.Id("!post").Call(jen.Id("&cur"))).Block(jen.Return(jen.Id(abort))), returnNil(), ) - rewriteFunc(t, stmts, spi) + e.rewriteFunc(t, stmts, spi) return nil } -func (e rewriteGen) basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error { +func (e rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { if !shouldAdd(t, spi.iface()) { return nil } @@ -226,11 +231,11 @@ func (e rewriteGen) basicMethod(t types.Type, basic *types.Basic, spi generatorS stmts := []jen.Code{ createCursor(), jen.If(jen.Id("!pre(&cur)")).Block(returnNil()), - jen.If(jen.Id("!post").Call(jen.Id("&cur"))).Block(jen.Return(jen.Id("abortE"))), + jen.If(jen.Id("!post").Call(jen.Id("&cur"))).Block(jen.Return(jen.Id(abort))), returnNil(), } - rewriteFunc(t, stmts, spi) + e.rewriteFunc(t, stmts, spi) return nil } @@ -245,12 +250,12 @@ func (e rewriteGen) visitNoChildren(t types.Type, spi generatorSPI) error { stmts := []jen.Code{ jen.Comment("ptrToStructMethod"), } - rewriteFunc(t, stmts, spi) + e.rewriteFunc(t, stmts, spi) return nil } -func rewriteFunc(t types.Type, stmts []jen.Code, spi generatorSPI) { +func (e rewriteGen) rewriteFunc(t types.Type, stmts []jen.Code, spi generatorSPI) { /* func (a *application) rewriteNodeType(parent AST, node NodeType, replacer replacerFunc) { @@ -259,16 +264,14 @@ func rewriteFunc(t types.Type, stmts []jen.Code, spi generatorSPI) { typeString := types.TypeString(t, noQualifier) funcName := fmt.Sprintf("%s%s", rewriteName, printableTypeName(t)) code := jen.Func().Id(funcName).Params( - jen.Id(fmt.Sprintf("parent AST, node %s, replacer replacerFunc, pre, post ApplyFunc", typeString)), + jen.Id(fmt.Sprintf("parent %s, node %s, replacer replacerFunc, pre, post ApplyFunc", e.ifaceName, typeString)), ).Error(). Block(stmts...) spi.addFunc(funcName, rewrite, code) } -func rewriteAllStructFields(t types.Type, strct *types.Struct, spi generatorSPI, fail bool) []jen.Code { - // _, ok := t.Underlying().(*types.Pointer) - +func (e rewriteGen) rewriteAllStructFields(t types.Type, strct *types.Struct, spi generatorSPI, fail bool) []jen.Code { /* if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { err = vterrors.New(vtrpcpb.Code_INTERNAL, "[BUG] tried to replace '%s' on '%s'") @@ -282,7 +285,7 @@ func rewriteAllStructFields(t types.Type, strct *types.Struct, spi generatorSPI, field := strct.Field(i) if types.Implements(field.Type(), spi.iface()) { spi.addType(field.Type()) - output = append(output, rewriteChild(t, field.Type(), field.Name(), jen.Id("node").Dot(field.Name()), jen.Dot(field.Name()), fail)) + output = append(output, e.rewriteChild(t, field.Type(), field.Name(), jen.Id("node").Dot(field.Name()), jen.Dot(field.Name()), fail)) continue } slice, isSlice := field.Type().(*types.Slice) @@ -294,15 +297,13 @@ func rewriteAllStructFields(t types.Type, strct *types.Struct, spi generatorSPI, } output = append(output, jen.For(jen.List(id, jen.Id("el")).Op(":=").Id("range node."+field.Name())). - Block(rewriteChild(t, slice.Elem(), field.Name(), jen.Id("el"), jen.Dot(field.Name()).Index(id), fail))) + Block(e.rewriteChild(t, slice.Elem(), field.Name(), jen.Id("el"), jen.Dot(field.Name()).Index(id), fail))) } } return output } func failReplacer(t types.Type, f string) *jen.Statement { - //err = vterrors.New(vtrpcpb.Code_INTERNAL, "[BUG] tried to replace '%s' on '%s'") - typeString := types.TypeString(t, noQualifier) return jen.Err().Op("=").Qual("vitess.io/vitess/go/vt/vterrors", "New").Call( jen.Qual("vitess.io/vitess/go/vt/proto/vtrpc", "Code_INTERNAL"), @@ -310,19 +311,19 @@ func failReplacer(t types.Type, f string) *jen.Statement { ) } -func rewriteChild(t, field types.Type, fieldName string, param jen.Code, replace jen.Code, fail bool) jen.Code { +func (e rewriteGen) rewriteChild(t, field types.Type, fieldName string, param jen.Code, replace jen.Code, fail bool) jen.Code { /* - if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { - parent.(*RefContainer).ASTType = newNode.(AST) - }, pre, post); errF != nil { - return errF - } + if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { + parent.(*RefContainer).ASTType = newNode.(AST) + }, pre, post); errF != nil { + return errF + } - if errF := rewriteAST(node, el, func(newNode, parent AST) { - parent.(*RefSliceContainer).ASTElements[i] = newNode.(AST) - }, pre, post); errF != nil { - return errF - } + if errF := rewriteAST(node, el, func(newNode, parent AST) { + parent.(*RefSliceContainer).ASTElements[i] = newNode.(AST) + }, pre, post); errF != nil { + return errF + } */ funcName := rewriteName + printableTypeName(field) @@ -337,7 +338,7 @@ func rewriteChild(t, field types.Type, fieldName string, param jen.Code, replace Id("newNode").Assert(jen.Id(types.TypeString(field, noQualifier))) } - funcBlock := jen.Func().Call(jen.Id("newNode, parent AST")). + funcBlock := jen.Func().Call(jen.Id("newNode, parent").Id(e.ifaceName)). Block(replaceOrFail) rewriteField := jen.If( @@ -351,3 +352,7 @@ func rewriteChild(t, field types.Type, fieldName string, param jen.Code, replace return rewriteField } + +var noQualifier = func(p *types.Package) string { + return "" +} diff --git a/go/tools/asthelpergen/rewriter_gen.go b/go/tools/asthelpergen/rewriter_gen.go deleted file mode 100644 index b7bedbe5cc8..00000000000 --- a/go/tools/asthelpergen/rewriter_gen.go +++ /dev/null @@ -1,209 +0,0 @@ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package asthelpergen - -import ( - "go/types" - - "github.com/dave/jennifer/jen" -) - -type rewriterGen struct { - cases []jen.Code - interestingType func(types.Type) bool - ifaceName string -} - -func newRewriterGen(f func(types.Type) bool, name string) *rewriterGen { - return &rewriterGen{interestingType: f, ifaceName: name} -} - -var noQualifier = func(p *types.Package) string { - return "" -} - -func (r *rewriterGen) visitStruct(t types.Type, stroct *types.Struct) error { - typeString := types.TypeString(t, noQualifier) - typeName := printableTypeName(t) - var caseStmts []jen.Code - for i := 0; i < stroct.NumFields(); i++ { - field := stroct.Field(i) - if r.interestingType(field.Type()) { - if _, ok := t.(*types.Pointer); ok { - function := r.createReplaceMethod(typeString, field) - caseStmts = append(caseStmts, caseStmtFor(field, function)) - } else { - caseStmts = append(caseStmts, casePanicStmtFor(field, typeName+" "+field.Name())) - } - } - sliceT, ok := field.Type().(*types.Slice) - if ok && r.interestingType(sliceT.Elem()) { // we have a field containing a slice of interesting elements - function := r.createReplacementMethod(t, sliceT.Elem(), jen.Dot(field.Name())) - caseStmts = append(caseStmts, caseStmtForSliceField(field, function)) - } - } - r.cases = append(r.cases, jen.Case(jen.Id(typeString)).Block(caseStmts...)) - return nil -} - -func (r *rewriterGen) visitInterface(types.Type, *types.Interface) error { - return nil // rewriter doesn't deal with interfaces -} - -func (r *rewriterGen) visitSlice(t types.Type, slice *types.Slice) error { - typeString := types.TypeString(t, noQualifier) - - var stmts []jen.Code - if r.interestingType(slice.Elem()) { - function := r.createReplacementMethod(t, slice.Elem(), jen.Empty()) - stmts = append(stmts, caseStmtForSlice(function)) - } - r.cases = append(r.cases, jen.Case(jen.Id(typeString)).Block(stmts...)) - return nil -} - -func caseStmtFor(field *types.Var, expr jen.Code) *jen.Statement { - // a.apply(node, node.Field, replacerMethod) - return jen.Id("a").Dot("apply").Call(jen.Id("node"), jen.Id("n").Dot(field.Name()), expr) -} - -func casePanicStmtFor(field *types.Var, name string) *jen.Statement { - return jen.Id("a").Dot("apply").Call(jen.Id("node"), jen.Id("n").Dot(field.Name()), jen.Id("replacePanic").Call(jen.Lit(name))) -} - -func caseStmtForSlice(function *jen.Statement) jen.Code { - return jen.For(jen.List(jen.Op("x"), jen.Id("el"))).Op(":=").Range().Id("n").Block( - jen.Id("a").Dot("apply").Call( - jen.Id("node"), - jen.Id("el"), - function, - ), - ) -} - -func caseStmtForSliceField(field *types.Var, function *jen.Statement) jen.Code { - //for x, el := range n { - return jen.For(jen.List(jen.Op("x"), jen.Id("el"))).Op(":=").Range().Id("n").Dot(field.Name()).Block( - jen.Id("a").Dot("apply").Call( - // a.apply(node, el, replaceInterfaceSlice(x)) - jen.Id("node"), - jen.Id("el"), - function, - ), - ) -} - -func (r *rewriterGen) structCase(name string, stroct *types.Struct) (jen.Code, error) { - var stmts []jen.Code - for i := 0; i < stroct.NumFields(); i++ { - field := stroct.Field(i) - if r.interestingType(field.Type()) { - stmts = append(stmts, jen.Id("a").Dot("apply").Call(jen.Id("node"), jen.Id("n").Dot(field.Name()), jen.Nil())) - } - } - return jen.Case(jen.Op("*").Id(name)).Block(stmts...), nil -} - -func (r *rewriterGen) createReplaceMethod(structType string, field *types.Var) jen.Code { - return jen.Func().Params( - jen.Id("newNode"), - jen.Id("parent").Id(r.ifaceName), - ).Block( - jen.Id("parent").Assert(jen.Id(structType)).Dot(field.Name()).Op("=").Id("newNode").Assert(jen.Id(types.TypeString(field.Type(), noQualifier))), - ) -} - -func (r *rewriterGen) createReplacementMethod(container, elem types.Type, x jen.Code) *jen.Statement { - /* - func replacer(idx int) func(AST, AST) { - return func(newnode, container AST) { - container.(InterfaceSlice)[idx] = newnode.(AST) - } - }(x) - */ - return jen.Func().Params(jen.Id("idx").Int()).Func().Params(jen.List(jen.Id(r.ifaceName), jen.Id(r.ifaceName))).Block( - jen.Return(jen.Func().Params(jen.List(jen.Id("newNode"), jen.Id("container")).Id(r.ifaceName))).Block( - jen.Id("container").Assert(jen.Id(types.TypeString(container, noQualifier))).Add(x).Index(jen.Id("idx")).Op("="). - Id("newNode").Assert(jen.Id(types.TypeString(elem, noQualifier))), - ), - ).Call(jen.Id("x")) -} - -func (r *rewriterGen) createFile(pkgName string) (string, *jen.File) { - out := jen.NewFile(pkgName) - out.HeaderComment(licenseFileHeader) - out.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.") - - out.Add( - // func (a *application) apply(parent, node SQLNode, replacer replacerFunc) { - jen.Func().Params( - jen.Id("a").Op("*").Id("application"), - ).Id("apply").Params( - jen.Id("parent"), - jen.Id("node").Id(r.ifaceName), - jen.Id("replacer").Id("replacerFunc"), - ).Block( - /* - if node == nil || isNilValue(node) { - return - } - */ - jen.If( - jen.Id("node").Op("==").Nil().Op("||"). - Id("isNilValue").Call(jen.Id("node"))).Block( - jen.Return(), - ), - /* - saved := a.cursor - a.cursor.replacer = replacer - a.cursor.node = node - a.cursor.parent = parent - */ - jen.Id("saved").Op(":=").Id("a").Dot("cursor"), - jen.Id("a").Dot("cursor").Dot("replacer").Op("=").Id("replacer"), - jen.Id("a").Dot("cursor").Dot("node").Op("=").Id("node"), - jen.Id("a").Dot("cursor").Dot("parent").Op("=").Id("parent"), - jen.If( - jen.Id("a").Dot("pre").Op("!=").Nil().Op("&&"). - Op("!").Id("a").Dot("pre").Call(jen.Op("&").Id("a").Dot("cursor"))).Block( - jen.Id("a").Dot("cursor").Op("=").Id("saved"), - jen.Return(), - ), - - // switch n := node.(type) { - jen.Switch(jen.Id("n").Op(":=").Id("node").Assert(jen.Id("type")).Block( - r.cases..., - )), - - /* - if a.post != nil && !a.post(&a.cursor) { - panic(abort) - } - */ - jen.If( - jen.Id("a").Dot("post").Op("!=").Nil().Op("&&"). - Op("!").Id("a").Dot("post").Call(jen.Op("&").Id("a").Dot("cursor"))).Block( - jen.Id("panic").Call(jen.Id("abort")), - ), - - // a.cursor = saved - jen.Id("a").Dot("cursor").Op("=").Id("saved"), - ), - ) - - return "rewriter.go", out -} diff --git a/go/vt/sqlparser/ast_helper.go b/go/vt/sqlparser/ast_helper.go index 954ecacc58c..417a132779e 100644 --- a/go/vt/sqlparser/ast_helper.go +++ b/go/vt/sqlparser/ast_helper.go @@ -17,6 +17,11 @@ limitations under the License. package sqlparser +import ( + vtrpc "vitess.io/vitess/go/vt/proto/vtrpc" + vterrors "vitess.io/vitess/go/vt/vterrors" +) + // EqualsSQLNode does deep equals between the two objects. func EqualsSQLNode(inA, inB SQLNode) bool { if inA == nil && inB == nil { @@ -1516,6 +1521,310 @@ func VisitSQLNode(in SQLNode, f Visit) error { } } +// rewriteSQLNode is part of the Rewrite implementation +func rewriteSQLNode(parent SQLNode, node SQLNode, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case AccessMode: + return rewriteAccessMode(parent, node, replacer, pre, post) + case *AddColumns: + return rewriteRefOfAddColumns(parent, node, replacer, pre, post) + case *AddConstraintDefinition: + return rewriteRefOfAddConstraintDefinition(parent, node, replacer, pre, post) + case *AddIndexDefinition: + return rewriteRefOfAddIndexDefinition(parent, node, replacer, pre, post) + case AlgorithmValue: + return rewriteAlgorithmValue(parent, node, replacer, pre, post) + case *AliasedExpr: + return rewriteRefOfAliasedExpr(parent, node, replacer, pre, post) + case *AliasedTableExpr: + return rewriteRefOfAliasedTableExpr(parent, node, replacer, pre, post) + case *AlterCharset: + return rewriteRefOfAlterCharset(parent, node, replacer, pre, post) + case *AlterColumn: + return rewriteRefOfAlterColumn(parent, node, replacer, pre, post) + case *AlterDatabase: + return rewriteRefOfAlterDatabase(parent, node, replacer, pre, post) + case *AlterMigration: + return rewriteRefOfAlterMigration(parent, node, replacer, pre, post) + case *AlterTable: + return rewriteRefOfAlterTable(parent, node, replacer, pre, post) + case *AlterView: + return rewriteRefOfAlterView(parent, node, replacer, pre, post) + case *AlterVschema: + return rewriteRefOfAlterVschema(parent, node, replacer, pre, post) + case *AndExpr: + return rewriteRefOfAndExpr(parent, node, replacer, pre, post) + case Argument: + return rewriteArgument(parent, node, replacer, pre, post) + case *AutoIncSpec: + return rewriteRefOfAutoIncSpec(parent, node, replacer, pre, post) + case *Begin: + return rewriteRefOfBegin(parent, node, replacer, pre, post) + case *BinaryExpr: + return rewriteRefOfBinaryExpr(parent, node, replacer, pre, post) + case BoolVal: + return rewriteBoolVal(parent, node, replacer, pre, post) + case *CallProc: + return rewriteRefOfCallProc(parent, node, replacer, pre, post) + case *CaseExpr: + return rewriteRefOfCaseExpr(parent, node, replacer, pre, post) + case *ChangeColumn: + return rewriteRefOfChangeColumn(parent, node, replacer, pre, post) + case *CheckConstraintDefinition: + return rewriteRefOfCheckConstraintDefinition(parent, node, replacer, pre, post) + case ColIdent: + return rewriteColIdent(parent, node, replacer, pre, post) + case *ColName: + return rewriteRefOfColName(parent, node, replacer, pre, post) + case *CollateExpr: + return rewriteRefOfCollateExpr(parent, node, replacer, pre, post) + case *ColumnDefinition: + return rewriteRefOfColumnDefinition(parent, node, replacer, pre, post) + case *ColumnType: + return rewriteRefOfColumnType(parent, node, replacer, pre, post) + case Columns: + return rewriteColumns(parent, node, replacer, pre, post) + case Comments: + return rewriteComments(parent, node, replacer, pre, post) + case *Commit: + return rewriteRefOfCommit(parent, node, replacer, pre, post) + case *ComparisonExpr: + return rewriteRefOfComparisonExpr(parent, node, replacer, pre, post) + case *ConstraintDefinition: + return rewriteRefOfConstraintDefinition(parent, node, replacer, pre, post) + case *ConvertExpr: + return rewriteRefOfConvertExpr(parent, node, replacer, pre, post) + case *ConvertType: + return rewriteRefOfConvertType(parent, node, replacer, pre, post) + case *ConvertUsingExpr: + return rewriteRefOfConvertUsingExpr(parent, node, replacer, pre, post) + case *CreateDatabase: + return rewriteRefOfCreateDatabase(parent, node, replacer, pre, post) + case *CreateTable: + return rewriteRefOfCreateTable(parent, node, replacer, pre, post) + case *CreateView: + return rewriteRefOfCreateView(parent, node, replacer, pre, post) + case *CurTimeFuncExpr: + return rewriteRefOfCurTimeFuncExpr(parent, node, replacer, pre, post) + case *Default: + return rewriteRefOfDefault(parent, node, replacer, pre, post) + case *Delete: + return rewriteRefOfDelete(parent, node, replacer, pre, post) + case *DerivedTable: + return rewriteRefOfDerivedTable(parent, node, replacer, pre, post) + case *DropColumn: + return rewriteRefOfDropColumn(parent, node, replacer, pre, post) + case *DropDatabase: + return rewriteRefOfDropDatabase(parent, node, replacer, pre, post) + case *DropKey: + return rewriteRefOfDropKey(parent, node, replacer, pre, post) + case *DropTable: + return rewriteRefOfDropTable(parent, node, replacer, pre, post) + case *DropView: + return rewriteRefOfDropView(parent, node, replacer, pre, post) + case *ExistsExpr: + return rewriteRefOfExistsExpr(parent, node, replacer, pre, post) + case *ExplainStmt: + return rewriteRefOfExplainStmt(parent, node, replacer, pre, post) + case *ExplainTab: + return rewriteRefOfExplainTab(parent, node, replacer, pre, post) + case Exprs: + return rewriteExprs(parent, node, replacer, pre, post) + case *Flush: + return rewriteRefOfFlush(parent, node, replacer, pre, post) + case *Force: + return rewriteRefOfForce(parent, node, replacer, pre, post) + case *ForeignKeyDefinition: + return rewriteRefOfForeignKeyDefinition(parent, node, replacer, pre, post) + case *FuncExpr: + return rewriteRefOfFuncExpr(parent, node, replacer, pre, post) + case GroupBy: + return rewriteGroupBy(parent, node, replacer, pre, post) + case *GroupConcatExpr: + return rewriteRefOfGroupConcatExpr(parent, node, replacer, pre, post) + case *IndexDefinition: + return rewriteRefOfIndexDefinition(parent, node, replacer, pre, post) + case *IndexHints: + return rewriteRefOfIndexHints(parent, node, replacer, pre, post) + case *IndexInfo: + return rewriteRefOfIndexInfo(parent, node, replacer, pre, post) + case *Insert: + return rewriteRefOfInsert(parent, node, replacer, pre, post) + case *IntervalExpr: + return rewriteRefOfIntervalExpr(parent, node, replacer, pre, post) + case *IsExpr: + return rewriteRefOfIsExpr(parent, node, replacer, pre, post) + case IsolationLevel: + return rewriteIsolationLevel(parent, node, replacer, pre, post) + case JoinCondition: + return rewriteJoinCondition(parent, node, replacer, pre, post) + case *JoinTableExpr: + return rewriteRefOfJoinTableExpr(parent, node, replacer, pre, post) + case *KeyState: + return rewriteRefOfKeyState(parent, node, replacer, pre, post) + case *Limit: + return rewriteRefOfLimit(parent, node, replacer, pre, post) + case ListArg: + return rewriteListArg(parent, node, replacer, pre, post) + case *Literal: + return rewriteRefOfLiteral(parent, node, replacer, pre, post) + case *Load: + return rewriteRefOfLoad(parent, node, replacer, pre, post) + case *LockOption: + return rewriteRefOfLockOption(parent, node, replacer, pre, post) + case *LockTables: + return rewriteRefOfLockTables(parent, node, replacer, pre, post) + case *MatchExpr: + return rewriteRefOfMatchExpr(parent, node, replacer, pre, post) + case *ModifyColumn: + return rewriteRefOfModifyColumn(parent, node, replacer, pre, post) + case *Nextval: + return rewriteRefOfNextval(parent, node, replacer, pre, post) + case *NotExpr: + return rewriteRefOfNotExpr(parent, node, replacer, pre, post) + case *NullVal: + return rewriteRefOfNullVal(parent, node, replacer, pre, post) + case OnDup: + return rewriteOnDup(parent, node, replacer, pre, post) + case *OptLike: + return rewriteRefOfOptLike(parent, node, replacer, pre, post) + case *OrExpr: + return rewriteRefOfOrExpr(parent, node, replacer, pre, post) + case *Order: + return rewriteRefOfOrder(parent, node, replacer, pre, post) + case OrderBy: + return rewriteOrderBy(parent, node, replacer, pre, post) + case *OrderByOption: + return rewriteRefOfOrderByOption(parent, node, replacer, pre, post) + case *OtherAdmin: + return rewriteRefOfOtherAdmin(parent, node, replacer, pre, post) + case *OtherRead: + return rewriteRefOfOtherRead(parent, node, replacer, pre, post) + case *ParenSelect: + return rewriteRefOfParenSelect(parent, node, replacer, pre, post) + case *ParenTableExpr: + return rewriteRefOfParenTableExpr(parent, node, replacer, pre, post) + case *PartitionDefinition: + return rewriteRefOfPartitionDefinition(parent, node, replacer, pre, post) + case *PartitionSpec: + return rewriteRefOfPartitionSpec(parent, node, replacer, pre, post) + case Partitions: + return rewritePartitions(parent, node, replacer, pre, post) + case *RangeCond: + return rewriteRefOfRangeCond(parent, node, replacer, pre, post) + case ReferenceAction: + return rewriteReferenceAction(parent, node, replacer, pre, post) + case *Release: + return rewriteRefOfRelease(parent, node, replacer, pre, post) + case *RenameIndex: + return rewriteRefOfRenameIndex(parent, node, replacer, pre, post) + case *RenameTable: + return rewriteRefOfRenameTable(parent, node, replacer, pre, post) + case *RenameTableName: + return rewriteRefOfRenameTableName(parent, node, replacer, pre, post) + case *RevertMigration: + return rewriteRefOfRevertMigration(parent, node, replacer, pre, post) + case *Rollback: + return rewriteRefOfRollback(parent, node, replacer, pre, post) + case *SRollback: + return rewriteRefOfSRollback(parent, node, replacer, pre, post) + case *Savepoint: + return rewriteRefOfSavepoint(parent, node, replacer, pre, post) + case *Select: + return rewriteRefOfSelect(parent, node, replacer, pre, post) + case SelectExprs: + return rewriteSelectExprs(parent, node, replacer, pre, post) + case *SelectInto: + return rewriteRefOfSelectInto(parent, node, replacer, pre, post) + case *Set: + return rewriteRefOfSet(parent, node, replacer, pre, post) + case *SetExpr: + return rewriteRefOfSetExpr(parent, node, replacer, pre, post) + case SetExprs: + return rewriteSetExprs(parent, node, replacer, pre, post) + case *SetTransaction: + return rewriteRefOfSetTransaction(parent, node, replacer, pre, post) + case *Show: + return rewriteRefOfShow(parent, node, replacer, pre, post) + case *ShowBasic: + return rewriteRefOfShowBasic(parent, node, replacer, pre, post) + case *ShowCreate: + return rewriteRefOfShowCreate(parent, node, replacer, pre, post) + case *ShowFilter: + return rewriteRefOfShowFilter(parent, node, replacer, pre, post) + case *ShowLegacy: + return rewriteRefOfShowLegacy(parent, node, replacer, pre, post) + case *StarExpr: + return rewriteRefOfStarExpr(parent, node, replacer, pre, post) + case *Stream: + return rewriteRefOfStream(parent, node, replacer, pre, post) + case *Subquery: + return rewriteRefOfSubquery(parent, node, replacer, pre, post) + case *SubstrExpr: + return rewriteRefOfSubstrExpr(parent, node, replacer, pre, post) + case TableExprs: + return rewriteTableExprs(parent, node, replacer, pre, post) + case TableIdent: + return rewriteTableIdent(parent, node, replacer, pre, post) + case TableName: + return rewriteTableName(parent, node, replacer, pre, post) + case TableNames: + return rewriteTableNames(parent, node, replacer, pre, post) + case TableOptions: + return rewriteTableOptions(parent, node, replacer, pre, post) + case *TableSpec: + return rewriteRefOfTableSpec(parent, node, replacer, pre, post) + case *TablespaceOperation: + return rewriteRefOfTablespaceOperation(parent, node, replacer, pre, post) + case *TimestampFuncExpr: + return rewriteRefOfTimestampFuncExpr(parent, node, replacer, pre, post) + case *TruncateTable: + return rewriteRefOfTruncateTable(parent, node, replacer, pre, post) + case *UnaryExpr: + return rewriteRefOfUnaryExpr(parent, node, replacer, pre, post) + case *Union: + return rewriteRefOfUnion(parent, node, replacer, pre, post) + case *UnionSelect: + return rewriteRefOfUnionSelect(parent, node, replacer, pre, post) + case *UnlockTables: + return rewriteRefOfUnlockTables(parent, node, replacer, pre, post) + case *Update: + return rewriteRefOfUpdate(parent, node, replacer, pre, post) + case *UpdateExpr: + return rewriteRefOfUpdateExpr(parent, node, replacer, pre, post) + case UpdateExprs: + return rewriteUpdateExprs(parent, node, replacer, pre, post) + case *Use: + return rewriteRefOfUse(parent, node, replacer, pre, post) + case *VStream: + return rewriteRefOfVStream(parent, node, replacer, pre, post) + case ValTuple: + return rewriteValTuple(parent, node, replacer, pre, post) + case *Validation: + return rewriteRefOfValidation(parent, node, replacer, pre, post) + case Values: + return rewriteValues(parent, node, replacer, pre, post) + case *ValuesFuncExpr: + return rewriteRefOfValuesFuncExpr(parent, node, replacer, pre, post) + case VindexParam: + return rewriteVindexParam(parent, node, replacer, pre, post) + case *VindexSpec: + return rewriteRefOfVindexSpec(parent, node, replacer, pre, post) + case *When: + return rewriteRefOfWhen(parent, node, replacer, pre, post) + case *Where: + return rewriteRefOfWhere(parent, node, replacer, pre, post) + case *XorExpr: + return rewriteRefOfXorExpr(parent, node, replacer, pre, post) + default: + // this should never happen + return nil + } +} + // EqualsRefOfAddColumns does deep equals between the two objects. func EqualsRefOfAddColumns(a, b *AddColumns) bool { if a == b { @@ -1563,6 +1872,42 @@ func VisitRefOfAddColumns(in *AddColumns, f Visit) error { return nil } +// rewriteRefOfAddColumns is part of the Rewrite implementation +func rewriteRefOfAddColumns(parent SQLNode, node *AddColumns, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node.Columns { + if errF := rewriteRefOfColumnDefinition(node, el, func(newNode, parent SQLNode) { + parent.(*AddColumns).Columns[i] = newNode.(*ColumnDefinition) + }, pre, post); errF != nil { + return errF + } + } + if errF := rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { + parent.(*AddColumns).First = newNode.(*ColName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { + parent.(*AddColumns).After = newNode.(*ColName) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfAddConstraintDefinition does deep equals between the two objects. func EqualsRefOfAddConstraintDefinition(a, b *AddConstraintDefinition) bool { if a == b { @@ -1598,6 +1943,30 @@ func VisitRefOfAddConstraintDefinition(in *AddConstraintDefinition, f Visit) err return nil } +// rewriteRefOfAddConstraintDefinition is part of the Rewrite implementation +func rewriteRefOfAddConstraintDefinition(parent SQLNode, node *AddConstraintDefinition, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteRefOfConstraintDefinition(node, node.ConstraintDefinition, func(newNode, parent SQLNode) { + parent.(*AddConstraintDefinition).ConstraintDefinition = newNode.(*ConstraintDefinition) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfAddIndexDefinition does deep equals between the two objects. func EqualsRefOfAddIndexDefinition(a, b *AddIndexDefinition) bool { if a == b { @@ -1633,6 +2002,30 @@ func VisitRefOfAddIndexDefinition(in *AddIndexDefinition, f Visit) error { return nil } +// rewriteRefOfAddIndexDefinition is part of the Rewrite implementation +func rewriteRefOfAddIndexDefinition(parent SQLNode, node *AddIndexDefinition, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteRefOfIndexDefinition(node, node.IndexDefinition, func(newNode, parent SQLNode) { + parent.(*AddIndexDefinition).IndexDefinition = newNode.(*IndexDefinition) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfAliasedExpr does deep equals between the two objects. func EqualsRefOfAliasedExpr(a, b *AliasedExpr) bool { if a == b { @@ -1673,6 +2066,35 @@ func VisitRefOfAliasedExpr(in *AliasedExpr, f Visit) error { return nil } +// rewriteRefOfAliasedExpr is part of the Rewrite implementation +func rewriteRefOfAliasedExpr(parent SQLNode, node *AliasedExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*AliasedExpr).Expr = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteColIdent(node, node.As, func(newNode, parent SQLNode) { + parent.(*AliasedExpr).As = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfAliasedTableExpr does deep equals between the two objects. func EqualsRefOfAliasedTableExpr(a, b *AliasedTableExpr) bool { if a == b { @@ -1723,6 +2145,45 @@ func VisitRefOfAliasedTableExpr(in *AliasedTableExpr, f Visit) error { return nil } +// rewriteRefOfAliasedTableExpr is part of the Rewrite implementation +func rewriteRefOfAliasedTableExpr(parent SQLNode, node *AliasedTableExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteSimpleTableExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).Expr = newNode.(SimpleTableExpr) + }, pre, post); errF != nil { + return errF + } + if errF := rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).Partitions = newNode.(Partitions) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteTableIdent(node, node.As, func(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).As = newNode.(TableIdent) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfIndexHints(node, node.Hints, func(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).Hints = newNode.(*IndexHints) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfAlterCharset does deep equals between the two objects. func EqualsRefOfAlterCharset(a, b *AlterCharset) bool { if a == b { @@ -1755,6 +2216,25 @@ func VisitRefOfAlterCharset(in *AlterCharset, f Visit) error { return nil } +// rewriteRefOfAlterCharset is part of the Rewrite implementation +func rewriteRefOfAlterCharset(parent SQLNode, node *AlterCharset, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfAlterColumn does deep equals between the two objects. func EqualsRefOfAlterColumn(a, b *AlterColumn) bool { if a == b { @@ -1796,6 +2276,35 @@ func VisitRefOfAlterColumn(in *AlterColumn, f Visit) error { return nil } +// rewriteRefOfAlterColumn is part of the Rewrite implementation +func rewriteRefOfAlterColumn(parent SQLNode, node *AlterColumn, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteRefOfColName(node, node.Column, func(newNode, parent SQLNode) { + parent.(*AlterColumn).Column = newNode.(*ColName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.DefaultVal, func(newNode, parent SQLNode) { + parent.(*AlterColumn).DefaultVal = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfAlterDatabase does deep equals between the two objects. func EqualsRefOfAlterDatabase(a, b *AlterDatabase) bool { if a == b { @@ -1831,6 +2340,25 @@ func VisitRefOfAlterDatabase(in *AlterDatabase, f Visit) error { return nil } +// rewriteRefOfAlterDatabase is part of the Rewrite implementation +func rewriteRefOfAlterDatabase(parent SQLNode, node *AlterDatabase, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfAlterMigration does deep equals between the two objects. func EqualsRefOfAlterMigration(a, b *AlterMigration) bool { if a == b { @@ -1863,6 +2391,25 @@ func VisitRefOfAlterMigration(in *AlterMigration, f Visit) error { return nil } +// rewriteRefOfAlterMigration is part of the Rewrite implementation +func rewriteRefOfAlterMigration(parent SQLNode, node *AlterMigration, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfAlterTable does deep equals between the two objects. func EqualsRefOfAlterTable(a, b *AlterTable) bool { if a == b { @@ -1911,12 +2458,48 @@ func VisitRefOfAlterTable(in *AlterTable, f Visit) error { return nil } -// EqualsRefOfAlterView does deep equals between the two objects. -func EqualsRefOfAlterView(a, b *AlterView) bool { - if a == b { - return true +// rewriteRefOfAlterTable is part of the Rewrite implementation +func rewriteRefOfAlterTable(parent SQLNode, node *AlterTable, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil } - if a == nil || b == nil { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*AlterTable).Table = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + for i, el := range node.AlterOptions { + if errF := rewriteAlterOption(node, el, func(newNode, parent SQLNode) { + parent.(*AlterTable).AlterOptions[i] = newNode.(AlterOption) + }, pre, post); errF != nil { + return errF + } + } + if errF := rewriteRefOfPartitionSpec(node, node.PartitionSpec, func(newNode, parent SQLNode) { + parent.(*AlterTable).PartitionSpec = newNode.(*PartitionSpec) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + +// EqualsRefOfAlterView does deep equals between the two objects. +func EqualsRefOfAlterView(a, b *AlterView) bool { + if a == b { + return true + } + if a == nil || b == nil { return false } return a.Algorithm == b.Algorithm && @@ -1960,6 +2543,40 @@ func VisitRefOfAlterView(in *AlterView, f Visit) error { return nil } +// rewriteRefOfAlterView is part of the Rewrite implementation +func rewriteRefOfAlterView(parent SQLNode, node *AlterView, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { + parent.(*AlterView).ViewName = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { + parent.(*AlterView).Columns = newNode.(Columns) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + parent.(*AlterView).Select = newNode.(SelectStatement) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfAlterVschema does deep equals between the two objects. func EqualsRefOfAlterVschema(a, b *AlterVschema) bool { if a == b { @@ -2013,6 +2630,47 @@ func VisitRefOfAlterVschema(in *AlterVschema, f Visit) error { return nil } +// rewriteRefOfAlterVschema is part of the Rewrite implementation +func rewriteRefOfAlterVschema(parent SQLNode, node *AlterVschema, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*AlterVschema).Table = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfVindexSpec(node, node.VindexSpec, func(newNode, parent SQLNode) { + parent.(*AlterVschema).VindexSpec = newNode.(*VindexSpec) + }, pre, post); errF != nil { + return errF + } + for i, el := range node.VindexCols { + if errF := rewriteColIdent(node, el, func(newNode, parent SQLNode) { + parent.(*AlterVschema).VindexCols[i] = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + } + if errF := rewriteRefOfAutoIncSpec(node, node.AutoIncSpec, func(newNode, parent SQLNode) { + parent.(*AlterVschema).AutoIncSpec = newNode.(*AutoIncSpec) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfAndExpr does deep equals between the two objects. func EqualsRefOfAndExpr(a, b *AndExpr) bool { if a == b { @@ -2053,6 +2711,35 @@ func VisitRefOfAndExpr(in *AndExpr, f Visit) error { return nil } +// rewriteRefOfAndExpr is part of the Rewrite implementation +func rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*AndExpr).Left = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + parent.(*AndExpr).Right = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfAutoIncSpec does deep equals between the two objects. func EqualsRefOfAutoIncSpec(a, b *AutoIncSpec) bool { if a == b { @@ -2093,6 +2780,35 @@ func VisitRefOfAutoIncSpec(in *AutoIncSpec, f Visit) error { return nil } +// rewriteRefOfAutoIncSpec is part of the Rewrite implementation +func rewriteRefOfAutoIncSpec(parent SQLNode, node *AutoIncSpec, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteColIdent(node, node.Column, func(newNode, parent SQLNode) { + parent.(*AutoIncSpec).Column = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteTableName(node, node.Sequence, func(newNode, parent SQLNode) { + parent.(*AutoIncSpec).Sequence = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfBegin does deep equals between the two objects. func EqualsRefOfBegin(a, b *Begin) bool { if a == b { @@ -2124,6 +2840,25 @@ func VisitRefOfBegin(in *Begin, f Visit) error { return nil } +// rewriteRefOfBegin is part of the Rewrite implementation +func rewriteRefOfBegin(parent SQLNode, node *Begin, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfBinaryExpr does deep equals between the two objects. func EqualsRefOfBinaryExpr(a, b *BinaryExpr) bool { if a == b { @@ -2165,6 +2900,35 @@ func VisitRefOfBinaryExpr(in *BinaryExpr, f Visit) error { return nil } +// rewriteRefOfBinaryExpr is part of the Rewrite implementation +func rewriteRefOfBinaryExpr(parent SQLNode, node *BinaryExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*BinaryExpr).Left = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + parent.(*BinaryExpr).Right = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfCallProc does deep equals between the two objects. func EqualsRefOfCallProc(a, b *CallProc) bool { if a == b { @@ -2205,6 +2969,35 @@ func VisitRefOfCallProc(in *CallProc, f Visit) error { return nil } +// rewriteRefOfCallProc is part of the Rewrite implementation +func rewriteRefOfCallProc(parent SQLNode, node *CallProc, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableName(node, node.Name, func(newNode, parent SQLNode) { + parent.(*CallProc).Name = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExprs(node, node.Params, func(newNode, parent SQLNode) { + parent.(*CallProc).Params = newNode.(Exprs) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfCaseExpr does deep equals between the two objects. func EqualsRefOfCaseExpr(a, b *CaseExpr) bool { if a == b { @@ -2252,6 +3045,42 @@ func VisitRefOfCaseExpr(in *CaseExpr, f Visit) error { return nil } +// rewriteRefOfCaseExpr is part of the Rewrite implementation +func rewriteRefOfCaseExpr(parent SQLNode, node *CaseExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*CaseExpr).Expr = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + for i, el := range node.Whens { + if errF := rewriteRefOfWhen(node, el, func(newNode, parent SQLNode) { + parent.(*CaseExpr).Whens[i] = newNode.(*When) + }, pre, post); errF != nil { + return errF + } + } + if errF := rewriteExpr(node, node.Else, func(newNode, parent SQLNode) { + parent.(*CaseExpr).Else = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfChangeColumn does deep equals between the two objects. func EqualsRefOfChangeColumn(a, b *ChangeColumn) bool { if a == b { @@ -2302,6 +3131,45 @@ func VisitRefOfChangeColumn(in *ChangeColumn, f Visit) error { return nil } +// rewriteRefOfChangeColumn is part of the Rewrite implementation +func rewriteRefOfChangeColumn(parent SQLNode, node *ChangeColumn, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteRefOfColName(node, node.OldColumn, func(newNode, parent SQLNode) { + parent.(*ChangeColumn).OldColumn = newNode.(*ColName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfColumnDefinition(node, node.NewColDefinition, func(newNode, parent SQLNode) { + parent.(*ChangeColumn).NewColDefinition = newNode.(*ColumnDefinition) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { + parent.(*ChangeColumn).First = newNode.(*ColName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { + parent.(*ChangeColumn).After = newNode.(*ColName) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfCheckConstraintDefinition does deep equals between the two objects. func EqualsRefOfCheckConstraintDefinition(a, b *CheckConstraintDefinition) bool { if a == b { @@ -2338,6 +3206,30 @@ func VisitRefOfCheckConstraintDefinition(in *CheckConstraintDefinition, f Visit) return nil } +// rewriteRefOfCheckConstraintDefinition is part of the Rewrite implementation +func rewriteRefOfCheckConstraintDefinition(parent SQLNode, node *CheckConstraintDefinition, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*CheckConstraintDefinition).Expr = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsColIdent does deep equals between the two objects. func EqualsColIdent(a, b ColIdent) bool { return a.val == b.val && @@ -2358,6 +3250,26 @@ func VisitColIdent(in ColIdent, f Visit) error { return nil } +// rewriteColIdent is part of the Rewrite implementation +func rewriteColIdent(parent SQLNode, node ColIdent, replacer replacerFunc, pre, post ApplyFunc) error { + var err error + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if err != nil { + return err + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfColName does deep equals between the two objects. func EqualsRefOfColName(a, b *ColName) bool { if a == b { @@ -2392,6 +3304,35 @@ func VisitRefOfColName(in *ColName, f Visit) error { return nil } +// rewriteRefOfColName is part of the Rewrite implementation +func rewriteRefOfColName(parent SQLNode, node *ColName, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*ColName).Name = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteTableName(node, node.Qualifier, func(newNode, parent SQLNode) { + parent.(*ColName).Qualifier = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfCollateExpr does deep equals between the two objects. func EqualsRefOfCollateExpr(a, b *CollateExpr) bool { if a == b { @@ -2428,6 +3369,30 @@ func VisitRefOfCollateExpr(in *CollateExpr, f Visit) error { return nil } +// rewriteRefOfCollateExpr is part of the Rewrite implementation +func rewriteRefOfCollateExpr(parent SQLNode, node *CollateExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*CollateExpr).Expr = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfColumnDefinition does deep equals between the two objects. func EqualsRefOfColumnDefinition(a, b *ColumnDefinition) bool { if a == b { @@ -2465,6 +3430,30 @@ func VisitRefOfColumnDefinition(in *ColumnDefinition, f Visit) error { return nil } +// rewriteRefOfColumnDefinition is part of the Rewrite implementation +func rewriteRefOfColumnDefinition(parent SQLNode, node *ColumnDefinition, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*ColumnDefinition).Name = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfColumnType does deep equals between the two objects. func EqualsRefOfColumnType(a, b *ColumnType) bool { if a == b { @@ -2514,6 +3503,35 @@ func VisitRefOfColumnType(in *ColumnType, f Visit) error { return nil } +// rewriteRefOfColumnType is part of the Rewrite implementation +func rewriteRefOfColumnType(parent SQLNode, node *ColumnType, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { + parent.(*ColumnType).Length = newNode.(*Literal) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfLiteral(node, node.Scale, func(newNode, parent SQLNode) { + parent.(*ColumnType).Scale = newNode.(*Literal) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsColumns does deep equals between the two objects. func EqualsColumns(a, b Columns) bool { if len(a) != len(b) { @@ -2552,6 +3570,32 @@ func VisitColumns(in Columns, f Visit) error { return nil } +// rewriteColumns is part of the Rewrite implementation +func rewriteColumns(parent SQLNode, node Columns, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node { + if errF := rewriteColIdent(node, el, func(newNode, parent SQLNode) { + parent.(Columns)[i] = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsComments does deep equals between the two objects. func EqualsComments(a, b Comments) bool { if len(a) != len(b) { @@ -2578,6 +3622,25 @@ func VisitComments(in Comments, f Visit) error { return err } +// rewriteComments is part of the Rewrite implementation +func rewriteComments(parent SQLNode, node Comments, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfCommit does deep equals between the two objects. func EqualsRefOfCommit(a, b *Commit) bool { if a == b { @@ -2609,6 +3672,25 @@ func VisitRefOfCommit(in *Commit, f Visit) error { return nil } +// rewriteRefOfCommit is part of the Rewrite implementation +func rewriteRefOfCommit(parent SQLNode, node *Commit, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfComparisonExpr does deep equals between the two objects. func EqualsRefOfComparisonExpr(a, b *ComparisonExpr) bool { if a == b { @@ -2655,6 +3737,40 @@ func VisitRefOfComparisonExpr(in *ComparisonExpr, f Visit) error { return nil } +// rewriteRefOfComparisonExpr is part of the Rewrite implementation +func rewriteRefOfComparisonExpr(parent SQLNode, node *ComparisonExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*ComparisonExpr).Left = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + parent.(*ComparisonExpr).Right = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.Escape, func(newNode, parent SQLNode) { + parent.(*ComparisonExpr).Escape = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfConstraintDefinition does deep equals between the two objects. func EqualsRefOfConstraintDefinition(a, b *ConstraintDefinition) bool { if a == b { @@ -2691,6 +3807,30 @@ func VisitRefOfConstraintDefinition(in *ConstraintDefinition, f Visit) error { return nil } +// rewriteRefOfConstraintDefinition is part of the Rewrite implementation +func rewriteRefOfConstraintDefinition(parent SQLNode, node *ConstraintDefinition, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteConstraintInfo(node, node.Details, func(newNode, parent SQLNode) { + parent.(*ConstraintDefinition).Details = newNode.(ConstraintInfo) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfConvertExpr does deep equals between the two objects. func EqualsRefOfConvertExpr(a, b *ConvertExpr) bool { if a == b { @@ -2731,6 +3871,35 @@ func VisitRefOfConvertExpr(in *ConvertExpr, f Visit) error { return nil } +// rewriteRefOfConvertExpr is part of the Rewrite implementation +func rewriteRefOfConvertExpr(parent SQLNode, node *ConvertExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*ConvertExpr).Expr = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfConvertType(node, node.Type, func(newNode, parent SQLNode) { + parent.(*ConvertExpr).Type = newNode.(*ConvertType) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfConvertType does deep equals between the two objects. func EqualsRefOfConvertType(a, b *ConvertType) bool { if a == b { @@ -2774,6 +3943,35 @@ func VisitRefOfConvertType(in *ConvertType, f Visit) error { return nil } +// rewriteRefOfConvertType is part of the Rewrite implementation +func rewriteRefOfConvertType(parent SQLNode, node *ConvertType, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { + parent.(*ConvertType).Length = newNode.(*Literal) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfLiteral(node, node.Scale, func(newNode, parent SQLNode) { + parent.(*ConvertType).Scale = newNode.(*Literal) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfConvertUsingExpr does deep equals between the two objects. func EqualsRefOfConvertUsingExpr(a, b *ConvertUsingExpr) bool { if a == b { @@ -2810,6 +4008,30 @@ func VisitRefOfConvertUsingExpr(in *ConvertUsingExpr, f Visit) error { return nil } +// rewriteRefOfConvertUsingExpr is part of the Rewrite implementation +func rewriteRefOfConvertUsingExpr(parent SQLNode, node *ConvertUsingExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*ConvertUsingExpr).Expr = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfCreateDatabase does deep equals between the two objects. func EqualsRefOfCreateDatabase(a, b *CreateDatabase) bool { if a == b { @@ -2850,6 +4072,30 @@ func VisitRefOfCreateDatabase(in *CreateDatabase, f Visit) error { return nil } +// rewriteRefOfCreateDatabase is part of the Rewrite implementation +func rewriteRefOfCreateDatabase(parent SQLNode, node *CreateDatabase, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*CreateDatabase).Comments = newNode.(Comments) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfCreateTable does deep equals between the two objects. func EqualsRefOfCreateTable(a, b *CreateTable) bool { if a == b { @@ -2898,6 +4144,40 @@ func VisitRefOfCreateTable(in *CreateTable, f Visit) error { return nil } +// rewriteRefOfCreateTable is part of the Rewrite implementation +func rewriteRefOfCreateTable(parent SQLNode, node *CreateTable, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*CreateTable).Table = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfTableSpec(node, node.TableSpec, func(newNode, parent SQLNode) { + parent.(*CreateTable).TableSpec = newNode.(*TableSpec) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfOptLike(node, node.OptLike, func(newNode, parent SQLNode) { + parent.(*CreateTable).OptLike = newNode.(*OptLike) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfCreateView does deep equals between the two objects. func EqualsRefOfCreateView(a, b *CreateView) bool { if a == b { @@ -2948,6 +4228,40 @@ func VisitRefOfCreateView(in *CreateView, f Visit) error { return nil } +// rewriteRefOfCreateView is part of the Rewrite implementation +func rewriteRefOfCreateView(parent SQLNode, node *CreateView, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { + parent.(*CreateView).ViewName = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { + parent.(*CreateView).Columns = newNode.(Columns) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + parent.(*CreateView).Select = newNode.(SelectStatement) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfCurTimeFuncExpr does deep equals between the two objects. func EqualsRefOfCurTimeFuncExpr(a, b *CurTimeFuncExpr) bool { if a == b { @@ -2988,6 +4302,35 @@ func VisitRefOfCurTimeFuncExpr(in *CurTimeFuncExpr, f Visit) error { return nil } +// rewriteRefOfCurTimeFuncExpr is part of the Rewrite implementation +func rewriteRefOfCurTimeFuncExpr(parent SQLNode, node *CurTimeFuncExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*CurTimeFuncExpr).Name = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.Fsp, func(newNode, parent SQLNode) { + parent.(*CurTimeFuncExpr).Fsp = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfDefault does deep equals between the two objects. func EqualsRefOfDefault(a, b *Default) bool { if a == b { @@ -3019,6 +4362,25 @@ func VisitRefOfDefault(in *Default, f Visit) error { return nil } +// rewriteRefOfDefault is part of the Rewrite implementation +func rewriteRefOfDefault(parent SQLNode, node *Default, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfDelete does deep equals between the two objects. func EqualsRefOfDelete(a, b *Delete) bool { if a == b { @@ -3085,6 +4447,60 @@ func VisitRefOfDelete(in *Delete, f Visit) error { return nil } +// rewriteRefOfDelete is part of the Rewrite implementation +func rewriteRefOfDelete(parent SQLNode, node *Delete, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Delete).Comments = newNode.(Comments) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteTableNames(node, node.Targets, func(newNode, parent SQLNode) { + parent.(*Delete).Targets = newNode.(TableNames) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteTableExprs(node, node.TableExprs, func(newNode, parent SQLNode) { + parent.(*Delete).TableExprs = newNode.(TableExprs) + }, pre, post); errF != nil { + return errF + } + if errF := rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { + parent.(*Delete).Partitions = newNode.(Partitions) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { + parent.(*Delete).Where = newNode.(*Where) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + parent.(*Delete).OrderBy = newNode.(OrderBy) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*Delete).Limit = newNode.(*Limit) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfDerivedTable does deep equals between the two objects. func EqualsRefOfDerivedTable(a, b *DerivedTable) bool { if a == b { @@ -3120,6 +4536,30 @@ func VisitRefOfDerivedTable(in *DerivedTable, f Visit) error { return nil } +// rewriteRefOfDerivedTable is part of the Rewrite implementation +func rewriteRefOfDerivedTable(parent SQLNode, node *DerivedTable, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + parent.(*DerivedTable).Select = newNode.(SelectStatement) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfDropColumn does deep equals between the two objects. func EqualsRefOfDropColumn(a, b *DropColumn) bool { if a == b { @@ -3155,6 +4595,30 @@ func VisitRefOfDropColumn(in *DropColumn, f Visit) error { return nil } +// rewriteRefOfDropColumn is part of the Rewrite implementation +func rewriteRefOfDropColumn(parent SQLNode, node *DropColumn, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { + parent.(*DropColumn).Name = newNode.(*ColName) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfDropDatabase does deep equals between the two objects. func EqualsRefOfDropDatabase(a, b *DropDatabase) bool { if a == b { @@ -3192,6 +4656,30 @@ func VisitRefOfDropDatabase(in *DropDatabase, f Visit) error { return nil } +// rewriteRefOfDropDatabase is part of the Rewrite implementation +func rewriteRefOfDropDatabase(parent SQLNode, node *DropDatabase, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*DropDatabase).Comments = newNode.(Comments) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfDropKey does deep equals between the two objects. func EqualsRefOfDropKey(a, b *DropKey) bool { if a == b { @@ -3224,6 +4712,25 @@ func VisitRefOfDropKey(in *DropKey, f Visit) error { return nil } +// rewriteRefOfDropKey is part of the Rewrite implementation +func rewriteRefOfDropKey(parent SQLNode, node *DropKey, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfDropTable does deep equals between the two objects. func EqualsRefOfDropTable(a, b *DropTable) bool { if a == b { @@ -3261,6 +4768,30 @@ func VisitRefOfDropTable(in *DropTable, f Visit) error { return nil } +// rewriteRefOfDropTable is part of the Rewrite implementation +func rewriteRefOfDropTable(parent SQLNode, node *DropTable, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { + parent.(*DropTable).FromTables = newNode.(TableNames) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfDropView does deep equals between the two objects. func EqualsRefOfDropView(a, b *DropView) bool { if a == b { @@ -3297,6 +4828,30 @@ func VisitRefOfDropView(in *DropView, f Visit) error { return nil } +// rewriteRefOfDropView is part of the Rewrite implementation +func rewriteRefOfDropView(parent SQLNode, node *DropView, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { + parent.(*DropView).FromTables = newNode.(TableNames) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfExistsExpr does deep equals between the two objects. func EqualsRefOfExistsExpr(a, b *ExistsExpr) bool { if a == b { @@ -3332,6 +4887,30 @@ func VisitRefOfExistsExpr(in *ExistsExpr, f Visit) error { return nil } +// rewriteRefOfExistsExpr is part of the Rewrite implementation +func rewriteRefOfExistsExpr(parent SQLNode, node *ExistsExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteRefOfSubquery(node, node.Subquery, func(newNode, parent SQLNode) { + parent.(*ExistsExpr).Subquery = newNode.(*Subquery) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfExplainStmt does deep equals between the two objects. func EqualsRefOfExplainStmt(a, b *ExplainStmt) bool { if a == b { @@ -3368,6 +4947,30 @@ func VisitRefOfExplainStmt(in *ExplainStmt, f Visit) error { return nil } +// rewriteRefOfExplainStmt is part of the Rewrite implementation +func rewriteRefOfExplainStmt(parent SQLNode, node *ExplainStmt, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteStatement(node, node.Statement, func(newNode, parent SQLNode) { + parent.(*ExplainStmt).Statement = newNode.(Statement) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfExplainTab does deep equals between the two objects. func EqualsRefOfExplainTab(a, b *ExplainTab) bool { if a == b { @@ -3404,6 +5007,30 @@ func VisitRefOfExplainTab(in *ExplainTab, f Visit) error { return nil } +// rewriteRefOfExplainTab is part of the Rewrite implementation +func rewriteRefOfExplainTab(parent SQLNode, node *ExplainTab, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*ExplainTab).Table = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsExprs does deep equals between the two objects. func EqualsExprs(a, b Exprs) bool { if len(a) != len(b) { @@ -3442,6 +5069,32 @@ func VisitExprs(in Exprs, f Visit) error { return nil } +// rewriteExprs is part of the Rewrite implementation +func rewriteExprs(parent SQLNode, node Exprs, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node { + if errF := rewriteExpr(node, el, func(newNode, parent SQLNode) { + parent.(Exprs)[i] = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfFlush does deep equals between the two objects. func EqualsRefOfFlush(a, b *Flush) bool { if a == b { @@ -3482,6 +5135,30 @@ func VisitRefOfFlush(in *Flush, f Visit) error { return nil } +// rewriteRefOfFlush is part of the Rewrite implementation +func rewriteRefOfFlush(parent SQLNode, node *Flush, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableNames(node, node.TableNames, func(newNode, parent SQLNode) { + parent.(*Flush).TableNames = newNode.(TableNames) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfForce does deep equals between the two objects. func EqualsRefOfForce(a, b *Force) bool { if a == b { @@ -3513,6 +5190,25 @@ func VisitRefOfForce(in *Force, f Visit) error { return nil } +// rewriteRefOfForce is part of the Rewrite implementation +func rewriteRefOfForce(parent SQLNode, node *Force, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfForeignKeyDefinition does deep equals between the two objects. func EqualsRefOfForeignKeyDefinition(a, b *ForeignKeyDefinition) bool { if a == b { @@ -3566,6 +5262,50 @@ func VisitRefOfForeignKeyDefinition(in *ForeignKeyDefinition, f Visit) error { return nil } +// rewriteRefOfForeignKeyDefinition is part of the Rewrite implementation +func rewriteRefOfForeignKeyDefinition(parent SQLNode, node *ForeignKeyDefinition, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteColumns(node, node.Source, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).Source = newNode.(Columns) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteTableName(node, node.ReferencedTable, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).ReferencedTable = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteColumns(node, node.ReferencedColumns, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).ReferencedColumns = newNode.(Columns) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteReferenceAction(node, node.OnDelete, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).OnDelete = newNode.(ReferenceAction) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteReferenceAction(node, node.OnUpdate, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).OnUpdate = newNode.(ReferenceAction) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfFuncExpr does deep equals between the two objects. func EqualsRefOfFuncExpr(a, b *FuncExpr) bool { if a == b { @@ -3612,6 +5352,40 @@ func VisitRefOfFuncExpr(in *FuncExpr, f Visit) error { return nil } +// rewriteRefOfFuncExpr is part of the Rewrite implementation +func rewriteRefOfFuncExpr(parent SQLNode, node *FuncExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { + parent.(*FuncExpr).Qualifier = newNode.(TableIdent) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*FuncExpr).Name = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteSelectExprs(node, node.Exprs, func(newNode, parent SQLNode) { + parent.(*FuncExpr).Exprs = newNode.(SelectExprs) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsGroupBy does deep equals between the two objects. func EqualsGroupBy(a, b GroupBy) bool { if len(a) != len(b) { @@ -3650,6 +5424,32 @@ func VisitGroupBy(in GroupBy, f Visit) error { return nil } +// rewriteGroupBy is part of the Rewrite implementation +func rewriteGroupBy(parent SQLNode, node GroupBy, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node { + if errF := rewriteExpr(node, el, func(newNode, parent SQLNode) { + parent.(GroupBy)[i] = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfGroupConcatExpr does deep equals between the two objects. func EqualsRefOfGroupConcatExpr(a, b *GroupConcatExpr) bool { if a == b { @@ -3697,6 +5497,40 @@ func VisitRefOfGroupConcatExpr(in *GroupConcatExpr, f Visit) error { return nil } +// rewriteRefOfGroupConcatExpr is part of the Rewrite implementation +func rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupConcatExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteSelectExprs(node, node.Exprs, func(newNode, parent SQLNode) { + parent.(*GroupConcatExpr).Exprs = newNode.(SelectExprs) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + parent.(*GroupConcatExpr).OrderBy = newNode.(OrderBy) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*GroupConcatExpr).Limit = newNode.(*Limit) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfIndexDefinition does deep equals between the two objects. func EqualsRefOfIndexDefinition(a, b *IndexDefinition) bool { if a == b { @@ -3736,15 +5570,39 @@ func VisitRefOfIndexDefinition(in *IndexDefinition, f Visit) error { return nil } -// EqualsRefOfIndexHints does deep equals between the two objects. -func EqualsRefOfIndexHints(a, b *IndexHints) bool { - if a == b { - return true +// rewriteRefOfIndexDefinition is part of the Rewrite implementation +func rewriteRefOfIndexDefinition(parent SQLNode, node *IndexDefinition, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil } - if a == nil || b == nil { - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return a.Type == b.Type && + if !pre(&cur) { + return nil + } + if errF := rewriteRefOfIndexInfo(node, node.Info, func(newNode, parent SQLNode) { + parent.(*IndexDefinition).Info = newNode.(*IndexInfo) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + +// EqualsRefOfIndexHints does deep equals between the two objects. +func EqualsRefOfIndexHints(a, b *IndexHints) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Type == b.Type && EqualsSliceOfColIdent(a.Indexes, b.Indexes) } @@ -3774,6 +5632,32 @@ func VisitRefOfIndexHints(in *IndexHints, f Visit) error { return nil } +// rewriteRefOfIndexHints is part of the Rewrite implementation +func rewriteRefOfIndexHints(parent SQLNode, node *IndexHints, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node.Indexes { + if errF := rewriteColIdent(node, el, func(newNode, parent SQLNode) { + parent.(*IndexHints).Indexes[i] = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfIndexInfo does deep equals between the two objects. func EqualsRefOfIndexInfo(a, b *IndexInfo) bool { if a == b { @@ -3819,6 +5703,35 @@ func VisitRefOfIndexInfo(in *IndexInfo, f Visit) error { return nil } +// rewriteRefOfIndexInfo is part of the Rewrite implementation +func rewriteRefOfIndexInfo(parent SQLNode, node *IndexInfo, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*IndexInfo).Name = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteColIdent(node, node.ConstraintName, func(newNode, parent SQLNode) { + parent.(*IndexInfo).ConstraintName = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfInsert does deep equals between the two objects. func EqualsRefOfInsert(a, b *Insert) bool { if a == b { @@ -3881,6 +5794,55 @@ func VisitRefOfInsert(in *Insert, f Visit) error { return nil } +// rewriteRefOfInsert is part of the Rewrite implementation +func rewriteRefOfInsert(parent SQLNode, node *Insert, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Insert).Comments = newNode.(Comments) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*Insert).Table = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if errF := rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { + parent.(*Insert).Partitions = newNode.(Partitions) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { + parent.(*Insert).Columns = newNode.(Columns) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteInsertRows(node, node.Rows, func(newNode, parent SQLNode) { + parent.(*Insert).Rows = newNode.(InsertRows) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteOnDup(node, node.OnDup, func(newNode, parent SQLNode) { + parent.(*Insert).OnDup = newNode.(OnDup) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfIntervalExpr does deep equals between the two objects. func EqualsRefOfIntervalExpr(a, b *IntervalExpr) bool { if a == b { @@ -3917,6 +5879,30 @@ func VisitRefOfIntervalExpr(in *IntervalExpr, f Visit) error { return nil } +// rewriteRefOfIntervalExpr is part of the Rewrite implementation +func rewriteRefOfIntervalExpr(parent SQLNode, node *IntervalExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*IntervalExpr).Expr = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfIsExpr does deep equals between the two objects. func EqualsRefOfIsExpr(a, b *IsExpr) bool { if a == b { @@ -3953,6 +5939,30 @@ func VisitRefOfIsExpr(in *IsExpr, f Visit) error { return nil } +// rewriteRefOfIsExpr is part of the Rewrite implementation +func rewriteRefOfIsExpr(parent SQLNode, node *IsExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*IsExpr).Expr = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsJoinCondition does deep equals between the two objects. func EqualsJoinCondition(a, b JoinCondition) bool { return EqualsExpr(a.On, b.On) && @@ -3978,6 +5988,36 @@ func VisitJoinCondition(in JoinCondition, f Visit) error { return nil } +// rewriteJoinCondition is part of the Rewrite implementation +func rewriteJoinCondition(parent SQLNode, node JoinCondition, replacer replacerFunc, pre, post ApplyFunc) error { + var err error + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.On, func(newNode, parent SQLNode) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'On' on 'JoinCondition'") + }, pre, post); errF != nil { + return errF + } + if errF := rewriteColumns(node, node.Using, func(newNode, parent SQLNode) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Using' on 'JoinCondition'") + }, pre, post); errF != nil { + return errF + } + if err != nil { + return err + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfJoinTableExpr does deep equals between the two objects. func EqualsRefOfJoinTableExpr(a, b *JoinTableExpr) bool { if a == b { @@ -4024,6 +6064,40 @@ func VisitRefOfJoinTableExpr(in *JoinTableExpr, f Visit) error { return nil } +// rewriteRefOfJoinTableExpr is part of the Rewrite implementation +func rewriteRefOfJoinTableExpr(parent SQLNode, node *JoinTableExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableExpr(node, node.LeftExpr, func(newNode, parent SQLNode) { + parent.(*JoinTableExpr).LeftExpr = newNode.(TableExpr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteTableExpr(node, node.RightExpr, func(newNode, parent SQLNode) { + parent.(*JoinTableExpr).RightExpr = newNode.(TableExpr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteJoinCondition(node, node.Condition, func(newNode, parent SQLNode) { + parent.(*JoinTableExpr).Condition = newNode.(JoinCondition) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfKeyState does deep equals between the two objects. func EqualsRefOfKeyState(a, b *KeyState) bool { if a == b { @@ -4055,6 +6129,25 @@ func VisitRefOfKeyState(in *KeyState, f Visit) error { return nil } +// rewriteRefOfKeyState is part of the Rewrite implementation +func rewriteRefOfKeyState(parent SQLNode, node *KeyState, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfLimit does deep equals between the two objects. func EqualsRefOfLimit(a, b *Limit) bool { if a == b { @@ -4095,6 +6188,35 @@ func VisitRefOfLimit(in *Limit, f Visit) error { return nil } +// rewriteRefOfLimit is part of the Rewrite implementation +func rewriteRefOfLimit(parent SQLNode, node *Limit, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Offset, func(newNode, parent SQLNode) { + parent.(*Limit).Offset = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.Rowcount, func(newNode, parent SQLNode) { + parent.(*Limit).Rowcount = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsListArg does deep equals between the two objects. func EqualsListArg(a, b ListArg) bool { if len(a) != len(b) { @@ -4121,6 +6243,25 @@ func VisitListArg(in ListArg, f Visit) error { return err } +// rewriteListArg is part of the Rewrite implementation +func rewriteListArg(parent SQLNode, node ListArg, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfLiteral does deep equals between the two objects. func EqualsRefOfLiteral(a, b *Literal) bool { if a == b { @@ -4153,6 +6294,25 @@ func VisitRefOfLiteral(in *Literal, f Visit) error { return nil } +// rewriteRefOfLiteral is part of the Rewrite implementation +func rewriteRefOfLiteral(parent SQLNode, node *Literal, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfLoad does deep equals between the two objects. func EqualsRefOfLoad(a, b *Load) bool { if a == b { @@ -4184,6 +6344,25 @@ func VisitRefOfLoad(in *Load, f Visit) error { return nil } +// rewriteRefOfLoad is part of the Rewrite implementation +func rewriteRefOfLoad(parent SQLNode, node *Load, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfLockOption does deep equals between the two objects. func EqualsRefOfLockOption(a, b *LockOption) bool { if a == b { @@ -4215,6 +6394,25 @@ func VisitRefOfLockOption(in *LockOption, f Visit) error { return nil } +// rewriteRefOfLockOption is part of the Rewrite implementation +func rewriteRefOfLockOption(parent SQLNode, node *LockOption, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfLockTables does deep equals between the two objects. func EqualsRefOfLockTables(a, b *LockTables) bool { if a == b { @@ -4247,6 +6445,25 @@ func VisitRefOfLockTables(in *LockTables, f Visit) error { return nil } +// rewriteRefOfLockTables is part of the Rewrite implementation +func rewriteRefOfLockTables(parent SQLNode, node *LockTables, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfMatchExpr does deep equals between the two objects. func EqualsRefOfMatchExpr(a, b *MatchExpr) bool { if a == b { @@ -4288,6 +6505,35 @@ func VisitRefOfMatchExpr(in *MatchExpr, f Visit) error { return nil } +// rewriteRefOfMatchExpr is part of the Rewrite implementation +func rewriteRefOfMatchExpr(parent SQLNode, node *MatchExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteSelectExprs(node, node.Columns, func(newNode, parent SQLNode) { + parent.(*MatchExpr).Columns = newNode.(SelectExprs) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*MatchExpr).Expr = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfModifyColumn does deep equals between the two objects. func EqualsRefOfModifyColumn(a, b *ModifyColumn) bool { if a == b { @@ -4333,6 +6579,40 @@ func VisitRefOfModifyColumn(in *ModifyColumn, f Visit) error { return nil } +// rewriteRefOfModifyColumn is part of the Rewrite implementation +func rewriteRefOfModifyColumn(parent SQLNode, node *ModifyColumn, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteRefOfColumnDefinition(node, node.NewColDefinition, func(newNode, parent SQLNode) { + parent.(*ModifyColumn).NewColDefinition = newNode.(*ColumnDefinition) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { + parent.(*ModifyColumn).First = newNode.(*ColName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { + parent.(*ModifyColumn).After = newNode.(*ColName) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfNextval does deep equals between the two objects. func EqualsRefOfNextval(a, b *Nextval) bool { if a == b { @@ -4368,6 +6648,30 @@ func VisitRefOfNextval(in *Nextval, f Visit) error { return nil } +// rewriteRefOfNextval is part of the Rewrite implementation +func rewriteRefOfNextval(parent SQLNode, node *Nextval, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*Nextval).Expr = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfNotExpr does deep equals between the two objects. func EqualsRefOfNotExpr(a, b *NotExpr) bool { if a == b { @@ -4403,6 +6707,30 @@ func VisitRefOfNotExpr(in *NotExpr, f Visit) error { return nil } +// rewriteRefOfNotExpr is part of the Rewrite implementation +func rewriteRefOfNotExpr(parent SQLNode, node *NotExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*NotExpr).Expr = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfNullVal does deep equals between the two objects. func EqualsRefOfNullVal(a, b *NullVal) bool { if a == b { @@ -4434,6 +6762,25 @@ func VisitRefOfNullVal(in *NullVal, f Visit) error { return nil } +// rewriteRefOfNullVal is part of the Rewrite implementation +func rewriteRefOfNullVal(parent SQLNode, node *NullVal, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsOnDup does deep equals between the two objects. func EqualsOnDup(a, b OnDup) bool { if len(a) != len(b) { @@ -4472,6 +6819,32 @@ func VisitOnDup(in OnDup, f Visit) error { return nil } +// rewriteOnDup is part of the Rewrite implementation +func rewriteOnDup(parent SQLNode, node OnDup, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node { + if errF := rewriteRefOfUpdateExpr(node, el, func(newNode, parent SQLNode) { + parent.(OnDup)[i] = newNode.(*UpdateExpr) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfOptLike does deep equals between the two objects. func EqualsRefOfOptLike(a, b *OptLike) bool { if a == b { @@ -4507,6 +6880,30 @@ func VisitRefOfOptLike(in *OptLike, f Visit) error { return nil } +// rewriteRefOfOptLike is part of the Rewrite implementation +func rewriteRefOfOptLike(parent SQLNode, node *OptLike, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableName(node, node.LikeTable, func(newNode, parent SQLNode) { + parent.(*OptLike).LikeTable = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfOrExpr does deep equals between the two objects. func EqualsRefOfOrExpr(a, b *OrExpr) bool { if a == b { @@ -4547,6 +6944,35 @@ func VisitRefOfOrExpr(in *OrExpr, f Visit) error { return nil } +// rewriteRefOfOrExpr is part of the Rewrite implementation +func rewriteRefOfOrExpr(parent SQLNode, node *OrExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*OrExpr).Left = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + parent.(*OrExpr).Right = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfOrder does deep equals between the two objects. func EqualsRefOfOrder(a, b *Order) bool { if a == b { @@ -4583,6 +7009,30 @@ func VisitRefOfOrder(in *Order, f Visit) error { return nil } +// rewriteRefOfOrder is part of the Rewrite implementation +func rewriteRefOfOrder(parent SQLNode, node *Order, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*Order).Expr = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsOrderBy does deep equals between the two objects. func EqualsOrderBy(a, b OrderBy) bool { if len(a) != len(b) { @@ -4621,6 +7071,32 @@ func VisitOrderBy(in OrderBy, f Visit) error { return nil } +// rewriteOrderBy is part of the Rewrite implementation +func rewriteOrderBy(parent SQLNode, node OrderBy, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node { + if errF := rewriteRefOfOrder(node, el, func(newNode, parent SQLNode) { + parent.(OrderBy)[i] = newNode.(*Order) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfOrderByOption does deep equals between the two objects. func EqualsRefOfOrderByOption(a, b *OrderByOption) bool { if a == b { @@ -4656,6 +7132,30 @@ func VisitRefOfOrderByOption(in *OrderByOption, f Visit) error { return nil } +// rewriteRefOfOrderByOption is part of the Rewrite implementation +func rewriteRefOfOrderByOption(parent SQLNode, node *OrderByOption, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteColumns(node, node.Cols, func(newNode, parent SQLNode) { + parent.(*OrderByOption).Cols = newNode.(Columns) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfOtherAdmin does deep equals between the two objects. func EqualsRefOfOtherAdmin(a, b *OtherAdmin) bool { if a == b { @@ -4687,6 +7187,25 @@ func VisitRefOfOtherAdmin(in *OtherAdmin, f Visit) error { return nil } +// rewriteRefOfOtherAdmin is part of the Rewrite implementation +func rewriteRefOfOtherAdmin(parent SQLNode, node *OtherAdmin, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfOtherRead does deep equals between the two objects. func EqualsRefOfOtherRead(a, b *OtherRead) bool { if a == b { @@ -4718,6 +7237,25 @@ func VisitRefOfOtherRead(in *OtherRead, f Visit) error { return nil } +// rewriteRefOfOtherRead is part of the Rewrite implementation +func rewriteRefOfOtherRead(parent SQLNode, node *OtherRead, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfParenSelect does deep equals between the two objects. func EqualsRefOfParenSelect(a, b *ParenSelect) bool { if a == b { @@ -4753,6 +7291,30 @@ func VisitRefOfParenSelect(in *ParenSelect, f Visit) error { return nil } +// rewriteRefOfParenSelect is part of the Rewrite implementation +func rewriteRefOfParenSelect(parent SQLNode, node *ParenSelect, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + parent.(*ParenSelect).Select = newNode.(SelectStatement) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfParenTableExpr does deep equals between the two objects. func EqualsRefOfParenTableExpr(a, b *ParenTableExpr) bool { if a == b { @@ -4788,6 +7350,30 @@ func VisitRefOfParenTableExpr(in *ParenTableExpr, f Visit) error { return nil } +// rewriteRefOfParenTableExpr is part of the Rewrite implementation +func rewriteRefOfParenTableExpr(parent SQLNode, node *ParenTableExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableExprs(node, node.Exprs, func(newNode, parent SQLNode) { + parent.(*ParenTableExpr).Exprs = newNode.(TableExprs) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfPartitionDefinition does deep equals between the two objects. func EqualsRefOfPartitionDefinition(a, b *PartitionDefinition) bool { if a == b { @@ -4829,6 +7415,35 @@ func VisitRefOfPartitionDefinition(in *PartitionDefinition, f Visit) error { return nil } +// rewriteRefOfPartitionDefinition is part of the Rewrite implementation +func rewriteRefOfPartitionDefinition(parent SQLNode, node *PartitionDefinition, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*PartitionDefinition).Name = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*PartitionDefinition).Limit = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfPartitionSpec does deep equals between the two objects. func EqualsRefOfPartitionSpec(a, b *PartitionSpec) bool { if a == b { @@ -4884,6 +7499,47 @@ func VisitRefOfPartitionSpec(in *PartitionSpec, f Visit) error { return nil } +// rewriteRefOfPartitionSpec is part of the Rewrite implementation +func rewriteRefOfPartitionSpec(parent SQLNode, node *PartitionSpec, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewritePartitions(node, node.Names, func(newNode, parent SQLNode) { + parent.(*PartitionSpec).Names = newNode.(Partitions) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfLiteral(node, node.Number, func(newNode, parent SQLNode) { + parent.(*PartitionSpec).Number = newNode.(*Literal) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteTableName(node, node.TableName, func(newNode, parent SQLNode) { + parent.(*PartitionSpec).TableName = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + for i, el := range node.Definitions { + if errF := rewriteRefOfPartitionDefinition(node, el, func(newNode, parent SQLNode) { + parent.(*PartitionSpec).Definitions[i] = newNode.(*PartitionDefinition) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsPartitions does deep equals between the two objects. func EqualsPartitions(a, b Partitions) bool { if len(a) != len(b) { @@ -4922,6 +7578,32 @@ func VisitPartitions(in Partitions, f Visit) error { return nil } +// rewritePartitions is part of the Rewrite implementation +func rewritePartitions(parent SQLNode, node Partitions, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node { + if errF := rewriteColIdent(node, el, func(newNode, parent SQLNode) { + parent.(Partitions)[i] = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfRangeCond does deep equals between the two objects. func EqualsRefOfRangeCond(a, b *RangeCond) bool { if a == b { @@ -4968,6 +7650,40 @@ func VisitRefOfRangeCond(in *RangeCond, f Visit) error { return nil } +// rewriteRefOfRangeCond is part of the Rewrite implementation +func rewriteRefOfRangeCond(parent SQLNode, node *RangeCond, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*RangeCond).Left = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.From, func(newNode, parent SQLNode) { + parent.(*RangeCond).From = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.To, func(newNode, parent SQLNode) { + parent.(*RangeCond).To = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfRelease does deep equals between the two objects. func EqualsRefOfRelease(a, b *Release) bool { if a == b { @@ -5003,6 +7719,30 @@ func VisitRefOfRelease(in *Release, f Visit) error { return nil } +// rewriteRefOfRelease is part of the Rewrite implementation +func rewriteRefOfRelease(parent SQLNode, node *Release, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*Release).Name = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfRenameIndex does deep equals between the two objects. func EqualsRefOfRenameIndex(a, b *RenameIndex) bool { if a == b { @@ -5035,6 +7775,25 @@ func VisitRefOfRenameIndex(in *RenameIndex, f Visit) error { return nil } +// rewriteRefOfRenameIndex is part of the Rewrite implementation +func rewriteRefOfRenameIndex(parent SQLNode, node *RenameIndex, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfRenameTable does deep equals between the two objects. func EqualsRefOfRenameTable(a, b *RenameTable) bool { if a == b { @@ -5067,6 +7826,25 @@ func VisitRefOfRenameTable(in *RenameTable, f Visit) error { return nil } +// rewriteRefOfRenameTable is part of the Rewrite implementation +func rewriteRefOfRenameTable(parent SQLNode, node *RenameTable, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfRenameTableName does deep equals between the two objects. func EqualsRefOfRenameTableName(a, b *RenameTableName) bool { if a == b { @@ -5102,6 +7880,30 @@ func VisitRefOfRenameTableName(in *RenameTableName, f Visit) error { return nil } +// rewriteRefOfRenameTableName is part of the Rewrite implementation +func rewriteRefOfRenameTableName(parent SQLNode, node *RenameTableName, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*RenameTableName).Table = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfRevertMigration does deep equals between the two objects. func EqualsRefOfRevertMigration(a, b *RevertMigration) bool { if a == b { @@ -5133,6 +7935,25 @@ func VisitRefOfRevertMigration(in *RevertMigration, f Visit) error { return nil } +// rewriteRefOfRevertMigration is part of the Rewrite implementation +func rewriteRefOfRevertMigration(parent SQLNode, node *RevertMigration, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfRollback does deep equals between the two objects. func EqualsRefOfRollback(a, b *Rollback) bool { if a == b { @@ -5164,6 +7985,25 @@ func VisitRefOfRollback(in *Rollback, f Visit) error { return nil } +// rewriteRefOfRollback is part of the Rewrite implementation +func rewriteRefOfRollback(parent SQLNode, node *Rollback, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfSRollback does deep equals between the two objects. func EqualsRefOfSRollback(a, b *SRollback) bool { if a == b { @@ -5199,6 +8039,30 @@ func VisitRefOfSRollback(in *SRollback, f Visit) error { return nil } +// rewriteRefOfSRollback is part of the Rewrite implementation +func rewriteRefOfSRollback(parent SQLNode, node *SRollback, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*SRollback).Name = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfSavepoint does deep equals between the two objects. func EqualsRefOfSavepoint(a, b *Savepoint) bool { if a == b { @@ -5234,6 +8098,30 @@ func VisitRefOfSavepoint(in *Savepoint, f Visit) error { return nil } +// rewriteRefOfSavepoint is part of the Rewrite implementation +func rewriteRefOfSavepoint(parent SQLNode, node *Savepoint, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*Savepoint).Name = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfSelect does deep equals between the two objects. func EqualsRefOfSelect(a, b *Select) bool { if a == b { @@ -5315,6 +8203,70 @@ func VisitRefOfSelect(in *Select, f Visit) error { return nil } +// rewriteRefOfSelect is part of the Rewrite implementation +func rewriteRefOfSelect(parent SQLNode, node *Select, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Select).Comments = newNode.(Comments) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteSelectExprs(node, node.SelectExprs, func(newNode, parent SQLNode) { + parent.(*Select).SelectExprs = newNode.(SelectExprs) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteTableExprs(node, node.From, func(newNode, parent SQLNode) { + parent.(*Select).From = newNode.(TableExprs) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { + parent.(*Select).Where = newNode.(*Where) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteGroupBy(node, node.GroupBy, func(newNode, parent SQLNode) { + parent.(*Select).GroupBy = newNode.(GroupBy) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfWhere(node, node.Having, func(newNode, parent SQLNode) { + parent.(*Select).Having = newNode.(*Where) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + parent.(*Select).OrderBy = newNode.(OrderBy) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*Select).Limit = newNode.(*Limit) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfSelectInto(node, node.Into, func(newNode, parent SQLNode) { + parent.(*Select).Into = newNode.(*SelectInto) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsSelectExprs does deep equals between the two objects. func EqualsSelectExprs(a, b SelectExprs) bool { if len(a) != len(b) { @@ -5353,6 +8305,32 @@ func VisitSelectExprs(in SelectExprs, f Visit) error { return nil } +// rewriteSelectExprs is part of the Rewrite implementation +func rewriteSelectExprs(parent SQLNode, node SelectExprs, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node { + if errF := rewriteSelectExpr(node, el, func(newNode, parent SQLNode) { + parent.(SelectExprs)[i] = newNode.(SelectExpr) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfSelectInto does deep equals between the two objects. func EqualsRefOfSelectInto(a, b *SelectInto) bool { if a == b { @@ -5390,6 +8368,25 @@ func VisitRefOfSelectInto(in *SelectInto, f Visit) error { return nil } +// rewriteRefOfSelectInto is part of the Rewrite implementation +func rewriteRefOfSelectInto(parent SQLNode, node *SelectInto, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfSet does deep equals between the two objects. func EqualsRefOfSet(a, b *Set) bool { if a == b { @@ -5430,6 +8427,35 @@ func VisitRefOfSet(in *Set, f Visit) error { return nil } +// rewriteRefOfSet is part of the Rewrite implementation +func rewriteRefOfSet(parent SQLNode, node *Set, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Set).Comments = newNode.(Comments) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteSetExprs(node, node.Exprs, func(newNode, parent SQLNode) { + parent.(*Set).Exprs = newNode.(SetExprs) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfSetExpr does deep equals between the two objects. func EqualsRefOfSetExpr(a, b *SetExpr) bool { if a == b { @@ -5471,6 +8497,35 @@ func VisitRefOfSetExpr(in *SetExpr, f Visit) error { return nil } +// rewriteRefOfSetExpr is part of the Rewrite implementation +func rewriteRefOfSetExpr(parent SQLNode, node *SetExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*SetExpr).Name = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*SetExpr).Expr = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsSetExprs does deep equals between the two objects. func EqualsSetExprs(a, b SetExprs) bool { if len(a) != len(b) { @@ -5509,6 +8564,32 @@ func VisitSetExprs(in SetExprs, f Visit) error { return nil } +// rewriteSetExprs is part of the Rewrite implementation +func rewriteSetExprs(parent SQLNode, node SetExprs, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node { + if errF := rewriteRefOfSetExpr(node, el, func(newNode, parent SQLNode) { + parent.(SetExprs)[i] = newNode.(*SetExpr) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfSetTransaction does deep equals between the two objects. func EqualsRefOfSetTransaction(a, b *SetTransaction) bool { if a == b { @@ -5557,6 +8638,42 @@ func VisitRefOfSetTransaction(in *SetTransaction, f Visit) error { return nil } +// rewriteRefOfSetTransaction is part of the Rewrite implementation +func rewriteRefOfSetTransaction(parent SQLNode, node *SetTransaction, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteSQLNode(node, node.SQLNode, func(newNode, parent SQLNode) { + parent.(*SetTransaction).SQLNode = newNode.(SQLNode) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*SetTransaction).Comments = newNode.(Comments) + }, pre, post); errF != nil { + return errF + } + for i, el := range node.Characteristics { + if errF := rewriteCharacteristic(node, el, func(newNode, parent SQLNode) { + parent.(*SetTransaction).Characteristics[i] = newNode.(Characteristic) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfShow does deep equals between the two objects. func EqualsRefOfShow(a, b *Show) bool { if a == b { @@ -5592,6 +8709,30 @@ func VisitRefOfShow(in *Show, f Visit) error { return nil } +// rewriteRefOfShow is part of the Rewrite implementation +func rewriteRefOfShow(parent SQLNode, node *Show, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteShowInternal(node, node.Internal, func(newNode, parent SQLNode) { + parent.(*Show).Internal = newNode.(ShowInternal) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfShowBasic does deep equals between the two objects. func EqualsRefOfShowBasic(a, b *ShowBasic) bool { if a == b { @@ -5635,6 +8776,35 @@ func VisitRefOfShowBasic(in *ShowBasic, f Visit) error { return nil } +// rewriteRefOfShowBasic is part of the Rewrite implementation +func rewriteRefOfShowBasic(parent SQLNode, node *ShowBasic, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableName(node, node.Tbl, func(newNode, parent SQLNode) { + parent.(*ShowBasic).Tbl = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfShowFilter(node, node.Filter, func(newNode, parent SQLNode) { + parent.(*ShowBasic).Filter = newNode.(*ShowFilter) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfShowCreate does deep equals between the two objects. func EqualsRefOfShowCreate(a, b *ShowCreate) bool { if a == b { @@ -5671,6 +8841,30 @@ func VisitRefOfShowCreate(in *ShowCreate, f Visit) error { return nil } +// rewriteRefOfShowCreate is part of the Rewrite implementation +func rewriteRefOfShowCreate(parent SQLNode, node *ShowCreate, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableName(node, node.Op, func(newNode, parent SQLNode) { + parent.(*ShowCreate).Op = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfShowFilter does deep equals between the two objects. func EqualsRefOfShowFilter(a, b *ShowFilter) bool { if a == b { @@ -5707,6 +8901,30 @@ func VisitRefOfShowFilter(in *ShowFilter, f Visit) error { return nil } +// rewriteRefOfShowFilter is part of the Rewrite implementation +func rewriteRefOfShowFilter(parent SQLNode, node *ShowFilter, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Filter, func(newNode, parent SQLNode) { + parent.(*ShowFilter).Filter = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfShowLegacy does deep equals between the two objects. func EqualsRefOfShowLegacy(a, b *ShowLegacy) bool { if a == b { @@ -5757,6 +8975,40 @@ func VisitRefOfShowLegacy(in *ShowLegacy, f Visit) error { return nil } +// rewriteRefOfShowLegacy is part of the Rewrite implementation +func rewriteRefOfShowLegacy(parent SQLNode, node *ShowLegacy, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableName(node, node.OnTable, func(newNode, parent SQLNode) { + parent.(*ShowLegacy).OnTable = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*ShowLegacy).Table = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.ShowCollationFilterOpt, func(newNode, parent SQLNode) { + parent.(*ShowLegacy).ShowCollationFilterOpt = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfStarExpr does deep equals between the two objects. func EqualsRefOfStarExpr(a, b *StarExpr) bool { if a == b { @@ -5792,6 +9044,30 @@ func VisitRefOfStarExpr(in *StarExpr, f Visit) error { return nil } +// rewriteRefOfStarExpr is part of the Rewrite implementation +func rewriteRefOfStarExpr(parent SQLNode, node *StarExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableName(node, node.TableName, func(newNode, parent SQLNode) { + parent.(*StarExpr).TableName = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfStream does deep equals between the two objects. func EqualsRefOfStream(a, b *Stream) bool { if a == b { @@ -5837,6 +9113,40 @@ func VisitRefOfStream(in *Stream, f Visit) error { return nil } +// rewriteRefOfStream is part of the Rewrite implementation +func rewriteRefOfStream(parent SQLNode, node *Stream, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Stream).Comments = newNode.(Comments) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteSelectExpr(node, node.SelectExpr, func(newNode, parent SQLNode) { + parent.(*Stream).SelectExpr = newNode.(SelectExpr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*Stream).Table = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfSubquery does deep equals between the two objects. func EqualsRefOfSubquery(a, b *Subquery) bool { if a == b { @@ -5872,6 +9182,30 @@ func VisitRefOfSubquery(in *Subquery, f Visit) error { return nil } +// rewriteRefOfSubquery is part of the Rewrite implementation +func rewriteRefOfSubquery(parent SQLNode, node *Subquery, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + parent.(*Subquery).Select = newNode.(SelectStatement) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfSubstrExpr does deep equals between the two objects. func EqualsRefOfSubstrExpr(a, b *SubstrExpr) bool { if a == b { @@ -5922,6 +9256,45 @@ func VisitRefOfSubstrExpr(in *SubstrExpr, f Visit) error { return nil } +// rewriteRefOfSubstrExpr is part of the Rewrite implementation +func rewriteRefOfSubstrExpr(parent SQLNode, node *SubstrExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { + parent.(*SubstrExpr).Name = newNode.(*ColName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfLiteral(node, node.StrVal, func(newNode, parent SQLNode) { + parent.(*SubstrExpr).StrVal = newNode.(*Literal) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.From, func(newNode, parent SQLNode) { + parent.(*SubstrExpr).From = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.To, func(newNode, parent SQLNode) { + parent.(*SubstrExpr).To = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsTableExprs does deep equals between the two objects. func EqualsTableExprs(a, b TableExprs) bool { if len(a) != len(b) { @@ -5960,6 +9333,32 @@ func VisitTableExprs(in TableExprs, f Visit) error { return nil } +// rewriteTableExprs is part of the Rewrite implementation +func rewriteTableExprs(parent SQLNode, node TableExprs, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node { + if errF := rewriteTableExpr(node, el, func(newNode, parent SQLNode) { + parent.(TableExprs)[i] = newNode.(TableExpr) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsTableIdent does deep equals between the two objects. func EqualsTableIdent(a, b TableIdent) bool { return a.v == b.v @@ -5978,6 +9377,26 @@ func VisitTableIdent(in TableIdent, f Visit) error { return nil } +// rewriteTableIdent is part of the Rewrite implementation +func rewriteTableIdent(parent SQLNode, node TableIdent, replacer replacerFunc, pre, post ApplyFunc) error { + var err error + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if err != nil { + return err + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsTableName does deep equals between the two objects. func EqualsTableName(a, b TableName) bool { return EqualsTableIdent(a.Name, b.Name) && @@ -6003,6 +9422,36 @@ func VisitTableName(in TableName, f Visit) error { return nil } +// rewriteTableName is part of the Rewrite implementation +func rewriteTableName(parent SQLNode, node TableName, replacer replacerFunc, pre, post ApplyFunc) error { + var err error + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Name' on 'TableName'") + }, pre, post); errF != nil { + return errF + } + if errF := rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Qualifier' on 'TableName'") + }, pre, post); errF != nil { + return errF + } + if err != nil { + return err + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsTableNames does deep equals between the two objects. func EqualsTableNames(a, b TableNames) bool { if len(a) != len(b) { @@ -6041,6 +9490,32 @@ func VisitTableNames(in TableNames, f Visit) error { return nil } +// rewriteTableNames is part of the Rewrite implementation +func rewriteTableNames(parent SQLNode, node TableNames, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node { + if errF := rewriteTableName(node, el, func(newNode, parent SQLNode) { + parent.(TableNames)[i] = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsTableOptions does deep equals between the two objects. func EqualsTableOptions(a, b TableOptions) bool { if len(a) != len(b) { @@ -6069,6 +9544,25 @@ func VisitTableOptions(in TableOptions, f Visit) error { return err } +// rewriteTableOptions is part of the Rewrite implementation +func rewriteTableOptions(parent SQLNode, node TableOptions, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfTableSpec does deep equals between the two objects. func EqualsRefOfTableSpec(a, b *TableSpec) bool { if a == b { @@ -6125,6 +9619,51 @@ func VisitRefOfTableSpec(in *TableSpec, f Visit) error { return nil } +// rewriteRefOfTableSpec is part of the Rewrite implementation +func rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node.Columns { + if errF := rewriteRefOfColumnDefinition(node, el, func(newNode, parent SQLNode) { + parent.(*TableSpec).Columns[i] = newNode.(*ColumnDefinition) + }, pre, post); errF != nil { + return errF + } + } + for i, el := range node.Indexes { + if errF := rewriteRefOfIndexDefinition(node, el, func(newNode, parent SQLNode) { + parent.(*TableSpec).Indexes[i] = newNode.(*IndexDefinition) + }, pre, post); errF != nil { + return errF + } + } + for i, el := range node.Constraints { + if errF := rewriteRefOfConstraintDefinition(node, el, func(newNode, parent SQLNode) { + parent.(*TableSpec).Constraints[i] = newNode.(*ConstraintDefinition) + }, pre, post); errF != nil { + return errF + } + } + if errF := rewriteTableOptions(node, node.Options, func(newNode, parent SQLNode) { + parent.(*TableSpec).Options = newNode.(TableOptions) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfTablespaceOperation does deep equals between the two objects. func EqualsRefOfTablespaceOperation(a, b *TablespaceOperation) bool { if a == b { @@ -6156,6 +9695,25 @@ func VisitRefOfTablespaceOperation(in *TablespaceOperation, f Visit) error { return nil } +// rewriteRefOfTablespaceOperation is part of the Rewrite implementation +func rewriteRefOfTablespaceOperation(parent SQLNode, node *TablespaceOperation, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfTimestampFuncExpr does deep equals between the two objects. func EqualsRefOfTimestampFuncExpr(a, b *TimestampFuncExpr) bool { if a == b { @@ -6198,6 +9756,35 @@ func VisitRefOfTimestampFuncExpr(in *TimestampFuncExpr, f Visit) error { return nil } +// rewriteRefOfTimestampFuncExpr is part of the Rewrite implementation +func rewriteRefOfTimestampFuncExpr(parent SQLNode, node *TimestampFuncExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Expr1, func(newNode, parent SQLNode) { + parent.(*TimestampFuncExpr).Expr1 = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.Expr2, func(newNode, parent SQLNode) { + parent.(*TimestampFuncExpr).Expr2 = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfTruncateTable does deep equals between the two objects. func EqualsRefOfTruncateTable(a, b *TruncateTable) bool { if a == b { @@ -6233,6 +9820,30 @@ func VisitRefOfTruncateTable(in *TruncateTable, f Visit) error { return nil } +// rewriteRefOfTruncateTable is part of the Rewrite implementation +func rewriteRefOfTruncateTable(parent SQLNode, node *TruncateTable, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*TruncateTable).Table = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfUnaryExpr does deep equals between the two objects. func EqualsRefOfUnaryExpr(a, b *UnaryExpr) bool { if a == b { @@ -6269,6 +9880,30 @@ func VisitRefOfUnaryExpr(in *UnaryExpr, f Visit) error { return nil } +// rewriteRefOfUnaryExpr is part of the Rewrite implementation +func rewriteRefOfUnaryExpr(parent SQLNode, node *UnaryExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*UnaryExpr).Expr = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfUnion does deep equals between the two objects. func EqualsRefOfUnion(a, b *Union) bool { if a == b { @@ -6322,6 +9957,47 @@ func VisitRefOfUnion(in *Union, f Visit) error { return nil } +// rewriteRefOfUnion is part of the Rewrite implementation +func rewriteRefOfUnion(parent SQLNode, node *Union, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteSelectStatement(node, node.FirstStatement, func(newNode, parent SQLNode) { + parent.(*Union).FirstStatement = newNode.(SelectStatement) + }, pre, post); errF != nil { + return errF + } + for i, el := range node.UnionSelects { + if errF := rewriteRefOfUnionSelect(node, el, func(newNode, parent SQLNode) { + parent.(*Union).UnionSelects[i] = newNode.(*UnionSelect) + }, pre, post); errF != nil { + return errF + } + } + if errF := rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + parent.(*Union).OrderBy = newNode.(OrderBy) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*Union).Limit = newNode.(*Limit) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfUnionSelect does deep equals between the two objects. func EqualsRefOfUnionSelect(a, b *UnionSelect) bool { if a == b { @@ -6349,11 +10025,35 @@ func VisitRefOfUnionSelect(in *UnionSelect, f Visit) error { if in == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitSelectStatement(in.Statement, f); err != nil { + return err + } + return nil +} + +// rewriteRefOfUnionSelect is part of the Rewrite implementation +func rewriteRefOfUnionSelect(parent SQLNode, node *UnionSelect, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteSelectStatement(node, node.Statement, func(newNode, parent SQLNode) { + parent.(*UnionSelect).Statement = newNode.(SelectStatement) + }, pre, post); errF != nil { + return errF } - if err := VisitSelectStatement(in.Statement, f); err != nil { - return err + if !post(&cur) { + return errAbort } return nil } @@ -6389,6 +10089,25 @@ func VisitRefOfUnlockTables(in *UnlockTables, f Visit) error { return nil } +// rewriteRefOfUnlockTables is part of the Rewrite implementation +func rewriteRefOfUnlockTables(parent SQLNode, node *UnlockTables, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfUpdate does deep equals between the two objects. func EqualsRefOfUpdate(a, b *Update) bool { if a == b { @@ -6450,6 +10169,55 @@ func VisitRefOfUpdate(in *Update, f Visit) error { return nil } +// rewriteRefOfUpdate is part of the Rewrite implementation +func rewriteRefOfUpdate(parent SQLNode, node *Update, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Update).Comments = newNode.(Comments) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteTableExprs(node, node.TableExprs, func(newNode, parent SQLNode) { + parent.(*Update).TableExprs = newNode.(TableExprs) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteUpdateExprs(node, node.Exprs, func(newNode, parent SQLNode) { + parent.(*Update).Exprs = newNode.(UpdateExprs) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { + parent.(*Update).Where = newNode.(*Where) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + parent.(*Update).OrderBy = newNode.(OrderBy) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*Update).Limit = newNode.(*Limit) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfUpdateExpr does deep equals between the two objects. func EqualsRefOfUpdateExpr(a, b *UpdateExpr) bool { if a == b { @@ -6490,6 +10258,35 @@ func VisitRefOfUpdateExpr(in *UpdateExpr, f Visit) error { return nil } +// rewriteRefOfUpdateExpr is part of the Rewrite implementation +func rewriteRefOfUpdateExpr(parent SQLNode, node *UpdateExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { + parent.(*UpdateExpr).Name = newNode.(*ColName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*UpdateExpr).Expr = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsUpdateExprs does deep equals between the two objects. func EqualsUpdateExprs(a, b UpdateExprs) bool { if len(a) != len(b) { @@ -6528,6 +10325,32 @@ func VisitUpdateExprs(in UpdateExprs, f Visit) error { return nil } +// rewriteUpdateExprs is part of the Rewrite implementation +func rewriteUpdateExprs(parent SQLNode, node UpdateExprs, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node { + if errF := rewriteRefOfUpdateExpr(node, el, func(newNode, parent SQLNode) { + parent.(UpdateExprs)[i] = newNode.(*UpdateExpr) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfUse does deep equals between the two objects. func EqualsRefOfUse(a, b *Use) bool { if a == b { @@ -6563,6 +10386,30 @@ func VisitRefOfUse(in *Use, f Visit) error { return nil } +// rewriteRefOfUse is part of the Rewrite implementation +func rewriteRefOfUse(parent SQLNode, node *Use, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableIdent(node, node.DBName, func(newNode, parent SQLNode) { + parent.(*Use).DBName = newNode.(TableIdent) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfVStream does deep equals between the two objects. func EqualsRefOfVStream(a, b *VStream) bool { if a == b { @@ -6618,6 +10465,50 @@ func VisitRefOfVStream(in *VStream, f Visit) error { return nil } +// rewriteRefOfVStream is part of the Rewrite implementation +func rewriteRefOfVStream(parent SQLNode, node *VStream, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*VStream).Comments = newNode.(Comments) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteSelectExpr(node, node.SelectExpr, func(newNode, parent SQLNode) { + parent.(*VStream).SelectExpr = newNode.(SelectExpr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*VStream).Table = newNode.(TableName) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { + parent.(*VStream).Where = newNode.(*Where) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*VStream).Limit = newNode.(*Limit) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsValTuple does deep equals between the two objects. func EqualsValTuple(a, b ValTuple) bool { if len(a) != len(b) { @@ -6656,6 +10547,32 @@ func VisitValTuple(in ValTuple, f Visit) error { return nil } +// rewriteValTuple is part of the Rewrite implementation +func rewriteValTuple(parent SQLNode, node ValTuple, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node { + if errF := rewriteExpr(node, el, func(newNode, parent SQLNode) { + parent.(ValTuple)[i] = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfValidation does deep equals between the two objects. func EqualsRefOfValidation(a, b *Validation) bool { if a == b { @@ -6687,6 +10604,25 @@ func VisitRefOfValidation(in *Validation, f Visit) error { return nil } +// rewriteRefOfValidation is part of the Rewrite implementation +func rewriteRefOfValidation(parent SQLNode, node *Validation, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsValues does deep equals between the two objects. func EqualsValues(a, b Values) bool { if len(a) != len(b) { @@ -6725,6 +10661,32 @@ func VisitValues(in Values, f Visit) error { return nil } +// rewriteValues is part of the Rewrite implementation +func rewriteValues(parent SQLNode, node Values, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + for i, el := range node { + if errF := rewriteValTuple(node, el, func(newNode, parent SQLNode) { + parent.(Values)[i] = newNode.(ValTuple) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfValuesFuncExpr does deep equals between the two objects. func EqualsRefOfValuesFuncExpr(a, b *ValuesFuncExpr) bool { if a == b { @@ -6760,6 +10722,30 @@ func VisitRefOfValuesFuncExpr(in *ValuesFuncExpr, f Visit) error { return nil } +// rewriteRefOfValuesFuncExpr is part of the Rewrite implementation +func rewriteRefOfValuesFuncExpr(parent SQLNode, node *ValuesFuncExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { + parent.(*ValuesFuncExpr).Name = newNode.(*ColName) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsVindexParam does deep equals between the two objects. func EqualsVindexParam(a, b VindexParam) bool { return a.Val == b.Val && @@ -6782,6 +10768,31 @@ func VisitVindexParam(in VindexParam, f Visit) error { return nil } +// rewriteVindexParam is part of the Rewrite implementation +func rewriteVindexParam(parent SQLNode, node VindexParam, replacer replacerFunc, pre, post ApplyFunc) error { + var err error + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Key' on 'VindexParam'") + }, pre, post); errF != nil { + return errF + } + if err != nil { + return err + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfVindexSpec does deep equals between the two objects. func EqualsRefOfVindexSpec(a, b *VindexSpec) bool { if a == b { @@ -6829,6 +10840,42 @@ func VisitRefOfVindexSpec(in *VindexSpec, f Visit) error { return nil } +// rewriteRefOfVindexSpec is part of the Rewrite implementation +func rewriteRefOfVindexSpec(parent SQLNode, node *VindexSpec, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*VindexSpec).Name = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteColIdent(node, node.Type, func(newNode, parent SQLNode) { + parent.(*VindexSpec).Type = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + for i, el := range node.Params { + if errF := rewriteVindexParam(node, el, func(newNode, parent SQLNode) { + parent.(*VindexSpec).Params[i] = newNode.(VindexParam) + }, pre, post); errF != nil { + return errF + } + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfWhen does deep equals between the two objects. func EqualsRefOfWhen(a, b *When) bool { if a == b { @@ -6869,6 +10916,35 @@ func VisitRefOfWhen(in *When, f Visit) error { return nil } +// rewriteRefOfWhen is part of the Rewrite implementation +func rewriteRefOfWhen(parent SQLNode, node *When, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Cond, func(newNode, parent SQLNode) { + parent.(*When).Cond = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.Val, func(newNode, parent SQLNode) { + parent.(*When).Val = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfWhere does deep equals between the two objects. func EqualsRefOfWhere(a, b *Where) bool { if a == b { @@ -6905,6 +10981,30 @@ func VisitRefOfWhere(in *Where, f Visit) error { return nil } +// rewriteRefOfWhere is part of the Rewrite implementation +func rewriteRefOfWhere(parent SQLNode, node *Where, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*Where).Expr = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfXorExpr does deep equals between the two objects. func EqualsRefOfXorExpr(a, b *XorExpr) bool { if a == b { @@ -6945,6 +11045,35 @@ func VisitRefOfXorExpr(in *XorExpr, f Visit) error { return nil } +// rewriteRefOfXorExpr is part of the Rewrite implementation +func rewriteRefOfXorExpr(parent SQLNode, node *XorExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*XorExpr).Left = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + parent.(*XorExpr).Right = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsAlterOption does deep equals between the two objects. func EqualsAlterOption(inA, inB AlterOption) bool { if inA == nil && inB == nil { @@ -7174,6 +11303,56 @@ func VisitAlterOption(in AlterOption, f Visit) error { } } +// rewriteAlterOption is part of the Rewrite implementation +func rewriteAlterOption(parent SQLNode, node AlterOption, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *AddColumns: + return rewriteRefOfAddColumns(parent, node, replacer, pre, post) + case *AddConstraintDefinition: + return rewriteRefOfAddConstraintDefinition(parent, node, replacer, pre, post) + case *AddIndexDefinition: + return rewriteRefOfAddIndexDefinition(parent, node, replacer, pre, post) + case AlgorithmValue: + return rewriteAlgorithmValue(parent, node, replacer, pre, post) + case *AlterCharset: + return rewriteRefOfAlterCharset(parent, node, replacer, pre, post) + case *AlterColumn: + return rewriteRefOfAlterColumn(parent, node, replacer, pre, post) + case *ChangeColumn: + return rewriteRefOfChangeColumn(parent, node, replacer, pre, post) + case *DropColumn: + return rewriteRefOfDropColumn(parent, node, replacer, pre, post) + case *DropKey: + return rewriteRefOfDropKey(parent, node, replacer, pre, post) + case *Force: + return rewriteRefOfForce(parent, node, replacer, pre, post) + case *KeyState: + return rewriteRefOfKeyState(parent, node, replacer, pre, post) + case *LockOption: + return rewriteRefOfLockOption(parent, node, replacer, pre, post) + case *ModifyColumn: + return rewriteRefOfModifyColumn(parent, node, replacer, pre, post) + case *OrderByOption: + return rewriteRefOfOrderByOption(parent, node, replacer, pre, post) + case *RenameIndex: + return rewriteRefOfRenameIndex(parent, node, replacer, pre, post) + case *RenameTableName: + return rewriteRefOfRenameTableName(parent, node, replacer, pre, post) + case TableOptions: + return rewriteTableOptions(parent, node, replacer, pre, post) + case *TablespaceOperation: + return rewriteRefOfTablespaceOperation(parent, node, replacer, pre, post) + case *Validation: + return rewriteRefOfValidation(parent, node, replacer, pre, post) + default: + // this should never happen + return nil + } +} + // EqualsCharacteristic does deep equals between the two objects. func EqualsCharacteristic(inA, inB Characteristic) bool { if inA == nil && inB == nil { @@ -7233,6 +11412,22 @@ func VisitCharacteristic(in Characteristic, f Visit) error { } } +// rewriteCharacteristic is part of the Rewrite implementation +func rewriteCharacteristic(parent SQLNode, node Characteristic, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case AccessMode: + return rewriteAccessMode(parent, node, replacer, pre, post) + case IsolationLevel: + return rewriteIsolationLevel(parent, node, replacer, pre, post) + default: + // this should never happen + return nil + } +} + // EqualsColTuple does deep equals between the two objects. func EqualsColTuple(inA, inB ColTuple) bool { if inA == nil && inB == nil { @@ -7302,6 +11497,24 @@ func VisitColTuple(in ColTuple, f Visit) error { } } +// rewriteColTuple is part of the Rewrite implementation +func rewriteColTuple(parent SQLNode, node ColTuple, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case ListArg: + return rewriteListArg(parent, node, replacer, pre, post) + case *Subquery: + return rewriteRefOfSubquery(parent, node, replacer, pre, post) + case ValTuple: + return rewriteValTuple(parent, node, replacer, pre, post) + default: + // this should never happen + return nil + } +} + // EqualsConstraintInfo does deep equals between the two objects. func EqualsConstraintInfo(inA, inB ConstraintInfo) bool { if inA == nil && inB == nil { @@ -7361,6 +11574,22 @@ func VisitConstraintInfo(in ConstraintInfo, f Visit) error { } } +// rewriteConstraintInfo is part of the Rewrite implementation +func rewriteConstraintInfo(parent SQLNode, node ConstraintInfo, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *CheckConstraintDefinition: + return rewriteRefOfCheckConstraintDefinition(parent, node, replacer, pre, post) + case *ForeignKeyDefinition: + return rewriteRefOfForeignKeyDefinition(parent, node, replacer, pre, post) + default: + // this should never happen + return nil + } +} + // EqualsDBDDLStatement does deep equals between the two objects. func EqualsDBDDLStatement(inA, inB DBDDLStatement) bool { if inA == nil && inB == nil { @@ -7430,6 +11659,24 @@ func VisitDBDDLStatement(in DBDDLStatement, f Visit) error { } } +// rewriteDBDDLStatement is part of the Rewrite implementation +func rewriteDBDDLStatement(parent SQLNode, node DBDDLStatement, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *AlterDatabase: + return rewriteRefOfAlterDatabase(parent, node, replacer, pre, post) + case *CreateDatabase: + return rewriteRefOfCreateDatabase(parent, node, replacer, pre, post) + case *DropDatabase: + return rewriteRefOfDropDatabase(parent, node, replacer, pre, post) + default: + // this should never happen + return nil + } +} + // EqualsDDLStatement does deep equals between the two objects. func EqualsDDLStatement(inA, inB DDLStatement) bool { if inA == nil && inB == nil { @@ -7549,6 +11796,34 @@ func VisitDDLStatement(in DDLStatement, f Visit) error { } } +// rewriteDDLStatement is part of the Rewrite implementation +func rewriteDDLStatement(parent SQLNode, node DDLStatement, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *AlterTable: + return rewriteRefOfAlterTable(parent, node, replacer, pre, post) + case *AlterView: + return rewriteRefOfAlterView(parent, node, replacer, pre, post) + case *CreateTable: + return rewriteRefOfCreateTable(parent, node, replacer, pre, post) + case *CreateView: + return rewriteRefOfCreateView(parent, node, replacer, pre, post) + case *DropTable: + return rewriteRefOfDropTable(parent, node, replacer, pre, post) + case *DropView: + return rewriteRefOfDropView(parent, node, replacer, pre, post) + case *RenameTable: + return rewriteRefOfRenameTable(parent, node, replacer, pre, post) + case *TruncateTable: + return rewriteRefOfTruncateTable(parent, node, replacer, pre, post) + default: + // this should never happen + return nil + } +} + // EqualsExplain does deep equals between the two objects. func EqualsExplain(inA, inB Explain) bool { if inA == nil && inB == nil { @@ -7608,6 +11883,22 @@ func VisitExplain(in Explain, f Visit) error { } } +// rewriteExplain is part of the Rewrite implementation +func rewriteExplain(parent SQLNode, node Explain, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *ExplainStmt: + return rewriteRefOfExplainStmt(parent, node, replacer, pre, post) + case *ExplainTab: + return rewriteRefOfExplainTab(parent, node, replacer, pre, post) + default: + // this should never happen + return nil + } +} + // EqualsExpr does deep equals between the two objects. func EqualsExpr(inA, inB Expr) bool { if inA == nil && inB == nil { @@ -7957,6 +12248,80 @@ func VisitExpr(in Expr, f Visit) error { } } +// rewriteExpr is part of the Rewrite implementation +func rewriteExpr(parent SQLNode, node Expr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *AndExpr: + return rewriteRefOfAndExpr(parent, node, replacer, pre, post) + case Argument: + return rewriteArgument(parent, node, replacer, pre, post) + case *BinaryExpr: + return rewriteRefOfBinaryExpr(parent, node, replacer, pre, post) + case BoolVal: + return rewriteBoolVal(parent, node, replacer, pre, post) + case *CaseExpr: + return rewriteRefOfCaseExpr(parent, node, replacer, pre, post) + case *ColName: + return rewriteRefOfColName(parent, node, replacer, pre, post) + case *CollateExpr: + return rewriteRefOfCollateExpr(parent, node, replacer, pre, post) + case *ComparisonExpr: + return rewriteRefOfComparisonExpr(parent, node, replacer, pre, post) + case *ConvertExpr: + return rewriteRefOfConvertExpr(parent, node, replacer, pre, post) + case *ConvertUsingExpr: + return rewriteRefOfConvertUsingExpr(parent, node, replacer, pre, post) + case *CurTimeFuncExpr: + return rewriteRefOfCurTimeFuncExpr(parent, node, replacer, pre, post) + case *Default: + return rewriteRefOfDefault(parent, node, replacer, pre, post) + case *ExistsExpr: + return rewriteRefOfExistsExpr(parent, node, replacer, pre, post) + case *FuncExpr: + return rewriteRefOfFuncExpr(parent, node, replacer, pre, post) + case *GroupConcatExpr: + return rewriteRefOfGroupConcatExpr(parent, node, replacer, pre, post) + case *IntervalExpr: + return rewriteRefOfIntervalExpr(parent, node, replacer, pre, post) + case *IsExpr: + return rewriteRefOfIsExpr(parent, node, replacer, pre, post) + case ListArg: + return rewriteListArg(parent, node, replacer, pre, post) + case *Literal: + return rewriteRefOfLiteral(parent, node, replacer, pre, post) + case *MatchExpr: + return rewriteRefOfMatchExpr(parent, node, replacer, pre, post) + case *NotExpr: + return rewriteRefOfNotExpr(parent, node, replacer, pre, post) + case *NullVal: + return rewriteRefOfNullVal(parent, node, replacer, pre, post) + case *OrExpr: + return rewriteRefOfOrExpr(parent, node, replacer, pre, post) + case *RangeCond: + return rewriteRefOfRangeCond(parent, node, replacer, pre, post) + case *Subquery: + return rewriteRefOfSubquery(parent, node, replacer, pre, post) + case *SubstrExpr: + return rewriteRefOfSubstrExpr(parent, node, replacer, pre, post) + case *TimestampFuncExpr: + return rewriteRefOfTimestampFuncExpr(parent, node, replacer, pre, post) + case *UnaryExpr: + return rewriteRefOfUnaryExpr(parent, node, replacer, pre, post) + case ValTuple: + return rewriteValTuple(parent, node, replacer, pre, post) + case *ValuesFuncExpr: + return rewriteRefOfValuesFuncExpr(parent, node, replacer, pre, post) + case *XorExpr: + return rewriteRefOfXorExpr(parent, node, replacer, pre, post) + default: + // this should never happen + return nil + } +} + // EqualsInsertRows does deep equals between the two objects. func EqualsInsertRows(inA, inB InsertRows) bool { if inA == nil && inB == nil { @@ -8036,6 +12401,26 @@ func VisitInsertRows(in InsertRows, f Visit) error { } } +// rewriteInsertRows is part of the Rewrite implementation +func rewriteInsertRows(parent SQLNode, node InsertRows, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *ParenSelect: + return rewriteRefOfParenSelect(parent, node, replacer, pre, post) + case *Select: + return rewriteRefOfSelect(parent, node, replacer, pre, post) + case *Union: + return rewriteRefOfUnion(parent, node, replacer, pre, post) + case Values: + return rewriteValues(parent, node, replacer, pre, post) + default: + // this should never happen + return nil + } +} + // EqualsSelectExpr does deep equals between the two objects. func EqualsSelectExpr(inA, inB SelectExpr) bool { if inA == nil && inB == nil { @@ -8105,6 +12490,24 @@ func VisitSelectExpr(in SelectExpr, f Visit) error { } } +// rewriteSelectExpr is part of the Rewrite implementation +func rewriteSelectExpr(parent SQLNode, node SelectExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *AliasedExpr: + return rewriteRefOfAliasedExpr(parent, node, replacer, pre, post) + case *Nextval: + return rewriteRefOfNextval(parent, node, replacer, pre, post) + case *StarExpr: + return rewriteRefOfStarExpr(parent, node, replacer, pre, post) + default: + // this should never happen + return nil + } +} + // EqualsSelectStatement does deep equals between the two objects. func EqualsSelectStatement(inA, inB SelectStatement) bool { if inA == nil && inB == nil { @@ -8174,6 +12577,24 @@ func VisitSelectStatement(in SelectStatement, f Visit) error { } } +// rewriteSelectStatement is part of the Rewrite implementation +func rewriteSelectStatement(parent SQLNode, node SelectStatement, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *ParenSelect: + return rewriteRefOfParenSelect(parent, node, replacer, pre, post) + case *Select: + return rewriteRefOfSelect(parent, node, replacer, pre, post) + case *Union: + return rewriteRefOfUnion(parent, node, replacer, pre, post) + default: + // this should never happen + return nil + } +} + // EqualsShowInternal does deep equals between the two objects. func EqualsShowInternal(inA, inB ShowInternal) bool { if inA == nil && inB == nil { @@ -8243,6 +12664,24 @@ func VisitShowInternal(in ShowInternal, f Visit) error { } } +// rewriteShowInternal is part of the Rewrite implementation +func rewriteShowInternal(parent SQLNode, node ShowInternal, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *ShowBasic: + return rewriteRefOfShowBasic(parent, node, replacer, pre, post) + case *ShowCreate: + return rewriteRefOfShowCreate(parent, node, replacer, pre, post) + case *ShowLegacy: + return rewriteRefOfShowLegacy(parent, node, replacer, pre, post) + default: + // this should never happen + return nil + } +} + // EqualsSimpleTableExpr does deep equals between the two objects. func EqualsSimpleTableExpr(inA, inB SimpleTableExpr) bool { if inA == nil && inB == nil { @@ -8302,6 +12741,22 @@ func VisitSimpleTableExpr(in SimpleTableExpr, f Visit) error { } } +// rewriteSimpleTableExpr is part of the Rewrite implementation +func rewriteSimpleTableExpr(parent SQLNode, node SimpleTableExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *DerivedTable: + return rewriteRefOfDerivedTable(parent, node, replacer, pre, post) + case TableName: + return rewriteTableName(parent, node, replacer, pre, post) + default: + // this should never happen + return nil + } +} + // EqualsStatement does deep equals between the two objects. func EqualsStatement(inA, inB Statement) bool { if inA == nil && inB == nil { @@ -8751,6 +13206,100 @@ func VisitStatement(in Statement, f Visit) error { } } +// rewriteStatement is part of the Rewrite implementation +func rewriteStatement(parent SQLNode, node Statement, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *AlterDatabase: + return rewriteRefOfAlterDatabase(parent, node, replacer, pre, post) + case *AlterMigration: + return rewriteRefOfAlterMigration(parent, node, replacer, pre, post) + case *AlterTable: + return rewriteRefOfAlterTable(parent, node, replacer, pre, post) + case *AlterView: + return rewriteRefOfAlterView(parent, node, replacer, pre, post) + case *AlterVschema: + return rewriteRefOfAlterVschema(parent, node, replacer, pre, post) + case *Begin: + return rewriteRefOfBegin(parent, node, replacer, pre, post) + case *CallProc: + return rewriteRefOfCallProc(parent, node, replacer, pre, post) + case *Commit: + return rewriteRefOfCommit(parent, node, replacer, pre, post) + case *CreateDatabase: + return rewriteRefOfCreateDatabase(parent, node, replacer, pre, post) + case *CreateTable: + return rewriteRefOfCreateTable(parent, node, replacer, pre, post) + case *CreateView: + return rewriteRefOfCreateView(parent, node, replacer, pre, post) + case *Delete: + return rewriteRefOfDelete(parent, node, replacer, pre, post) + case *DropDatabase: + return rewriteRefOfDropDatabase(parent, node, replacer, pre, post) + case *DropTable: + return rewriteRefOfDropTable(parent, node, replacer, pre, post) + case *DropView: + return rewriteRefOfDropView(parent, node, replacer, pre, post) + case *ExplainStmt: + return rewriteRefOfExplainStmt(parent, node, replacer, pre, post) + case *ExplainTab: + return rewriteRefOfExplainTab(parent, node, replacer, pre, post) + case *Flush: + return rewriteRefOfFlush(parent, node, replacer, pre, post) + case *Insert: + return rewriteRefOfInsert(parent, node, replacer, pre, post) + case *Load: + return rewriteRefOfLoad(parent, node, replacer, pre, post) + case *LockTables: + return rewriteRefOfLockTables(parent, node, replacer, pre, post) + case *OtherAdmin: + return rewriteRefOfOtherAdmin(parent, node, replacer, pre, post) + case *OtherRead: + return rewriteRefOfOtherRead(parent, node, replacer, pre, post) + case *ParenSelect: + return rewriteRefOfParenSelect(parent, node, replacer, pre, post) + case *Release: + return rewriteRefOfRelease(parent, node, replacer, pre, post) + case *RenameTable: + return rewriteRefOfRenameTable(parent, node, replacer, pre, post) + case *RevertMigration: + return rewriteRefOfRevertMigration(parent, node, replacer, pre, post) + case *Rollback: + return rewriteRefOfRollback(parent, node, replacer, pre, post) + case *SRollback: + return rewriteRefOfSRollback(parent, node, replacer, pre, post) + case *Savepoint: + return rewriteRefOfSavepoint(parent, node, replacer, pre, post) + case *Select: + return rewriteRefOfSelect(parent, node, replacer, pre, post) + case *Set: + return rewriteRefOfSet(parent, node, replacer, pre, post) + case *SetTransaction: + return rewriteRefOfSetTransaction(parent, node, replacer, pre, post) + case *Show: + return rewriteRefOfShow(parent, node, replacer, pre, post) + case *Stream: + return rewriteRefOfStream(parent, node, replacer, pre, post) + case *TruncateTable: + return rewriteRefOfTruncateTable(parent, node, replacer, pre, post) + case *Union: + return rewriteRefOfUnion(parent, node, replacer, pre, post) + case *UnlockTables: + return rewriteRefOfUnlockTables(parent, node, replacer, pre, post) + case *Update: + return rewriteRefOfUpdate(parent, node, replacer, pre, post) + case *Use: + return rewriteRefOfUse(parent, node, replacer, pre, post) + case *VStream: + return rewriteRefOfVStream(parent, node, replacer, pre, post) + default: + // this should never happen + return nil + } +} + // EqualsTableExpr does deep equals between the two objects. func EqualsTableExpr(inA, inB TableExpr) bool { if inA == nil && inB == nil { @@ -8820,42 +13369,156 @@ func VisitTableExpr(in TableExpr, f Visit) error { } } +// rewriteTableExpr is part of the Rewrite implementation +func rewriteTableExpr(parent SQLNode, node TableExpr, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *AliasedTableExpr: + return rewriteRefOfAliasedTableExpr(parent, node, replacer, pre, post) + case *JoinTableExpr: + return rewriteRefOfJoinTableExpr(parent, node, replacer, pre, post) + case *ParenTableExpr: + return rewriteRefOfParenTableExpr(parent, node, replacer, pre, post) + default: + // this should never happen + return nil + } +} + // VisitAccessMode will visit all parts of the AST func VisitAccessMode(in AccessMode, f Visit) error { _, err := f(in) return err } +// rewriteAccessMode is part of the Rewrite implementation +func rewriteAccessMode(parent SQLNode, node AccessMode, replacer replacerFunc, pre, post ApplyFunc) error { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // VisitAlgorithmValue will visit all parts of the AST func VisitAlgorithmValue(in AlgorithmValue, f Visit) error { _, err := f(in) return err } +// rewriteAlgorithmValue is part of the Rewrite implementation +func rewriteAlgorithmValue(parent SQLNode, node AlgorithmValue, replacer replacerFunc, pre, post ApplyFunc) error { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // VisitArgument will visit all parts of the AST func VisitArgument(in Argument, f Visit) error { _, err := f(in) return err } +// rewriteArgument is part of the Rewrite implementation +func rewriteArgument(parent SQLNode, node Argument, replacer replacerFunc, pre, post ApplyFunc) error { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // VisitBoolVal will visit all parts of the AST func VisitBoolVal(in BoolVal, f Visit) error { _, err := f(in) return err } +// rewriteBoolVal is part of the Rewrite implementation +func rewriteBoolVal(parent SQLNode, node BoolVal, replacer replacerFunc, pre, post ApplyFunc) error { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // VisitIsolationLevel will visit all parts of the AST func VisitIsolationLevel(in IsolationLevel, f Visit) error { _, err := f(in) return err } +// rewriteIsolationLevel is part of the Rewrite implementation +func rewriteIsolationLevel(parent SQLNode, node IsolationLevel, replacer replacerFunc, pre, post ApplyFunc) error { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // VisitReferenceAction will visit all parts of the AST func VisitReferenceAction(in ReferenceAction, f Visit) error { _, err := f(in) return err } +// rewriteReferenceAction is part of the Rewrite implementation +func rewriteReferenceAction(parent SQLNode, node ReferenceAction, replacer replacerFunc, pre, post ApplyFunc) error { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsSliceOfRefOfColumnDefinition does deep equals between the two objects. func EqualsSliceOfRefOfColumnDefinition(a, b []*ColumnDefinition) bool { if len(a) != len(b) { @@ -8999,6 +13662,25 @@ func VisitRefOfColIdent(in *ColIdent, f Visit) error { return nil } +// rewriteRefOfColIdent is part of the Rewrite implementation +func rewriteRefOfColIdent(parent SQLNode, node *ColIdent, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsColumnType does deep equals between the two objects. func EqualsColumnType(a, b ColumnType) bool { return a.Type == b.Type && @@ -9149,6 +13831,35 @@ func VisitRefOfJoinCondition(in *JoinCondition, f Visit) error { return nil } +// rewriteRefOfJoinCondition is part of the Rewrite implementation +func rewriteRefOfJoinCondition(parent SQLNode, node *JoinCondition, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteExpr(node, node.On, func(newNode, parent SQLNode) { + parent.(*JoinCondition).On = newNode.(Expr) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteColumns(node, node.Using, func(newNode, parent SQLNode) { + parent.(*JoinCondition).Using = newNode.(Columns) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsTableAndLockTypes does deep equals between the two objects. func EqualsTableAndLockTypes(a, b TableAndLockTypes) bool { if len(a) != len(b) { @@ -9311,6 +14022,25 @@ func VisitRefOfTableIdent(in *TableIdent, f Visit) error { return nil } +// rewriteRefOfTableIdent is part of the Rewrite implementation +func rewriteRefOfTableIdent(parent SQLNode, node *TableIdent, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfTableName does deep equals between the two objects. func EqualsRefOfTableName(a, b *TableName) bool { if a == b { @@ -9351,6 +14081,35 @@ func VisitRefOfTableName(in *TableName, f Visit) error { return nil } +// rewriteRefOfTableName is part of the Rewrite implementation +func rewriteRefOfTableName(parent SQLNode, node *TableName, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*TableName).Name = newNode.(TableIdent) + }, pre, post); errF != nil { + return errF + } + if errF := rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { + parent.(*TableName).Qualifier = newNode.(TableIdent) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsRefOfTableOption does deep equals between the two objects. func EqualsRefOfTableOption(a, b *TableOption) bool { if a == b { @@ -9478,6 +14237,30 @@ func VisitRefOfVindexParam(in *VindexParam, f Visit) error { return nil } +// rewriteRefOfVindexParam is part of the Rewrite implementation +func rewriteRefOfVindexParam(parent SQLNode, node *VindexParam, replacer replacerFunc, pre, post ApplyFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + if errF := rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { + parent.(*VindexParam).Key = newNode.(ColIdent) + }, pre, post); errF != nil { + return errF + } + if !post(&cur) { + return errAbort + } + return nil +} + // EqualsSliceOfVindexParam does deep equals between the two objects. func EqualsSliceOfVindexParam(a, b []VindexParam) bool { if len(a) != len(b) { diff --git a/go/vt/sqlparser/rewriter.go b/go/vt/sqlparser/rewriter.go deleted file mode 100644 index ccffe468745..00000000000 --- a/go/vt/sqlparser/rewriter.go +++ /dev/null @@ -1,931 +0,0 @@ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -// Code generated by ASTHelperGen. DO NOT EDIT. - -package sqlparser - -func (a *application) apply(parent, node SQLNode, replacer replacerFunc) { - if node == nil || isNilValue(node) { - return - } - saved := a.cursor - a.cursor.replacer = replacer - a.cursor.node = node - a.cursor.parent = parent - if a.pre != nil && !a.pre(&a.cursor) { - a.cursor = saved - return - } - switch n := node.(type) { - case *AddColumns: - for x, el := range n.Columns { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*AddColumns).Columns[idx] = newNode.(*ColumnDefinition) - } - }(x)) - } - a.apply(node, n.First, func(newNode, parent SQLNode) { - parent.(*AddColumns).First = newNode.(*ColName) - }) - a.apply(node, n.After, func(newNode, parent SQLNode) { - parent.(*AddColumns).After = newNode.(*ColName) - }) - case *AddConstraintDefinition: - a.apply(node, n.ConstraintDefinition, func(newNode, parent SQLNode) { - parent.(*AddConstraintDefinition).ConstraintDefinition = newNode.(*ConstraintDefinition) - }) - case *AddIndexDefinition: - a.apply(node, n.IndexDefinition, func(newNode, parent SQLNode) { - parent.(*AddIndexDefinition).IndexDefinition = newNode.(*IndexDefinition) - }) - case *AliasedExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*AliasedExpr).Expr = newNode.(Expr) - }) - a.apply(node, n.As, func(newNode, parent SQLNode) { - parent.(*AliasedExpr).As = newNode.(ColIdent) - }) - case *AliasedTableExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*AliasedTableExpr).Expr = newNode.(SimpleTableExpr) - }) - a.apply(node, n.Partitions, func(newNode, parent SQLNode) { - parent.(*AliasedTableExpr).Partitions = newNode.(Partitions) - }) - a.apply(node, n.As, func(newNode, parent SQLNode) { - parent.(*AliasedTableExpr).As = newNode.(TableIdent) - }) - a.apply(node, n.Hints, func(newNode, parent SQLNode) { - parent.(*AliasedTableExpr).Hints = newNode.(*IndexHints) - }) - case *AlterCharset: - case *AlterColumn: - a.apply(node, n.Column, func(newNode, parent SQLNode) { - parent.(*AlterColumn).Column = newNode.(*ColName) - }) - a.apply(node, n.DefaultVal, func(newNode, parent SQLNode) { - parent.(*AlterColumn).DefaultVal = newNode.(Expr) - }) - case *AlterDatabase: - case *AlterMigration: - case *AlterTable: - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*AlterTable).Table = newNode.(TableName) - }) - for x, el := range n.AlterOptions { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*AlterTable).AlterOptions[idx] = newNode.(AlterOption) - } - }(x)) - } - a.apply(node, n.PartitionSpec, func(newNode, parent SQLNode) { - parent.(*AlterTable).PartitionSpec = newNode.(*PartitionSpec) - }) - case *AlterView: - a.apply(node, n.ViewName, func(newNode, parent SQLNode) { - parent.(*AlterView).ViewName = newNode.(TableName) - }) - a.apply(node, n.Columns, func(newNode, parent SQLNode) { - parent.(*AlterView).Columns = newNode.(Columns) - }) - a.apply(node, n.Select, func(newNode, parent SQLNode) { - parent.(*AlterView).Select = newNode.(SelectStatement) - }) - case *AlterVschema: - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*AlterVschema).Table = newNode.(TableName) - }) - a.apply(node, n.VindexSpec, func(newNode, parent SQLNode) { - parent.(*AlterVschema).VindexSpec = newNode.(*VindexSpec) - }) - for x, el := range n.VindexCols { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*AlterVschema).VindexCols[idx] = newNode.(ColIdent) - } - }(x)) - } - a.apply(node, n.AutoIncSpec, func(newNode, parent SQLNode) { - parent.(*AlterVschema).AutoIncSpec = newNode.(*AutoIncSpec) - }) - case *AndExpr: - a.apply(node, n.Left, func(newNode, parent SQLNode) { - parent.(*AndExpr).Left = newNode.(Expr) - }) - a.apply(node, n.Right, func(newNode, parent SQLNode) { - parent.(*AndExpr).Right = newNode.(Expr) - }) - case *AutoIncSpec: - a.apply(node, n.Column, func(newNode, parent SQLNode) { - parent.(*AutoIncSpec).Column = newNode.(ColIdent) - }) - a.apply(node, n.Sequence, func(newNode, parent SQLNode) { - parent.(*AutoIncSpec).Sequence = newNode.(TableName) - }) - case *Begin: - case *BinaryExpr: - a.apply(node, n.Left, func(newNode, parent SQLNode) { - parent.(*BinaryExpr).Left = newNode.(Expr) - }) - a.apply(node, n.Right, func(newNode, parent SQLNode) { - parent.(*BinaryExpr).Right = newNode.(Expr) - }) - case *CallProc: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*CallProc).Name = newNode.(TableName) - }) - a.apply(node, n.Params, func(newNode, parent SQLNode) { - parent.(*CallProc).Params = newNode.(Exprs) - }) - case *CaseExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*CaseExpr).Expr = newNode.(Expr) - }) - for x, el := range n.Whens { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*CaseExpr).Whens[idx] = newNode.(*When) - } - }(x)) - } - a.apply(node, n.Else, func(newNode, parent SQLNode) { - parent.(*CaseExpr).Else = newNode.(Expr) - }) - case *ChangeColumn: - a.apply(node, n.OldColumn, func(newNode, parent SQLNode) { - parent.(*ChangeColumn).OldColumn = newNode.(*ColName) - }) - a.apply(node, n.NewColDefinition, func(newNode, parent SQLNode) { - parent.(*ChangeColumn).NewColDefinition = newNode.(*ColumnDefinition) - }) - a.apply(node, n.First, func(newNode, parent SQLNode) { - parent.(*ChangeColumn).First = newNode.(*ColName) - }) - a.apply(node, n.After, func(newNode, parent SQLNode) { - parent.(*ChangeColumn).After = newNode.(*ColName) - }) - case *CheckConstraintDefinition: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*CheckConstraintDefinition).Expr = newNode.(Expr) - }) - case ColIdent: - case *ColName: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*ColName).Name = newNode.(ColIdent) - }) - a.apply(node, n.Qualifier, func(newNode, parent SQLNode) { - parent.(*ColName).Qualifier = newNode.(TableName) - }) - case *CollateExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*CollateExpr).Expr = newNode.(Expr) - }) - case *ColumnDefinition: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*ColumnDefinition).Name = newNode.(ColIdent) - }) - case *ColumnType: - a.apply(node, n.Length, func(newNode, parent SQLNode) { - parent.(*ColumnType).Length = newNode.(*Literal) - }) - a.apply(node, n.Scale, func(newNode, parent SQLNode) { - parent.(*ColumnType).Scale = newNode.(*Literal) - }) - case Columns: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(Columns)[idx] = newNode.(ColIdent) - } - }(x)) - } - case Comments: - case *Commit: - case *ComparisonExpr: - a.apply(node, n.Left, func(newNode, parent SQLNode) { - parent.(*ComparisonExpr).Left = newNode.(Expr) - }) - a.apply(node, n.Right, func(newNode, parent SQLNode) { - parent.(*ComparisonExpr).Right = newNode.(Expr) - }) - a.apply(node, n.Escape, func(newNode, parent SQLNode) { - parent.(*ComparisonExpr).Escape = newNode.(Expr) - }) - case *ConstraintDefinition: - a.apply(node, n.Details, func(newNode, parent SQLNode) { - parent.(*ConstraintDefinition).Details = newNode.(ConstraintInfo) - }) - case *ConvertExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*ConvertExpr).Expr = newNode.(Expr) - }) - a.apply(node, n.Type, func(newNode, parent SQLNode) { - parent.(*ConvertExpr).Type = newNode.(*ConvertType) - }) - case *ConvertType: - a.apply(node, n.Length, func(newNode, parent SQLNode) { - parent.(*ConvertType).Length = newNode.(*Literal) - }) - a.apply(node, n.Scale, func(newNode, parent SQLNode) { - parent.(*ConvertType).Scale = newNode.(*Literal) - }) - case *ConvertUsingExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*ConvertUsingExpr).Expr = newNode.(Expr) - }) - case *CreateDatabase: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*CreateDatabase).Comments = newNode.(Comments) - }) - case *CreateTable: - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*CreateTable).Table = newNode.(TableName) - }) - a.apply(node, n.TableSpec, func(newNode, parent SQLNode) { - parent.(*CreateTable).TableSpec = newNode.(*TableSpec) - }) - a.apply(node, n.OptLike, func(newNode, parent SQLNode) { - parent.(*CreateTable).OptLike = newNode.(*OptLike) - }) - case *CreateView: - a.apply(node, n.ViewName, func(newNode, parent SQLNode) { - parent.(*CreateView).ViewName = newNode.(TableName) - }) - a.apply(node, n.Columns, func(newNode, parent SQLNode) { - parent.(*CreateView).Columns = newNode.(Columns) - }) - a.apply(node, n.Select, func(newNode, parent SQLNode) { - parent.(*CreateView).Select = newNode.(SelectStatement) - }) - case *CurTimeFuncExpr: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*CurTimeFuncExpr).Name = newNode.(ColIdent) - }) - a.apply(node, n.Fsp, func(newNode, parent SQLNode) { - parent.(*CurTimeFuncExpr).Fsp = newNode.(Expr) - }) - case *Default: - case *Delete: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*Delete).Comments = newNode.(Comments) - }) - a.apply(node, n.Targets, func(newNode, parent SQLNode) { - parent.(*Delete).Targets = newNode.(TableNames) - }) - a.apply(node, n.TableExprs, func(newNode, parent SQLNode) { - parent.(*Delete).TableExprs = newNode.(TableExprs) - }) - a.apply(node, n.Partitions, func(newNode, parent SQLNode) { - parent.(*Delete).Partitions = newNode.(Partitions) - }) - a.apply(node, n.Where, func(newNode, parent SQLNode) { - parent.(*Delete).Where = newNode.(*Where) - }) - a.apply(node, n.OrderBy, func(newNode, parent SQLNode) { - parent.(*Delete).OrderBy = newNode.(OrderBy) - }) - a.apply(node, n.Limit, func(newNode, parent SQLNode) { - parent.(*Delete).Limit = newNode.(*Limit) - }) - case *DerivedTable: - a.apply(node, n.Select, func(newNode, parent SQLNode) { - parent.(*DerivedTable).Select = newNode.(SelectStatement) - }) - case *DropColumn: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*DropColumn).Name = newNode.(*ColName) - }) - case *DropDatabase: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*DropDatabase).Comments = newNode.(Comments) - }) - case *DropKey: - case *DropTable: - a.apply(node, n.FromTables, func(newNode, parent SQLNode) { - parent.(*DropTable).FromTables = newNode.(TableNames) - }) - case *DropView: - a.apply(node, n.FromTables, func(newNode, parent SQLNode) { - parent.(*DropView).FromTables = newNode.(TableNames) - }) - case *ExistsExpr: - a.apply(node, n.Subquery, func(newNode, parent SQLNode) { - parent.(*ExistsExpr).Subquery = newNode.(*Subquery) - }) - case *ExplainStmt: - a.apply(node, n.Statement, func(newNode, parent SQLNode) { - parent.(*ExplainStmt).Statement = newNode.(Statement) - }) - case *ExplainTab: - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*ExplainTab).Table = newNode.(TableName) - }) - case Exprs: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(Exprs)[idx] = newNode.(Expr) - } - }(x)) - } - case *Flush: - a.apply(node, n.TableNames, func(newNode, parent SQLNode) { - parent.(*Flush).TableNames = newNode.(TableNames) - }) - case *Force: - case *ForeignKeyDefinition: - a.apply(node, n.Source, func(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).Source = newNode.(Columns) - }) - a.apply(node, n.ReferencedTable, func(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).ReferencedTable = newNode.(TableName) - }) - a.apply(node, n.ReferencedColumns, func(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).ReferencedColumns = newNode.(Columns) - }) - a.apply(node, n.OnDelete, func(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).OnDelete = newNode.(ReferenceAction) - }) - a.apply(node, n.OnUpdate, func(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).OnUpdate = newNode.(ReferenceAction) - }) - case *FuncExpr: - a.apply(node, n.Qualifier, func(newNode, parent SQLNode) { - parent.(*FuncExpr).Qualifier = newNode.(TableIdent) - }) - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*FuncExpr).Name = newNode.(ColIdent) - }) - a.apply(node, n.Exprs, func(newNode, parent SQLNode) { - parent.(*FuncExpr).Exprs = newNode.(SelectExprs) - }) - case GroupBy: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(GroupBy)[idx] = newNode.(Expr) - } - }(x)) - } - case *GroupConcatExpr: - a.apply(node, n.Exprs, func(newNode, parent SQLNode) { - parent.(*GroupConcatExpr).Exprs = newNode.(SelectExprs) - }) - a.apply(node, n.OrderBy, func(newNode, parent SQLNode) { - parent.(*GroupConcatExpr).OrderBy = newNode.(OrderBy) - }) - a.apply(node, n.Limit, func(newNode, parent SQLNode) { - parent.(*GroupConcatExpr).Limit = newNode.(*Limit) - }) - case *IndexDefinition: - a.apply(node, n.Info, func(newNode, parent SQLNode) { - parent.(*IndexDefinition).Info = newNode.(*IndexInfo) - }) - case *IndexHints: - for x, el := range n.Indexes { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*IndexHints).Indexes[idx] = newNode.(ColIdent) - } - }(x)) - } - case *IndexInfo: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*IndexInfo).Name = newNode.(ColIdent) - }) - a.apply(node, n.ConstraintName, func(newNode, parent SQLNode) { - parent.(*IndexInfo).ConstraintName = newNode.(ColIdent) - }) - case *Insert: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*Insert).Comments = newNode.(Comments) - }) - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*Insert).Table = newNode.(TableName) - }) - a.apply(node, n.Partitions, func(newNode, parent SQLNode) { - parent.(*Insert).Partitions = newNode.(Partitions) - }) - a.apply(node, n.Columns, func(newNode, parent SQLNode) { - parent.(*Insert).Columns = newNode.(Columns) - }) - a.apply(node, n.Rows, func(newNode, parent SQLNode) { - parent.(*Insert).Rows = newNode.(InsertRows) - }) - a.apply(node, n.OnDup, func(newNode, parent SQLNode) { - parent.(*Insert).OnDup = newNode.(OnDup) - }) - case *IntervalExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*IntervalExpr).Expr = newNode.(Expr) - }) - case *IsExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*IsExpr).Expr = newNode.(Expr) - }) - case JoinCondition: - a.apply(node, n.On, replacePanic("JoinCondition On")) - a.apply(node, n.Using, replacePanic("JoinCondition Using")) - case *JoinTableExpr: - a.apply(node, n.LeftExpr, func(newNode, parent SQLNode) { - parent.(*JoinTableExpr).LeftExpr = newNode.(TableExpr) - }) - a.apply(node, n.RightExpr, func(newNode, parent SQLNode) { - parent.(*JoinTableExpr).RightExpr = newNode.(TableExpr) - }) - a.apply(node, n.Condition, func(newNode, parent SQLNode) { - parent.(*JoinTableExpr).Condition = newNode.(JoinCondition) - }) - case *KeyState: - case *Limit: - a.apply(node, n.Offset, func(newNode, parent SQLNode) { - parent.(*Limit).Offset = newNode.(Expr) - }) - a.apply(node, n.Rowcount, func(newNode, parent SQLNode) { - parent.(*Limit).Rowcount = newNode.(Expr) - }) - case ListArg: - case *Literal: - case *Load: - case *LockOption: - case *LockTables: - case *MatchExpr: - a.apply(node, n.Columns, func(newNode, parent SQLNode) { - parent.(*MatchExpr).Columns = newNode.(SelectExprs) - }) - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*MatchExpr).Expr = newNode.(Expr) - }) - case *ModifyColumn: - a.apply(node, n.NewColDefinition, func(newNode, parent SQLNode) { - parent.(*ModifyColumn).NewColDefinition = newNode.(*ColumnDefinition) - }) - a.apply(node, n.First, func(newNode, parent SQLNode) { - parent.(*ModifyColumn).First = newNode.(*ColName) - }) - a.apply(node, n.After, func(newNode, parent SQLNode) { - parent.(*ModifyColumn).After = newNode.(*ColName) - }) - case *Nextval: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*Nextval).Expr = newNode.(Expr) - }) - case *NotExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*NotExpr).Expr = newNode.(Expr) - }) - case *NullVal: - case OnDup: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(OnDup)[idx] = newNode.(*UpdateExpr) - } - }(x)) - } - case *OptLike: - a.apply(node, n.LikeTable, func(newNode, parent SQLNode) { - parent.(*OptLike).LikeTable = newNode.(TableName) - }) - case *OrExpr: - a.apply(node, n.Left, func(newNode, parent SQLNode) { - parent.(*OrExpr).Left = newNode.(Expr) - }) - a.apply(node, n.Right, func(newNode, parent SQLNode) { - parent.(*OrExpr).Right = newNode.(Expr) - }) - case *Order: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*Order).Expr = newNode.(Expr) - }) - case OrderBy: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(OrderBy)[idx] = newNode.(*Order) - } - }(x)) - } - case *OrderByOption: - a.apply(node, n.Cols, func(newNode, parent SQLNode) { - parent.(*OrderByOption).Cols = newNode.(Columns) - }) - case *OtherAdmin: - case *OtherRead: - case *ParenSelect: - a.apply(node, n.Select, func(newNode, parent SQLNode) { - parent.(*ParenSelect).Select = newNode.(SelectStatement) - }) - case *ParenTableExpr: - a.apply(node, n.Exprs, func(newNode, parent SQLNode) { - parent.(*ParenTableExpr).Exprs = newNode.(TableExprs) - }) - case *PartitionDefinition: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*PartitionDefinition).Name = newNode.(ColIdent) - }) - a.apply(node, n.Limit, func(newNode, parent SQLNode) { - parent.(*PartitionDefinition).Limit = newNode.(Expr) - }) - case *PartitionSpec: - a.apply(node, n.Names, func(newNode, parent SQLNode) { - parent.(*PartitionSpec).Names = newNode.(Partitions) - }) - a.apply(node, n.Number, func(newNode, parent SQLNode) { - parent.(*PartitionSpec).Number = newNode.(*Literal) - }) - a.apply(node, n.TableName, func(newNode, parent SQLNode) { - parent.(*PartitionSpec).TableName = newNode.(TableName) - }) - for x, el := range n.Definitions { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*PartitionSpec).Definitions[idx] = newNode.(*PartitionDefinition) - } - }(x)) - } - case Partitions: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(Partitions)[idx] = newNode.(ColIdent) - } - }(x)) - } - case *RangeCond: - a.apply(node, n.Left, func(newNode, parent SQLNode) { - parent.(*RangeCond).Left = newNode.(Expr) - }) - a.apply(node, n.From, func(newNode, parent SQLNode) { - parent.(*RangeCond).From = newNode.(Expr) - }) - a.apply(node, n.To, func(newNode, parent SQLNode) { - parent.(*RangeCond).To = newNode.(Expr) - }) - case *Release: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*Release).Name = newNode.(ColIdent) - }) - case *RenameIndex: - case *RenameTable: - case *RenameTableName: - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*RenameTableName).Table = newNode.(TableName) - }) - case *RevertMigration: - case *Rollback: - case *SRollback: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*SRollback).Name = newNode.(ColIdent) - }) - case *Savepoint: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*Savepoint).Name = newNode.(ColIdent) - }) - case *Select: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*Select).Comments = newNode.(Comments) - }) - a.apply(node, n.SelectExprs, func(newNode, parent SQLNode) { - parent.(*Select).SelectExprs = newNode.(SelectExprs) - }) - a.apply(node, n.From, func(newNode, parent SQLNode) { - parent.(*Select).From = newNode.(TableExprs) - }) - a.apply(node, n.Where, func(newNode, parent SQLNode) { - parent.(*Select).Where = newNode.(*Where) - }) - a.apply(node, n.GroupBy, func(newNode, parent SQLNode) { - parent.(*Select).GroupBy = newNode.(GroupBy) - }) - a.apply(node, n.Having, func(newNode, parent SQLNode) { - parent.(*Select).Having = newNode.(*Where) - }) - a.apply(node, n.OrderBy, func(newNode, parent SQLNode) { - parent.(*Select).OrderBy = newNode.(OrderBy) - }) - a.apply(node, n.Limit, func(newNode, parent SQLNode) { - parent.(*Select).Limit = newNode.(*Limit) - }) - a.apply(node, n.Into, func(newNode, parent SQLNode) { - parent.(*Select).Into = newNode.(*SelectInto) - }) - case SelectExprs: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(SelectExprs)[idx] = newNode.(SelectExpr) - } - }(x)) - } - case *SelectInto: - case *Set: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*Set).Comments = newNode.(Comments) - }) - a.apply(node, n.Exprs, func(newNode, parent SQLNode) { - parent.(*Set).Exprs = newNode.(SetExprs) - }) - case *SetExpr: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*SetExpr).Name = newNode.(ColIdent) - }) - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*SetExpr).Expr = newNode.(Expr) - }) - case SetExprs: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(SetExprs)[idx] = newNode.(*SetExpr) - } - }(x)) - } - case *SetTransaction: - a.apply(node, n.SQLNode, func(newNode, parent SQLNode) { - parent.(*SetTransaction).SQLNode = newNode.(SQLNode) - }) - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*SetTransaction).Comments = newNode.(Comments) - }) - for x, el := range n.Characteristics { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*SetTransaction).Characteristics[idx] = newNode.(Characteristic) - } - }(x)) - } - case *Show: - a.apply(node, n.Internal, func(newNode, parent SQLNode) { - parent.(*Show).Internal = newNode.(ShowInternal) - }) - case *ShowBasic: - a.apply(node, n.Tbl, func(newNode, parent SQLNode) { - parent.(*ShowBasic).Tbl = newNode.(TableName) - }) - a.apply(node, n.Filter, func(newNode, parent SQLNode) { - parent.(*ShowBasic).Filter = newNode.(*ShowFilter) - }) - case *ShowCreate: - a.apply(node, n.Op, func(newNode, parent SQLNode) { - parent.(*ShowCreate).Op = newNode.(TableName) - }) - case *ShowFilter: - a.apply(node, n.Filter, func(newNode, parent SQLNode) { - parent.(*ShowFilter).Filter = newNode.(Expr) - }) - case *ShowLegacy: - a.apply(node, n.OnTable, func(newNode, parent SQLNode) { - parent.(*ShowLegacy).OnTable = newNode.(TableName) - }) - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*ShowLegacy).Table = newNode.(TableName) - }) - a.apply(node, n.ShowCollationFilterOpt, func(newNode, parent SQLNode) { - parent.(*ShowLegacy).ShowCollationFilterOpt = newNode.(Expr) - }) - case *StarExpr: - a.apply(node, n.TableName, func(newNode, parent SQLNode) { - parent.(*StarExpr).TableName = newNode.(TableName) - }) - case *Stream: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*Stream).Comments = newNode.(Comments) - }) - a.apply(node, n.SelectExpr, func(newNode, parent SQLNode) { - parent.(*Stream).SelectExpr = newNode.(SelectExpr) - }) - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*Stream).Table = newNode.(TableName) - }) - case *Subquery: - a.apply(node, n.Select, func(newNode, parent SQLNode) { - parent.(*Subquery).Select = newNode.(SelectStatement) - }) - case *SubstrExpr: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*SubstrExpr).Name = newNode.(*ColName) - }) - a.apply(node, n.StrVal, func(newNode, parent SQLNode) { - parent.(*SubstrExpr).StrVal = newNode.(*Literal) - }) - a.apply(node, n.From, func(newNode, parent SQLNode) { - parent.(*SubstrExpr).From = newNode.(Expr) - }) - a.apply(node, n.To, func(newNode, parent SQLNode) { - parent.(*SubstrExpr).To = newNode.(Expr) - }) - case TableExprs: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(TableExprs)[idx] = newNode.(TableExpr) - } - }(x)) - } - case TableIdent: - case TableName: - a.apply(node, n.Name, replacePanic("TableName Name")) - a.apply(node, n.Qualifier, replacePanic("TableName Qualifier")) - case TableNames: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(TableNames)[idx] = newNode.(TableName) - } - }(x)) - } - case TableOptions: - case *TableSpec: - for x, el := range n.Columns { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*TableSpec).Columns[idx] = newNode.(*ColumnDefinition) - } - }(x)) - } - for x, el := range n.Indexes { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*TableSpec).Indexes[idx] = newNode.(*IndexDefinition) - } - }(x)) - } - for x, el := range n.Constraints { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*TableSpec).Constraints[idx] = newNode.(*ConstraintDefinition) - } - }(x)) - } - a.apply(node, n.Options, func(newNode, parent SQLNode) { - parent.(*TableSpec).Options = newNode.(TableOptions) - }) - case *TablespaceOperation: - case *TimestampFuncExpr: - a.apply(node, n.Expr1, func(newNode, parent SQLNode) { - parent.(*TimestampFuncExpr).Expr1 = newNode.(Expr) - }) - a.apply(node, n.Expr2, func(newNode, parent SQLNode) { - parent.(*TimestampFuncExpr).Expr2 = newNode.(Expr) - }) - case *TruncateTable: - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*TruncateTable).Table = newNode.(TableName) - }) - case *UnaryExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*UnaryExpr).Expr = newNode.(Expr) - }) - case *Union: - a.apply(node, n.FirstStatement, func(newNode, parent SQLNode) { - parent.(*Union).FirstStatement = newNode.(SelectStatement) - }) - for x, el := range n.UnionSelects { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*Union).UnionSelects[idx] = newNode.(*UnionSelect) - } - }(x)) - } - a.apply(node, n.OrderBy, func(newNode, parent SQLNode) { - parent.(*Union).OrderBy = newNode.(OrderBy) - }) - a.apply(node, n.Limit, func(newNode, parent SQLNode) { - parent.(*Union).Limit = newNode.(*Limit) - }) - case *UnionSelect: - a.apply(node, n.Statement, func(newNode, parent SQLNode) { - parent.(*UnionSelect).Statement = newNode.(SelectStatement) - }) - case *UnlockTables: - case *Update: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*Update).Comments = newNode.(Comments) - }) - a.apply(node, n.TableExprs, func(newNode, parent SQLNode) { - parent.(*Update).TableExprs = newNode.(TableExprs) - }) - a.apply(node, n.Exprs, func(newNode, parent SQLNode) { - parent.(*Update).Exprs = newNode.(UpdateExprs) - }) - a.apply(node, n.Where, func(newNode, parent SQLNode) { - parent.(*Update).Where = newNode.(*Where) - }) - a.apply(node, n.OrderBy, func(newNode, parent SQLNode) { - parent.(*Update).OrderBy = newNode.(OrderBy) - }) - a.apply(node, n.Limit, func(newNode, parent SQLNode) { - parent.(*Update).Limit = newNode.(*Limit) - }) - case *UpdateExpr: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*UpdateExpr).Name = newNode.(*ColName) - }) - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*UpdateExpr).Expr = newNode.(Expr) - }) - case UpdateExprs: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(UpdateExprs)[idx] = newNode.(*UpdateExpr) - } - }(x)) - } - case *Use: - a.apply(node, n.DBName, func(newNode, parent SQLNode) { - parent.(*Use).DBName = newNode.(TableIdent) - }) - case *VStream: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*VStream).Comments = newNode.(Comments) - }) - a.apply(node, n.SelectExpr, func(newNode, parent SQLNode) { - parent.(*VStream).SelectExpr = newNode.(SelectExpr) - }) - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*VStream).Table = newNode.(TableName) - }) - a.apply(node, n.Where, func(newNode, parent SQLNode) { - parent.(*VStream).Where = newNode.(*Where) - }) - a.apply(node, n.Limit, func(newNode, parent SQLNode) { - parent.(*VStream).Limit = newNode.(*Limit) - }) - case ValTuple: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(ValTuple)[idx] = newNode.(Expr) - } - }(x)) - } - case *Validation: - case Values: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(Values)[idx] = newNode.(ValTuple) - } - }(x)) - } - case *ValuesFuncExpr: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*ValuesFuncExpr).Name = newNode.(*ColName) - }) - case VindexParam: - a.apply(node, n.Key, replacePanic("VindexParam Key")) - case *VindexSpec: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*VindexSpec).Name = newNode.(ColIdent) - }) - a.apply(node, n.Type, func(newNode, parent SQLNode) { - parent.(*VindexSpec).Type = newNode.(ColIdent) - }) - for x, el := range n.Params { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*VindexSpec).Params[idx] = newNode.(VindexParam) - } - }(x)) - } - case *When: - a.apply(node, n.Cond, func(newNode, parent SQLNode) { - parent.(*When).Cond = newNode.(Expr) - }) - a.apply(node, n.Val, func(newNode, parent SQLNode) { - parent.(*When).Val = newNode.(Expr) - }) - case *Where: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*Where).Expr = newNode.(Expr) - }) - case *XorExpr: - a.apply(node, n.Left, func(newNode, parent SQLNode) { - parent.(*XorExpr).Left = newNode.(Expr) - }) - a.apply(node, n.Right, func(newNode, parent SQLNode) { - parent.(*XorExpr).Right = newNode.(Expr) - }) - } - if a.post != nil && !a.post(&a.cursor) { - panic(abort) - } - a.cursor = saved -} diff --git a/go/vt/sqlparser/rewriter_api.go b/go/vt/sqlparser/rewriter_api.go index 4cebbf1dd3e..cac0042f81b 100644 --- a/go/vt/sqlparser/rewriter_api.go +++ b/go/vt/sqlparser/rewriter_api.go @@ -18,8 +18,6 @@ package sqlparser import ( "fmt" - "reflect" - "runtime" ) // The rewriter was heavily inspired by https://github.com/golang/tools/blob/master/go/ast/astutil/rewrite.go @@ -42,34 +40,27 @@ import ( // func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode, err error) { parent := &struct{ SQLNode }{node} - defer func() { - if r := recover(); r != nil { - switch r := r.(type) { - case abortT: // nothing to do - - case *runtime.TypeAssertionError: - err = r - case *valueTypeFieldCantChangeErr: - err = r - default: - panic(r) - } - } - result = parent.SQLNode - }() - - a := &application{ - pre: pre, - post: post, - cursor: Cursor{}, - } // this is the root-replacer, used when the user replaces the root of the ast replacer := func(newNode SQLNode, _ SQLNode) { parent.SQLNode = newNode } - a.apply(parent, node, replacer) + if pre == nil { + pre = func(*Cursor) bool { + return true + } + } + if post == nil { + post = func(*Cursor) bool { + return true + } + } + + err = rewriteSQLNode(parent, node, replacer, pre, post) + if err != nil && err != errAbort { + return nil, err + } return parent.SQLNode, nil } @@ -82,11 +73,7 @@ func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode, err error) { // See Rewrite for details. type ApplyFunc func(*Cursor) bool -type abortT int - -var abort = abortT(0) // singleton, to signal termination of Apply - -var abortE = fmt.Errorf("this error is to abort the rewriter, it is not an actual error") +var errAbort = fmt.Errorf("this error is to abort the rewriter, it is not an actual error") // A Cursor describes a node encountered during Apply. // Information about the node and its parent is available @@ -111,32 +98,3 @@ func (c *Cursor) Replace(newNode SQLNode) { } type replacerFunc func(newNode, parent SQLNode) - -// application carries all the shared data so we can pass it around cheaply. -type application struct { - pre, post ApplyFunc - cursor Cursor -} - -func isNilValue(i interface{}) bool { - valueOf := reflect.ValueOf(i) - kind := valueOf.Kind() - isNullable := kind == reflect.Ptr || kind == reflect.Array || kind == reflect.Slice - return isNullable && valueOf.IsNil() -} - -// this type is here so we can catch it in the Rewrite method above -type valueTypeFieldCantChangeErr struct { - msg string -} - -// Error implements the error interface -func (e *valueTypeFieldCantChangeErr) Error() string { - return "Tried replacing a field of a value type. This is not supported. " + e.msg -} - -func replacePanic(msg string) func(newNode, parent SQLNode) { - return func(newNode, parent SQLNode) { - panic(&valueTypeFieldCantChangeErr{msg: msg}) - } -} diff --git a/go/vt/sqlparser/rewriter_test.go b/go/vt/sqlparser/rewriter_test.go index 6131c6c5588..4a037aeead5 100644 --- a/go/vt/sqlparser/rewriter_test.go +++ b/go/vt/sqlparser/rewriter_test.go @@ -39,19 +39,6 @@ func BenchmarkVisitLargeExpression(b *testing.B) { } } -func TestBadTypeReturnsErrorAndNotPanic(t *testing.T) { - parse, err := Parse("select 42 from dual") - require.NoError(t, err) - _, err = Rewrite(parse, func(cursor *Cursor) bool { - _, ok := cursor.Node().(*Literal) - if ok { - cursor.Replace(&AliasedTableExpr{}) // this is not a valid replacement because of types - } - return true - }, nil) - require.Error(t, err) -} - func TestChangeValueTypeGivesError(t *testing.T) { parse, err := Parse("select * from a join b on a.id = b.id") require.NoError(t, err) From b65f3c9af6eb6de93919b8629ed3b0fb0ba7a131 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Fri, 19 Mar 2021 15:51:24 +0100 Subject: [PATCH 07/15] add benchmark Signed-off-by: Andres Taylor --- go/vt/sqlparser/walker_test.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/go/vt/sqlparser/walker_test.go b/go/vt/sqlparser/walker_test.go index c30741029be..386e8736eac 100644 --- a/go/vt/sqlparser/walker_test.go +++ b/go/vt/sqlparser/walker_test.go @@ -38,3 +38,22 @@ func BenchmarkWalkLargeExpression(b *testing.B) { }) } } + +func BenchmarkRewriteLargeExpression(b *testing.B) { + for i := 0; i < 10; i++ { + b.Run(fmt.Sprintf("%d", i), func(b *testing.B) { + exp := newGenerator(int64(i*100), 5).expression() + count := 0 + for i := 0; i < b.N; i++ { + _, err := Rewrite(exp, func(_ *Cursor) bool { + count++ + return true + }, func(_ *Cursor) bool { + count-- + return true + }) + require.NoError(b, err) + } + }) + } +} From 1f1cb2f4a0686f367cbb3b3ceae218a9f277111a Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Fri, 19 Mar 2021 16:27:22 +0100 Subject: [PATCH 08/15] check pre and post before executing Signed-off-by: Andres Taylor --- .../asthelpergen/integration/ast_helper.go | 60 +- .../asthelpergen/integration/test_helpers.go | 12 - go/tools/asthelpergen/rewrite_gen.go | 40 +- go/vt/sqlparser/ast_helper.go | 604 +++++++++--------- go/vt/sqlparser/rewriter_api.go | 11 - 5 files changed, 348 insertions(+), 379 deletions(-) diff --git a/go/tools/asthelpergen/integration/ast_helper.go b/go/tools/asthelpergen/integration/ast_helper.go index ee7ad844879..8b775fb12da 100644 --- a/go/tools/asthelpergen/integration/ast_helper.go +++ b/go/tools/asthelpergen/integration/ast_helper.go @@ -253,10 +253,10 @@ func rewriteBytes(parent AST, node Bytes, replacer replacerFunc, pre, post Apply parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -288,13 +288,13 @@ func rewriteInterfaceContainer(parent AST, node InterfaceContainer, replacer rep parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if err != nil { return err } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -348,7 +348,7 @@ func rewriteInterfaceSlice(parent AST, node InterfaceSlice, replacer replacerFun parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node { @@ -358,7 +358,7 @@ func rewriteInterfaceSlice(parent AST, node InterfaceSlice, replacer replacerFun return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -405,10 +405,10 @@ func rewriteRefOfLeaf(parent AST, node *Leaf, replacer replacerFunc, pre, post A parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -462,7 +462,7 @@ func rewriteLeafSlice(parent AST, node LeafSlice, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node { @@ -472,7 +472,7 @@ func rewriteLeafSlice(parent AST, node LeafSlice, replacer replacerFunc, pre, po return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -515,10 +515,10 @@ func rewriteRefOfNoCloneType(parent AST, node *NoCloneType, replacer replacerFun parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -575,7 +575,7 @@ func rewriteRefOfRefContainer(parent AST, node *RefContainer, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { @@ -588,7 +588,7 @@ func rewriteRefOfRefContainer(parent AST, node *RefContainer, replacer replacerF }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -650,7 +650,7 @@ func rewriteRefOfRefSliceContainer(parent AST, node *RefSliceContainer, replacer parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node.ASTElements { @@ -667,7 +667,7 @@ func rewriteRefOfRefSliceContainer(parent AST, node *RefSliceContainer, replacer return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -720,7 +720,7 @@ func rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteSubIface(node, node.inner, func(newNode, parent AST) { @@ -728,7 +728,7 @@ func rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer replacerFunc, pre, }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -768,7 +768,7 @@ func rewriteValueContainer(parent AST, node ValueContainer, replacer replacerFun parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { @@ -784,7 +784,7 @@ func rewriteValueContainer(parent AST, node ValueContainer, replacer replacerFun if err != nil { return err } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -828,7 +828,7 @@ func rewriteValueSliceContainer(parent AST, node ValueSliceContainer, replacer r parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for _, el := range node.ASTElements { @@ -848,7 +848,7 @@ func rewriteValueSliceContainer(parent AST, node ValueSliceContainer, replacer r if err != nil { return err } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -930,10 +930,10 @@ func rewriteBasicType(parent AST, node BasicType, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -981,10 +981,10 @@ func rewriteRefOfInterfaceContainer(parent AST, node *InterfaceContainer, replac parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -1125,7 +1125,7 @@ func rewriteRefOfValueContainer(parent AST, node *ValueContainer, replacer repla parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { @@ -1138,7 +1138,7 @@ func rewriteRefOfValueContainer(parent AST, node *ValueContainer, replacer repla }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -1200,7 +1200,7 @@ func rewriteRefOfValueSliceContainer(parent AST, node *ValueSliceContainer, repl parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node.ASTElements { @@ -1217,7 +1217,7 @@ func rewriteRefOfValueSliceContainer(parent AST, node *ValueSliceContainer, repl return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil diff --git a/go/tools/asthelpergen/integration/test_helpers.go b/go/tools/asthelpergen/integration/test_helpers.go index 063a1f7a81d..cb1d62be847 100644 --- a/go/tools/asthelpergen/integration/test_helpers.go +++ b/go/tools/asthelpergen/integration/test_helpers.go @@ -65,18 +65,6 @@ type replacerFunc func(newNode, parent AST) func Rewrite(node AST, pre, post ApplyFunc) (AST, error) { outer := &struct{ AST }{node} - if pre == nil { - pre = func(cursor *Cursor) bool { - return true - } - } - - if post == nil { - post = func(cursor *Cursor) bool { - return true - } - } - err := rewriteAST(outer, node, func(newNode, parent AST) { outer.AST = newNode }, pre, post) diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go index 458dfd3d7a1..63a0ae8497a 100644 --- a/go/tools/asthelpergen/rewrite_gen.go +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -92,12 +92,12 @@ func (e rewriteGen) structMethod(t types.Type, strct *types.Struct, spi generato stmts := []jen.Code{ jen.Var().Id("err").Error(), createCursor(), - jen.If(jen.Id("!pre(&cur)")).Block(returnNil()), + executePre(), } stmts = append(stmts, e.rewriteAllStructFields(t, strct, spi, true)...) stmts = append(stmts, jen.If(jen.Id("err != nil")).Block(jen.Return(jen.Err())), - jen.If(jen.Id("!post").Call(jen.Id("&cur"))).Block(jen.Return(jen.Id(abort))), + executePost(), returnNil(), ) e.rewriteFunc(t, stmts, spi) @@ -130,13 +130,13 @@ func (e rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi gen return nil } */ - jen.If(jen.Id("!pre(&cur)")).Block(returnNil()), + executePre(), } stmts = append(stmts, e.rewriteAllStructFields(t, strct, spi, false)...) stmts = append(stmts, - jen.If(jen.Id("!post").Call(jen.Id("&cur"))).Block(jen.Return(jen.Id(abort))), + executePost(), returnNil(), ) e.rewriteFunc(t, stmts, spi) @@ -190,7 +190,7 @@ func (e rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorS stmts := []jen.Code{ jen.If(jen.Id("node == nil").Block(returnNil())), createCursor(), - jen.If(jen.Id("!pre(&cur)")).Block(returnNil()), + executePre(), } if shouldAdd(slice.Elem(), spi.iface()) { @@ -216,13 +216,21 @@ func (e rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorS return nil */ - jen.If(jen.Id("!post").Call(jen.Id("&cur"))).Block(jen.Return(jen.Id(abort))), + executePost(), returnNil(), ) e.rewriteFunc(t, stmts, spi) return nil } +func executePre() *jen.Statement { + return jen.If(jen.Id("pre!= nil && !pre(&cur)")).Block(returnNil()) +} + +func executePost() *jen.Statement { + return jen.If(jen.Id("post != nil && !post(&cur)")).Block(jen.Return(jen.Id(abort))) +} + func (e rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { if !shouldAdd(t, spi.iface()) { return nil @@ -230,8 +238,8 @@ func (e rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI) stmts := []jen.Code{ createCursor(), - jen.If(jen.Id("!pre(&cur)")).Block(returnNil()), - jen.If(jen.Id("!post").Call(jen.Id("&cur"))).Block(jen.Return(jen.Id(abort))), + executePre(), + executePost(), returnNil(), } @@ -239,22 +247,6 @@ func (e rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI) return nil } -func (e rewriteGen) visitNoChildren(t types.Type, spi generatorSPI) error { - if !shouldAdd(t, spi.iface()) { - return nil - } - - /* - */ - - stmts := []jen.Code{ - jen.Comment("ptrToStructMethod"), - } - e.rewriteFunc(t, stmts, spi) - - return nil -} - func (e rewriteGen) rewriteFunc(t types.Type, stmts []jen.Code, spi generatorSPI) { /* diff --git a/go/vt/sqlparser/ast_helper.go b/go/vt/sqlparser/ast_helper.go index 417a132779e..c0ee1633720 100644 --- a/go/vt/sqlparser/ast_helper.go +++ b/go/vt/sqlparser/ast_helper.go @@ -1882,7 +1882,7 @@ func rewriteRefOfAddColumns(parent SQLNode, node *AddColumns, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node.Columns { @@ -1902,7 +1902,7 @@ func rewriteRefOfAddColumns(parent SQLNode, node *AddColumns, replacer replacerF }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -1953,7 +1953,7 @@ func rewriteRefOfAddConstraintDefinition(parent SQLNode, node *AddConstraintDefi parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteRefOfConstraintDefinition(node, node.ConstraintDefinition, func(newNode, parent SQLNode) { @@ -1961,7 +1961,7 @@ func rewriteRefOfAddConstraintDefinition(parent SQLNode, node *AddConstraintDefi }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -2012,7 +2012,7 @@ func rewriteRefOfAddIndexDefinition(parent SQLNode, node *AddIndexDefinition, re parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteRefOfIndexDefinition(node, node.IndexDefinition, func(newNode, parent SQLNode) { @@ -2020,7 +2020,7 @@ func rewriteRefOfAddIndexDefinition(parent SQLNode, node *AddIndexDefinition, re }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -2076,7 +2076,7 @@ func rewriteRefOfAliasedExpr(parent SQLNode, node *AliasedExpr, replacer replace parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -2089,7 +2089,7 @@ func rewriteRefOfAliasedExpr(parent SQLNode, node *AliasedExpr, replacer replace }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -2155,7 +2155,7 @@ func rewriteRefOfAliasedTableExpr(parent SQLNode, node *AliasedTableExpr, replac parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteSimpleTableExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -2178,7 +2178,7 @@ func rewriteRefOfAliasedTableExpr(parent SQLNode, node *AliasedTableExpr, replac }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -2226,10 +2226,10 @@ func rewriteRefOfAlterCharset(parent SQLNode, node *AlterCharset, replacer repla parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -2286,7 +2286,7 @@ func rewriteRefOfAlterColumn(parent SQLNode, node *AlterColumn, replacer replace parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteRefOfColName(node, node.Column, func(newNode, parent SQLNode) { @@ -2299,7 +2299,7 @@ func rewriteRefOfAlterColumn(parent SQLNode, node *AlterColumn, replacer replace }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -2350,10 +2350,10 @@ func rewriteRefOfAlterDatabase(parent SQLNode, node *AlterDatabase, replacer rep parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -2401,10 +2401,10 @@ func rewriteRefOfAlterMigration(parent SQLNode, node *AlterMigration, replacer r parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -2468,7 +2468,7 @@ func rewriteRefOfAlterTable(parent SQLNode, node *AlterTable, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { @@ -2488,7 +2488,7 @@ func rewriteRefOfAlterTable(parent SQLNode, node *AlterTable, replacer replacerF }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -2553,7 +2553,7 @@ func rewriteRefOfAlterView(parent SQLNode, node *AlterView, replacer replacerFun parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { @@ -2571,7 +2571,7 @@ func rewriteRefOfAlterView(parent SQLNode, node *AlterView, replacer replacerFun }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -2640,7 +2640,7 @@ func rewriteRefOfAlterVschema(parent SQLNode, node *AlterVschema, replacer repla parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { @@ -2665,7 +2665,7 @@ func rewriteRefOfAlterVschema(parent SQLNode, node *AlterVschema, replacer repla }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -2721,7 +2721,7 @@ func rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { @@ -2734,7 +2734,7 @@ func rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replacer replacerFunc, p }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -2790,7 +2790,7 @@ func rewriteRefOfAutoIncSpec(parent SQLNode, node *AutoIncSpec, replacer replace parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteColIdent(node, node.Column, func(newNode, parent SQLNode) { @@ -2803,7 +2803,7 @@ func rewriteRefOfAutoIncSpec(parent SQLNode, node *AutoIncSpec, replacer replace }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -2850,10 +2850,10 @@ func rewriteRefOfBegin(parent SQLNode, node *Begin, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -2910,7 +2910,7 @@ func rewriteRefOfBinaryExpr(parent SQLNode, node *BinaryExpr, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { @@ -2923,7 +2923,7 @@ func rewriteRefOfBinaryExpr(parent SQLNode, node *BinaryExpr, replacer replacerF }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -2979,7 +2979,7 @@ func rewriteRefOfCallProc(parent SQLNode, node *CallProc, replacer replacerFunc, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableName(node, node.Name, func(newNode, parent SQLNode) { @@ -2992,7 +2992,7 @@ func rewriteRefOfCallProc(parent SQLNode, node *CallProc, replacer replacerFunc, }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -3055,7 +3055,7 @@ func rewriteRefOfCaseExpr(parent SQLNode, node *CaseExpr, replacer replacerFunc, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -3075,7 +3075,7 @@ func rewriteRefOfCaseExpr(parent SQLNode, node *CaseExpr, replacer replacerFunc, }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -3141,7 +3141,7 @@ func rewriteRefOfChangeColumn(parent SQLNode, node *ChangeColumn, replacer repla parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteRefOfColName(node, node.OldColumn, func(newNode, parent SQLNode) { @@ -3164,7 +3164,7 @@ func rewriteRefOfChangeColumn(parent SQLNode, node *ChangeColumn, replacer repla }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -3216,7 +3216,7 @@ func rewriteRefOfCheckConstraintDefinition(parent SQLNode, node *CheckConstraint parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -3224,7 +3224,7 @@ func rewriteRefOfCheckConstraintDefinition(parent SQLNode, node *CheckConstraint }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -3258,13 +3258,13 @@ func rewriteColIdent(parent SQLNode, node ColIdent, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if err != nil { return err } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -3314,7 +3314,7 @@ func rewriteRefOfColName(parent SQLNode, node *ColName, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -3327,7 +3327,7 @@ func rewriteRefOfColName(parent SQLNode, node *ColName, replacer replacerFunc, p }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -3379,7 +3379,7 @@ func rewriteRefOfCollateExpr(parent SQLNode, node *CollateExpr, replacer replace parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -3387,7 +3387,7 @@ func rewriteRefOfCollateExpr(parent SQLNode, node *CollateExpr, replacer replace }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -3440,7 +3440,7 @@ func rewriteRefOfColumnDefinition(parent SQLNode, node *ColumnDefinition, replac parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -3448,7 +3448,7 @@ func rewriteRefOfColumnDefinition(parent SQLNode, node *ColumnDefinition, replac }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -3513,7 +3513,7 @@ func rewriteRefOfColumnType(parent SQLNode, node *ColumnType, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { @@ -3526,7 +3526,7 @@ func rewriteRefOfColumnType(parent SQLNode, node *ColumnType, replacer replacerF }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -3580,7 +3580,7 @@ func rewriteColumns(parent SQLNode, node Columns, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node { @@ -3590,7 +3590,7 @@ func rewriteColumns(parent SQLNode, node Columns, replacer replacerFunc, pre, po return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -3632,10 +3632,10 @@ func rewriteComments(parent SQLNode, node Comments, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -3682,10 +3682,10 @@ func rewriteRefOfCommit(parent SQLNode, node *Commit, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -3747,7 +3747,7 @@ func rewriteRefOfComparisonExpr(parent SQLNode, node *ComparisonExpr, replacer r parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { @@ -3765,7 +3765,7 @@ func rewriteRefOfComparisonExpr(parent SQLNode, node *ComparisonExpr, replacer r }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -3817,7 +3817,7 @@ func rewriteRefOfConstraintDefinition(parent SQLNode, node *ConstraintDefinition parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteConstraintInfo(node, node.Details, func(newNode, parent SQLNode) { @@ -3825,7 +3825,7 @@ func rewriteRefOfConstraintDefinition(parent SQLNode, node *ConstraintDefinition }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -3881,7 +3881,7 @@ func rewriteRefOfConvertExpr(parent SQLNode, node *ConvertExpr, replacer replace parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -3894,7 +3894,7 @@ func rewriteRefOfConvertExpr(parent SQLNode, node *ConvertExpr, replacer replace }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -3953,7 +3953,7 @@ func rewriteRefOfConvertType(parent SQLNode, node *ConvertType, replacer replace parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { @@ -3966,7 +3966,7 @@ func rewriteRefOfConvertType(parent SQLNode, node *ConvertType, replacer replace }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -4018,7 +4018,7 @@ func rewriteRefOfConvertUsingExpr(parent SQLNode, node *ConvertUsingExpr, replac parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -4026,7 +4026,7 @@ func rewriteRefOfConvertUsingExpr(parent SQLNode, node *ConvertUsingExpr, replac }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -4082,7 +4082,7 @@ func rewriteRefOfCreateDatabase(parent SQLNode, node *CreateDatabase, replacer r parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -4090,7 +4090,7 @@ func rewriteRefOfCreateDatabase(parent SQLNode, node *CreateDatabase, replacer r }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -4154,7 +4154,7 @@ func rewriteRefOfCreateTable(parent SQLNode, node *CreateTable, replacer replace parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { @@ -4172,7 +4172,7 @@ func rewriteRefOfCreateTable(parent SQLNode, node *CreateTable, replacer replace }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -4238,7 +4238,7 @@ func rewriteRefOfCreateView(parent SQLNode, node *CreateView, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { @@ -4256,7 +4256,7 @@ func rewriteRefOfCreateView(parent SQLNode, node *CreateView, replacer replacerF }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -4312,7 +4312,7 @@ func rewriteRefOfCurTimeFuncExpr(parent SQLNode, node *CurTimeFuncExpr, replacer parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -4325,7 +4325,7 @@ func rewriteRefOfCurTimeFuncExpr(parent SQLNode, node *CurTimeFuncExpr, replacer }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -4372,10 +4372,10 @@ func rewriteRefOfDefault(parent SQLNode, node *Default, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -4457,7 +4457,7 @@ func rewriteRefOfDelete(parent SQLNode, node *Delete, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -4495,7 +4495,7 @@ func rewriteRefOfDelete(parent SQLNode, node *Delete, replacer replacerFunc, pre }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -4546,7 +4546,7 @@ func rewriteRefOfDerivedTable(parent SQLNode, node *DerivedTable, replacer repla parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { @@ -4554,7 +4554,7 @@ func rewriteRefOfDerivedTable(parent SQLNode, node *DerivedTable, replacer repla }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -4605,7 +4605,7 @@ func rewriteRefOfDropColumn(parent SQLNode, node *DropColumn, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { @@ -4613,7 +4613,7 @@ func rewriteRefOfDropColumn(parent SQLNode, node *DropColumn, replacer replacerF }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -4666,7 +4666,7 @@ func rewriteRefOfDropDatabase(parent SQLNode, node *DropDatabase, replacer repla parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -4674,7 +4674,7 @@ func rewriteRefOfDropDatabase(parent SQLNode, node *DropDatabase, replacer repla }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -4722,10 +4722,10 @@ func rewriteRefOfDropKey(parent SQLNode, node *DropKey, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -4778,7 +4778,7 @@ func rewriteRefOfDropTable(parent SQLNode, node *DropTable, replacer replacerFun parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { @@ -4786,7 +4786,7 @@ func rewriteRefOfDropTable(parent SQLNode, node *DropTable, replacer replacerFun }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -4838,7 +4838,7 @@ func rewriteRefOfDropView(parent SQLNode, node *DropView, replacer replacerFunc, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { @@ -4846,7 +4846,7 @@ func rewriteRefOfDropView(parent SQLNode, node *DropView, replacer replacerFunc, }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -4897,7 +4897,7 @@ func rewriteRefOfExistsExpr(parent SQLNode, node *ExistsExpr, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteRefOfSubquery(node, node.Subquery, func(newNode, parent SQLNode) { @@ -4905,7 +4905,7 @@ func rewriteRefOfExistsExpr(parent SQLNode, node *ExistsExpr, replacer replacerF }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -4957,7 +4957,7 @@ func rewriteRefOfExplainStmt(parent SQLNode, node *ExplainStmt, replacer replace parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteStatement(node, node.Statement, func(newNode, parent SQLNode) { @@ -4965,7 +4965,7 @@ func rewriteRefOfExplainStmt(parent SQLNode, node *ExplainStmt, replacer replace }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -5017,7 +5017,7 @@ func rewriteRefOfExplainTab(parent SQLNode, node *ExplainTab, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { @@ -5025,7 +5025,7 @@ func rewriteRefOfExplainTab(parent SQLNode, node *ExplainTab, replacer replacerF }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -5079,7 +5079,7 @@ func rewriteExprs(parent SQLNode, node Exprs, replacer replacerFunc, pre, post A parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node { @@ -5089,7 +5089,7 @@ func rewriteExprs(parent SQLNode, node Exprs, replacer replacerFunc, pre, post A return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -5145,7 +5145,7 @@ func rewriteRefOfFlush(parent SQLNode, node *Flush, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableNames(node, node.TableNames, func(newNode, parent SQLNode) { @@ -5153,7 +5153,7 @@ func rewriteRefOfFlush(parent SQLNode, node *Flush, replacer replacerFunc, pre, }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -5200,10 +5200,10 @@ func rewriteRefOfForce(parent SQLNode, node *Force, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -5272,7 +5272,7 @@ func rewriteRefOfForeignKeyDefinition(parent SQLNode, node *ForeignKeyDefinition parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteColumns(node, node.Source, func(newNode, parent SQLNode) { @@ -5300,7 +5300,7 @@ func rewriteRefOfForeignKeyDefinition(parent SQLNode, node *ForeignKeyDefinition }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -5362,7 +5362,7 @@ func rewriteRefOfFuncExpr(parent SQLNode, node *FuncExpr, replacer replacerFunc, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { @@ -5380,7 +5380,7 @@ func rewriteRefOfFuncExpr(parent SQLNode, node *FuncExpr, replacer replacerFunc, }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -5434,7 +5434,7 @@ func rewriteGroupBy(parent SQLNode, node GroupBy, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node { @@ -5444,7 +5444,7 @@ func rewriteGroupBy(parent SQLNode, node GroupBy, replacer replacerFunc, pre, po return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -5507,7 +5507,7 @@ func rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupConcatExpr, replacer parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteSelectExprs(node, node.Exprs, func(newNode, parent SQLNode) { @@ -5525,7 +5525,7 @@ func rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupConcatExpr, replacer }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -5580,7 +5580,7 @@ func rewriteRefOfIndexDefinition(parent SQLNode, node *IndexDefinition, replacer parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteRefOfIndexInfo(node, node.Info, func(newNode, parent SQLNode) { @@ -5588,7 +5588,7 @@ func rewriteRefOfIndexDefinition(parent SQLNode, node *IndexDefinition, replacer }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -5642,7 +5642,7 @@ func rewriteRefOfIndexHints(parent SQLNode, node *IndexHints, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node.Indexes { @@ -5652,7 +5652,7 @@ func rewriteRefOfIndexHints(parent SQLNode, node *IndexHints, replacer replacerF return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -5713,7 +5713,7 @@ func rewriteRefOfIndexInfo(parent SQLNode, node *IndexInfo, replacer replacerFun parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -5726,7 +5726,7 @@ func rewriteRefOfIndexInfo(parent SQLNode, node *IndexInfo, replacer replacerFun }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -5804,7 +5804,7 @@ func rewriteRefOfInsert(parent SQLNode, node *Insert, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -5837,7 +5837,7 @@ func rewriteRefOfInsert(parent SQLNode, node *Insert, replacer replacerFunc, pre }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -5889,7 +5889,7 @@ func rewriteRefOfIntervalExpr(parent SQLNode, node *IntervalExpr, replacer repla parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -5897,7 +5897,7 @@ func rewriteRefOfIntervalExpr(parent SQLNode, node *IntervalExpr, replacer repla }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -5949,7 +5949,7 @@ func rewriteRefOfIsExpr(parent SQLNode, node *IsExpr, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -5957,7 +5957,7 @@ func rewriteRefOfIsExpr(parent SQLNode, node *IsExpr, replacer replacerFunc, pre }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -5996,7 +5996,7 @@ func rewriteJoinCondition(parent SQLNode, node JoinCondition, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.On, func(newNode, parent SQLNode) { @@ -6012,7 +6012,7 @@ func rewriteJoinCondition(parent SQLNode, node JoinCondition, replacer replacerF if err != nil { return err } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -6074,7 +6074,7 @@ func rewriteRefOfJoinTableExpr(parent SQLNode, node *JoinTableExpr, replacer rep parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableExpr(node, node.LeftExpr, func(newNode, parent SQLNode) { @@ -6092,7 +6092,7 @@ func rewriteRefOfJoinTableExpr(parent SQLNode, node *JoinTableExpr, replacer rep }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -6139,10 +6139,10 @@ func rewriteRefOfKeyState(parent SQLNode, node *KeyState, replacer replacerFunc, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -6198,7 +6198,7 @@ func rewriteRefOfLimit(parent SQLNode, node *Limit, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Offset, func(newNode, parent SQLNode) { @@ -6211,7 +6211,7 @@ func rewriteRefOfLimit(parent SQLNode, node *Limit, replacer replacerFunc, pre, }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -6253,10 +6253,10 @@ func rewriteListArg(parent SQLNode, node ListArg, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -6304,10 +6304,10 @@ func rewriteRefOfLiteral(parent SQLNode, node *Literal, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -6354,10 +6354,10 @@ func rewriteRefOfLoad(parent SQLNode, node *Load, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -6404,10 +6404,10 @@ func rewriteRefOfLockOption(parent SQLNode, node *LockOption, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -6455,10 +6455,10 @@ func rewriteRefOfLockTables(parent SQLNode, node *LockTables, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -6515,7 +6515,7 @@ func rewriteRefOfMatchExpr(parent SQLNode, node *MatchExpr, replacer replacerFun parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteSelectExprs(node, node.Columns, func(newNode, parent SQLNode) { @@ -6528,7 +6528,7 @@ func rewriteRefOfMatchExpr(parent SQLNode, node *MatchExpr, replacer replacerFun }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -6589,7 +6589,7 @@ func rewriteRefOfModifyColumn(parent SQLNode, node *ModifyColumn, replacer repla parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteRefOfColumnDefinition(node, node.NewColDefinition, func(newNode, parent SQLNode) { @@ -6607,7 +6607,7 @@ func rewriteRefOfModifyColumn(parent SQLNode, node *ModifyColumn, replacer repla }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -6658,7 +6658,7 @@ func rewriteRefOfNextval(parent SQLNode, node *Nextval, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -6666,7 +6666,7 @@ func rewriteRefOfNextval(parent SQLNode, node *Nextval, replacer replacerFunc, p }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -6717,7 +6717,7 @@ func rewriteRefOfNotExpr(parent SQLNode, node *NotExpr, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -6725,7 +6725,7 @@ func rewriteRefOfNotExpr(parent SQLNode, node *NotExpr, replacer replacerFunc, p }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -6772,10 +6772,10 @@ func rewriteRefOfNullVal(parent SQLNode, node *NullVal, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -6829,7 +6829,7 @@ func rewriteOnDup(parent SQLNode, node OnDup, replacer replacerFunc, pre, post A parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node { @@ -6839,7 +6839,7 @@ func rewriteOnDup(parent SQLNode, node OnDup, replacer replacerFunc, pre, post A return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -6890,7 +6890,7 @@ func rewriteRefOfOptLike(parent SQLNode, node *OptLike, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableName(node, node.LikeTable, func(newNode, parent SQLNode) { @@ -6898,7 +6898,7 @@ func rewriteRefOfOptLike(parent SQLNode, node *OptLike, replacer replacerFunc, p }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -6954,7 +6954,7 @@ func rewriteRefOfOrExpr(parent SQLNode, node *OrExpr, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { @@ -6967,7 +6967,7 @@ func rewriteRefOfOrExpr(parent SQLNode, node *OrExpr, replacer replacerFunc, pre }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7019,7 +7019,7 @@ func rewriteRefOfOrder(parent SQLNode, node *Order, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -7027,7 +7027,7 @@ func rewriteRefOfOrder(parent SQLNode, node *Order, replacer replacerFunc, pre, }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7081,7 +7081,7 @@ func rewriteOrderBy(parent SQLNode, node OrderBy, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node { @@ -7091,7 +7091,7 @@ func rewriteOrderBy(parent SQLNode, node OrderBy, replacer replacerFunc, pre, po return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7142,7 +7142,7 @@ func rewriteRefOfOrderByOption(parent SQLNode, node *OrderByOption, replacer rep parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteColumns(node, node.Cols, func(newNode, parent SQLNode) { @@ -7150,7 +7150,7 @@ func rewriteRefOfOrderByOption(parent SQLNode, node *OrderByOption, replacer rep }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7197,10 +7197,10 @@ func rewriteRefOfOtherAdmin(parent SQLNode, node *OtherAdmin, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7247,10 +7247,10 @@ func rewriteRefOfOtherRead(parent SQLNode, node *OtherRead, replacer replacerFun parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7301,7 +7301,7 @@ func rewriteRefOfParenSelect(parent SQLNode, node *ParenSelect, replacer replace parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { @@ -7309,7 +7309,7 @@ func rewriteRefOfParenSelect(parent SQLNode, node *ParenSelect, replacer replace }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7360,7 +7360,7 @@ func rewriteRefOfParenTableExpr(parent SQLNode, node *ParenTableExpr, replacer r parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableExprs(node, node.Exprs, func(newNode, parent SQLNode) { @@ -7368,7 +7368,7 @@ func rewriteRefOfParenTableExpr(parent SQLNode, node *ParenTableExpr, replacer r }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7425,7 +7425,7 @@ func rewriteRefOfPartitionDefinition(parent SQLNode, node *PartitionDefinition, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -7438,7 +7438,7 @@ func rewriteRefOfPartitionDefinition(parent SQLNode, node *PartitionDefinition, }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7509,7 +7509,7 @@ func rewriteRefOfPartitionSpec(parent SQLNode, node *PartitionSpec, replacer rep parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewritePartitions(node, node.Names, func(newNode, parent SQLNode) { @@ -7534,7 +7534,7 @@ func rewriteRefOfPartitionSpec(parent SQLNode, node *PartitionSpec, replacer rep return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7588,7 +7588,7 @@ func rewritePartitions(parent SQLNode, node Partitions, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node { @@ -7598,7 +7598,7 @@ func rewritePartitions(parent SQLNode, node Partitions, replacer replacerFunc, p return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7660,7 +7660,7 @@ func rewriteRefOfRangeCond(parent SQLNode, node *RangeCond, replacer replacerFun parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { @@ -7678,7 +7678,7 @@ func rewriteRefOfRangeCond(parent SQLNode, node *RangeCond, replacer replacerFun }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7729,7 +7729,7 @@ func rewriteRefOfRelease(parent SQLNode, node *Release, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -7737,7 +7737,7 @@ func rewriteRefOfRelease(parent SQLNode, node *Release, replacer replacerFunc, p }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7785,10 +7785,10 @@ func rewriteRefOfRenameIndex(parent SQLNode, node *RenameIndex, replacer replace parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7836,10 +7836,10 @@ func rewriteRefOfRenameTable(parent SQLNode, node *RenameTable, replacer replace parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7890,7 +7890,7 @@ func rewriteRefOfRenameTableName(parent SQLNode, node *RenameTableName, replacer parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { @@ -7898,7 +7898,7 @@ func rewriteRefOfRenameTableName(parent SQLNode, node *RenameTableName, replacer }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7945,10 +7945,10 @@ func rewriteRefOfRevertMigration(parent SQLNode, node *RevertMigration, replacer parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -7995,10 +7995,10 @@ func rewriteRefOfRollback(parent SQLNode, node *Rollback, replacer replacerFunc, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -8049,7 +8049,7 @@ func rewriteRefOfSRollback(parent SQLNode, node *SRollback, replacer replacerFun parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -8057,7 +8057,7 @@ func rewriteRefOfSRollback(parent SQLNode, node *SRollback, replacer replacerFun }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -8108,7 +8108,7 @@ func rewriteRefOfSavepoint(parent SQLNode, node *Savepoint, replacer replacerFun parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -8116,7 +8116,7 @@ func rewriteRefOfSavepoint(parent SQLNode, node *Savepoint, replacer replacerFun }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -8213,7 +8213,7 @@ func rewriteRefOfSelect(parent SQLNode, node *Select, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -8261,7 +8261,7 @@ func rewriteRefOfSelect(parent SQLNode, node *Select, replacer replacerFunc, pre }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -8315,7 +8315,7 @@ func rewriteSelectExprs(parent SQLNode, node SelectExprs, replacer replacerFunc, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node { @@ -8325,7 +8325,7 @@ func rewriteSelectExprs(parent SQLNode, node SelectExprs, replacer replacerFunc, return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -8378,10 +8378,10 @@ func rewriteRefOfSelectInto(parent SQLNode, node *SelectInto, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -8437,7 +8437,7 @@ func rewriteRefOfSet(parent SQLNode, node *Set, replacer replacerFunc, pre, post parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -8450,7 +8450,7 @@ func rewriteRefOfSet(parent SQLNode, node *Set, replacer replacerFunc, pre, post }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -8507,7 +8507,7 @@ func rewriteRefOfSetExpr(parent SQLNode, node *SetExpr, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -8520,7 +8520,7 @@ func rewriteRefOfSetExpr(parent SQLNode, node *SetExpr, replacer replacerFunc, p }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -8574,7 +8574,7 @@ func rewriteSetExprs(parent SQLNode, node SetExprs, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node { @@ -8584,7 +8584,7 @@ func rewriteSetExprs(parent SQLNode, node SetExprs, replacer replacerFunc, pre, return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -8648,7 +8648,7 @@ func rewriteRefOfSetTransaction(parent SQLNode, node *SetTransaction, replacer r parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteSQLNode(node, node.SQLNode, func(newNode, parent SQLNode) { @@ -8668,7 +8668,7 @@ func rewriteRefOfSetTransaction(parent SQLNode, node *SetTransaction, replacer r return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -8719,7 +8719,7 @@ func rewriteRefOfShow(parent SQLNode, node *Show, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteShowInternal(node, node.Internal, func(newNode, parent SQLNode) { @@ -8727,7 +8727,7 @@ func rewriteRefOfShow(parent SQLNode, node *Show, replacer replacerFunc, pre, po }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -8786,7 +8786,7 @@ func rewriteRefOfShowBasic(parent SQLNode, node *ShowBasic, replacer replacerFun parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableName(node, node.Tbl, func(newNode, parent SQLNode) { @@ -8799,7 +8799,7 @@ func rewriteRefOfShowBasic(parent SQLNode, node *ShowBasic, replacer replacerFun }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -8851,7 +8851,7 @@ func rewriteRefOfShowCreate(parent SQLNode, node *ShowCreate, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableName(node, node.Op, func(newNode, parent SQLNode) { @@ -8859,7 +8859,7 @@ func rewriteRefOfShowCreate(parent SQLNode, node *ShowCreate, replacer replacerF }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -8911,7 +8911,7 @@ func rewriteRefOfShowFilter(parent SQLNode, node *ShowFilter, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Filter, func(newNode, parent SQLNode) { @@ -8919,7 +8919,7 @@ func rewriteRefOfShowFilter(parent SQLNode, node *ShowFilter, replacer replacerF }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -8985,7 +8985,7 @@ func rewriteRefOfShowLegacy(parent SQLNode, node *ShowLegacy, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableName(node, node.OnTable, func(newNode, parent SQLNode) { @@ -9003,7 +9003,7 @@ func rewriteRefOfShowLegacy(parent SQLNode, node *ShowLegacy, replacer replacerF }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -9054,7 +9054,7 @@ func rewriteRefOfStarExpr(parent SQLNode, node *StarExpr, replacer replacerFunc, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableName(node, node.TableName, func(newNode, parent SQLNode) { @@ -9062,7 +9062,7 @@ func rewriteRefOfStarExpr(parent SQLNode, node *StarExpr, replacer replacerFunc, }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -9123,7 +9123,7 @@ func rewriteRefOfStream(parent SQLNode, node *Stream, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -9141,7 +9141,7 @@ func rewriteRefOfStream(parent SQLNode, node *Stream, replacer replacerFunc, pre }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -9192,7 +9192,7 @@ func rewriteRefOfSubquery(parent SQLNode, node *Subquery, replacer replacerFunc, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { @@ -9200,7 +9200,7 @@ func rewriteRefOfSubquery(parent SQLNode, node *Subquery, replacer replacerFunc, }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -9266,7 +9266,7 @@ func rewriteRefOfSubstrExpr(parent SQLNode, node *SubstrExpr, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { @@ -9289,7 +9289,7 @@ func rewriteRefOfSubstrExpr(parent SQLNode, node *SubstrExpr, replacer replacerF }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -9343,7 +9343,7 @@ func rewriteTableExprs(parent SQLNode, node TableExprs, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node { @@ -9353,7 +9353,7 @@ func rewriteTableExprs(parent SQLNode, node TableExprs, replacer replacerFunc, p return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -9385,13 +9385,13 @@ func rewriteTableIdent(parent SQLNode, node TableIdent, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if err != nil { return err } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -9430,7 +9430,7 @@ func rewriteTableName(parent SQLNode, node TableName, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -9446,7 +9446,7 @@ func rewriteTableName(parent SQLNode, node TableName, replacer replacerFunc, pre if err != nil { return err } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -9500,7 +9500,7 @@ func rewriteTableNames(parent SQLNode, node TableNames, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node { @@ -9510,7 +9510,7 @@ func rewriteTableNames(parent SQLNode, node TableNames, replacer replacerFunc, p return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -9554,10 +9554,10 @@ func rewriteTableOptions(parent SQLNode, node TableOptions, replacer replacerFun parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -9629,7 +9629,7 @@ func rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, replacer replacerFun parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node.Columns { @@ -9658,7 +9658,7 @@ func rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, replacer replacerFun }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -9705,10 +9705,10 @@ func rewriteRefOfTablespaceOperation(parent SQLNode, node *TablespaceOperation, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -9766,7 +9766,7 @@ func rewriteRefOfTimestampFuncExpr(parent SQLNode, node *TimestampFuncExpr, repl parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Expr1, func(newNode, parent SQLNode) { @@ -9779,7 +9779,7 @@ func rewriteRefOfTimestampFuncExpr(parent SQLNode, node *TimestampFuncExpr, repl }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -9830,7 +9830,7 @@ func rewriteRefOfTruncateTable(parent SQLNode, node *TruncateTable, replacer rep parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { @@ -9838,7 +9838,7 @@ func rewriteRefOfTruncateTable(parent SQLNode, node *TruncateTable, replacer rep }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -9890,7 +9890,7 @@ func rewriteRefOfUnaryExpr(parent SQLNode, node *UnaryExpr, replacer replacerFun parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -9898,7 +9898,7 @@ func rewriteRefOfUnaryExpr(parent SQLNode, node *UnaryExpr, replacer replacerFun }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -9967,7 +9967,7 @@ func rewriteRefOfUnion(parent SQLNode, node *Union, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteSelectStatement(node, node.FirstStatement, func(newNode, parent SQLNode) { @@ -9992,7 +9992,7 @@ func rewriteRefOfUnion(parent SQLNode, node *Union, replacer replacerFunc, pre, }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -10044,7 +10044,7 @@ func rewriteRefOfUnionSelect(parent SQLNode, node *UnionSelect, replacer replace parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteSelectStatement(node, node.Statement, func(newNode, parent SQLNode) { @@ -10052,7 +10052,7 @@ func rewriteRefOfUnionSelect(parent SQLNode, node *UnionSelect, replacer replace }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -10099,10 +10099,10 @@ func rewriteRefOfUnlockTables(parent SQLNode, node *UnlockTables, replacer repla parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -10179,7 +10179,7 @@ func rewriteRefOfUpdate(parent SQLNode, node *Update, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -10212,7 +10212,7 @@ func rewriteRefOfUpdate(parent SQLNode, node *Update, replacer replacerFunc, pre }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -10268,7 +10268,7 @@ func rewriteRefOfUpdateExpr(parent SQLNode, node *UpdateExpr, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { @@ -10281,7 +10281,7 @@ func rewriteRefOfUpdateExpr(parent SQLNode, node *UpdateExpr, replacer replacerF }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -10335,7 +10335,7 @@ func rewriteUpdateExprs(parent SQLNode, node UpdateExprs, replacer replacerFunc, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node { @@ -10345,7 +10345,7 @@ func rewriteUpdateExprs(parent SQLNode, node UpdateExprs, replacer replacerFunc, return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -10396,7 +10396,7 @@ func rewriteRefOfUse(parent SQLNode, node *Use, replacer replacerFunc, pre, post parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableIdent(node, node.DBName, func(newNode, parent SQLNode) { @@ -10404,7 +10404,7 @@ func rewriteRefOfUse(parent SQLNode, node *Use, replacer replacerFunc, pre, post }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -10475,7 +10475,7 @@ func rewriteRefOfVStream(parent SQLNode, node *VStream, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -10503,7 +10503,7 @@ func rewriteRefOfVStream(parent SQLNode, node *VStream, replacer replacerFunc, p }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -10557,7 +10557,7 @@ func rewriteValTuple(parent SQLNode, node ValTuple, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node { @@ -10567,7 +10567,7 @@ func rewriteValTuple(parent SQLNode, node ValTuple, replacer replacerFunc, pre, return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -10614,10 +10614,10 @@ func rewriteRefOfValidation(parent SQLNode, node *Validation, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -10671,7 +10671,7 @@ func rewriteValues(parent SQLNode, node Values, replacer replacerFunc, pre, post parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } for i, el := range node { @@ -10681,7 +10681,7 @@ func rewriteValues(parent SQLNode, node Values, replacer replacerFunc, pre, post return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -10732,7 +10732,7 @@ func rewriteRefOfValuesFuncExpr(parent SQLNode, node *ValuesFuncExpr, replacer r parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { @@ -10740,7 +10740,7 @@ func rewriteRefOfValuesFuncExpr(parent SQLNode, node *ValuesFuncExpr, replacer r }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -10776,7 +10776,7 @@ func rewriteVindexParam(parent SQLNode, node VindexParam, replacer replacerFunc, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { @@ -10787,7 +10787,7 @@ func rewriteVindexParam(parent SQLNode, node VindexParam, replacer replacerFunc, if err != nil { return err } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -10850,7 +10850,7 @@ func rewriteRefOfVindexSpec(parent SQLNode, node *VindexSpec, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -10870,7 +10870,7 @@ func rewriteRefOfVindexSpec(parent SQLNode, node *VindexSpec, replacer replacerF return errF } } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -10926,7 +10926,7 @@ func rewriteRefOfWhen(parent SQLNode, node *When, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Cond, func(newNode, parent SQLNode) { @@ -10939,7 +10939,7 @@ func rewriteRefOfWhen(parent SQLNode, node *When, replacer replacerFunc, pre, po }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -10991,7 +10991,7 @@ func rewriteRefOfWhere(parent SQLNode, node *Where, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -10999,7 +10999,7 @@ func rewriteRefOfWhere(parent SQLNode, node *Where, replacer replacerFunc, pre, }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -11055,7 +11055,7 @@ func rewriteRefOfXorExpr(parent SQLNode, node *XorExpr, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { @@ -11068,7 +11068,7 @@ func rewriteRefOfXorExpr(parent SQLNode, node *XorExpr, replacer replacerFunc, p }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -13400,10 +13400,10 @@ func rewriteAccessMode(parent SQLNode, node AccessMode, replacer replacerFunc, p parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -13422,10 +13422,10 @@ func rewriteAlgorithmValue(parent SQLNode, node AlgorithmValue, replacer replace parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -13444,10 +13444,10 @@ func rewriteArgument(parent SQLNode, node Argument, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -13466,10 +13466,10 @@ func rewriteBoolVal(parent SQLNode, node BoolVal, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -13488,10 +13488,10 @@ func rewriteIsolationLevel(parent SQLNode, node IsolationLevel, replacer replace parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -13510,10 +13510,10 @@ func rewriteReferenceAction(parent SQLNode, node ReferenceAction, replacer repla parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -13672,10 +13672,10 @@ func rewriteRefOfColIdent(parent SQLNode, node *ColIdent, replacer replacerFunc, parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -13841,7 +13841,7 @@ func rewriteRefOfJoinCondition(parent SQLNode, node *JoinCondition, replacer rep parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteExpr(node, node.On, func(newNode, parent SQLNode) { @@ -13854,7 +13854,7 @@ func rewriteRefOfJoinCondition(parent SQLNode, node *JoinCondition, replacer rep }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -14032,10 +14032,10 @@ func rewriteRefOfTableIdent(parent SQLNode, node *TableIdent, replacer replacerF parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -14091,7 +14091,7 @@ func rewriteRefOfTableName(parent SQLNode, node *TableName, replacer replacerFun parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -14104,7 +14104,7 @@ func rewriteRefOfTableName(parent SQLNode, node *TableName, replacer replacerFun }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil @@ -14247,7 +14247,7 @@ func rewriteRefOfVindexParam(parent SQLNode, node *VindexParam, replacer replace parent: parent, replacer: replacer, } - if !pre(&cur) { + if pre != nil && !pre(&cur) { return nil } if errF := rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { @@ -14255,7 +14255,7 @@ func rewriteRefOfVindexParam(parent SQLNode, node *VindexParam, replacer replace }, pre, post); errF != nil { return errF } - if !post(&cur) { + if post != nil && !post(&cur) { return errAbort } return nil diff --git a/go/vt/sqlparser/rewriter_api.go b/go/vt/sqlparser/rewriter_api.go index cac0042f81b..53371224a96 100644 --- a/go/vt/sqlparser/rewriter_api.go +++ b/go/vt/sqlparser/rewriter_api.go @@ -46,17 +46,6 @@ func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode, err error) { parent.SQLNode = newNode } - if pre == nil { - pre = func(*Cursor) bool { - return true - } - } - if post == nil { - post = func(*Cursor) bool { - return true - } - } - err = rewriteSQLNode(parent, node, replacer, pre, post) if err != nil && err != errAbort { return nil, err From 7dad1d12dd6a7384bb3c176ae13149c1f3782b28 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Sat, 20 Mar 2021 08:39:50 +0100 Subject: [PATCH 09/15] send pre/post with the application receiver Signed-off-by: Andres Taylor --- .../asthelpergen/integration/ast_helper.go | 180 +- .../asthelpergen/integration/test_helpers.go | 9 +- go/tools/asthelpergen/integration/types.go | 4 + go/tools/asthelpergen/rewrite_gen.go | 21 +- go/vt/sqlparser/ast_helper.go | 2392 ++++++++--------- go/vt/sqlparser/rewriter_api.go | 13 +- 6 files changed, 1319 insertions(+), 1300 deletions(-) diff --git a/go/tools/asthelpergen/integration/ast_helper.go b/go/tools/asthelpergen/integration/ast_helper.go index 8b775fb12da..e5d904b83a5 100644 --- a/go/tools/asthelpergen/integration/ast_helper.go +++ b/go/tools/asthelpergen/integration/ast_helper.go @@ -182,35 +182,35 @@ func VisitAST(in AST, f Visit) error { } // rewriteAST is part of the Rewrite implementation -func rewriteAST(parent AST, node AST, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteAST(parent AST, node AST, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case BasicType: - return rewriteBasicType(parent, node, replacer, pre, post) + return a.rewriteBasicType(parent, node, replacer) case Bytes: - return rewriteBytes(parent, node, replacer, pre, post) + return a.rewriteBytes(parent, node, replacer) case InterfaceContainer: - return rewriteInterfaceContainer(parent, node, replacer, pre, post) + return a.rewriteInterfaceContainer(parent, node, replacer) case InterfaceSlice: - return rewriteInterfaceSlice(parent, node, replacer, pre, post) + return a.rewriteInterfaceSlice(parent, node, replacer) case *Leaf: - return rewriteRefOfLeaf(parent, node, replacer, pre, post) + return a.rewriteRefOfLeaf(parent, node, replacer) case LeafSlice: - return rewriteLeafSlice(parent, node, replacer, pre, post) + return a.rewriteLeafSlice(parent, node, replacer) case *NoCloneType: - return rewriteRefOfNoCloneType(parent, node, replacer, pre, post) + return a.rewriteRefOfNoCloneType(parent, node, replacer) case *RefContainer: - return rewriteRefOfRefContainer(parent, node, replacer, pre, post) + return a.rewriteRefOfRefContainer(parent, node, replacer) case *RefSliceContainer: - return rewriteRefOfRefSliceContainer(parent, node, replacer, pre, post) + return a.rewriteRefOfRefSliceContainer(parent, node, replacer) case *SubImpl: - return rewriteRefOfSubImpl(parent, node, replacer, pre, post) + return a.rewriteRefOfSubImpl(parent, node, replacer) case ValueContainer: - return rewriteValueContainer(parent, node, replacer, pre, post) + return a.rewriteValueContainer(parent, node, replacer) case ValueSliceContainer: - return rewriteValueSliceContainer(parent, node, replacer, pre, post) + return a.rewriteValueSliceContainer(parent, node, replacer) default: // this should never happen return nil @@ -244,7 +244,7 @@ func VisitBytes(in Bytes, f Visit) error { } // rewriteBytes is part of the Rewrite implementation -func rewriteBytes(parent AST, node Bytes, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteBytes(parent AST, node Bytes, replacer replacerFunc) error { if node == nil { return nil } @@ -253,10 +253,10 @@ func rewriteBytes(parent AST, node Bytes, replacer replacerFunc, pre, post Apply parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -281,20 +281,20 @@ func VisitInterfaceContainer(in InterfaceContainer, f Visit) error { } // rewriteInterfaceContainer is part of the Rewrite implementation -func rewriteInterfaceContainer(parent AST, node InterfaceContainer, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteInterfaceContainer(parent AST, node InterfaceContainer, replacer replacerFunc) error { var err error cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } if err != nil { return err } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -339,7 +339,7 @@ func VisitInterfaceSlice(in InterfaceSlice, f Visit) error { } // rewriteInterfaceSlice is part of the Rewrite implementation -func rewriteInterfaceSlice(parent AST, node InterfaceSlice, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteInterfaceSlice(parent AST, node InterfaceSlice, replacer replacerFunc) error { if node == nil { return nil } @@ -348,17 +348,17 @@ func rewriteInterfaceSlice(parent AST, node InterfaceSlice, replacer replacerFun parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node { - if errF := rewriteAST(node, el, func(newNode, parent AST) { + if errF := a.rewriteAST(node, el, func(newNode, parent AST) { parent.(InterfaceSlice)[i] = newNode.(AST) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -396,7 +396,7 @@ func VisitRefOfLeaf(in *Leaf, f Visit) error { } // rewriteRefOfLeaf is part of the Rewrite implementation -func rewriteRefOfLeaf(parent AST, node *Leaf, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfLeaf(parent AST, node *Leaf, replacer replacerFunc) error { if node == nil { return nil } @@ -405,10 +405,10 @@ func rewriteRefOfLeaf(parent AST, node *Leaf, replacer replacerFunc, pre, post A parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -453,7 +453,7 @@ func VisitLeafSlice(in LeafSlice, f Visit) error { } // rewriteLeafSlice is part of the Rewrite implementation -func rewriteLeafSlice(parent AST, node LeafSlice, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteLeafSlice(parent AST, node LeafSlice, replacer replacerFunc) error { if node == nil { return nil } @@ -462,17 +462,17 @@ func rewriteLeafSlice(parent AST, node LeafSlice, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node { - if errF := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + if errF := a.rewriteRefOfLeaf(node, el, func(newNode, parent AST) { parent.(LeafSlice)[i] = newNode.(*Leaf) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -506,7 +506,7 @@ func VisitRefOfNoCloneType(in *NoCloneType, f Visit) error { } // rewriteRefOfNoCloneType is part of the Rewrite implementation -func rewriteRefOfNoCloneType(parent AST, node *NoCloneType, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfNoCloneType(parent AST, node *NoCloneType, replacer replacerFunc) error { if node == nil { return nil } @@ -515,10 +515,10 @@ func rewriteRefOfNoCloneType(parent AST, node *NoCloneType, replacer replacerFun parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -566,7 +566,7 @@ func VisitRefOfRefContainer(in *RefContainer, f Visit) error { } // rewriteRefOfRefContainer is part of the Rewrite implementation -func rewriteRefOfRefContainer(parent AST, node *RefContainer, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfRefContainer(parent AST, node *RefContainer, replacer replacerFunc) error { if node == nil { return nil } @@ -575,20 +575,20 @@ func rewriteRefOfRefContainer(parent AST, node *RefContainer, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { + if errF := a.rewriteAST(node, node.ASTType, func(newNode, parent AST) { parent.(*RefContainer).ASTType = newNode.(AST) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { + if errF := a.rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { parent.(*RefContainer).ASTImplementationType = newNode.(*Leaf) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -641,7 +641,7 @@ func VisitRefOfRefSliceContainer(in *RefSliceContainer, f Visit) error { } // rewriteRefOfRefSliceContainer is part of the Rewrite implementation -func rewriteRefOfRefSliceContainer(parent AST, node *RefSliceContainer, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfRefSliceContainer(parent AST, node *RefSliceContainer, replacer replacerFunc) error { if node == nil { return nil } @@ -650,24 +650,24 @@ func rewriteRefOfRefSliceContainer(parent AST, node *RefSliceContainer, replacer parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node.ASTElements { - if errF := rewriteAST(node, el, func(newNode, parent AST) { + if errF := a.rewriteAST(node, el, func(newNode, parent AST) { parent.(*RefSliceContainer).ASTElements[i] = newNode.(AST) - }, pre, post); errF != nil { + }); errF != nil { return errF } } for i, el := range node.ASTImplementationElements { - if errF := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + if errF := a.rewriteRefOfLeaf(node, el, func(newNode, parent AST) { parent.(*RefSliceContainer).ASTImplementationElements[i] = newNode.(*Leaf) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -711,7 +711,7 @@ func VisitRefOfSubImpl(in *SubImpl, f Visit) error { } // rewriteRefOfSubImpl is part of the Rewrite implementation -func rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer replacerFunc) error { if node == nil { return nil } @@ -720,15 +720,15 @@ func rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteSubIface(node, node.inner, func(newNode, parent AST) { + if errF := a.rewriteSubIface(node, node.inner, func(newNode, parent AST) { parent.(*SubImpl).inner = newNode.(SubIface) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -761,30 +761,30 @@ func VisitValueContainer(in ValueContainer, f Visit) error { } // rewriteValueContainer is part of the Rewrite implementation -func rewriteValueContainer(parent AST, node ValueContainer, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteValueContainer(parent AST, node ValueContainer, replacer replacerFunc) error { var err error cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { + if errF := a.rewriteAST(node, node.ASTType, func(newNode, parent AST) { err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTType' on 'ValueContainer'") - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { + if errF := a.rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTImplementationType' on 'ValueContainer'") - }, pre, post); errF != nil { + }); errF != nil { return errF } if err != nil { return err } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -821,34 +821,34 @@ func VisitValueSliceContainer(in ValueSliceContainer, f Visit) error { } // rewriteValueSliceContainer is part of the Rewrite implementation -func rewriteValueSliceContainer(parent AST, node ValueSliceContainer, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteValueSliceContainer(parent AST, node ValueSliceContainer, replacer replacerFunc) error { var err error cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for _, el := range node.ASTElements { - if errF := rewriteAST(node, el, func(newNode, parent AST) { + if errF := a.rewriteAST(node, el, func(newNode, parent AST) { err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTElements' on 'ValueSliceContainer'") - }, pre, post); errF != nil { + }); errF != nil { return errF } } for _, el := range node.ASTImplementationElements { - if errF := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + if errF := a.rewriteRefOfLeaf(node, el, func(newNode, parent AST) { err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTImplementationElements' on 'ValueSliceContainer'") - }, pre, post); errF != nil { + }); errF != nil { return errF } } if err != nil { return err } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -904,13 +904,13 @@ func VisitSubIface(in SubIface, f Visit) error { } // rewriteSubIface is part of the Rewrite implementation -func rewriteSubIface(parent AST, node SubIface, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteSubIface(parent AST, node SubIface, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case *SubImpl: - return rewriteRefOfSubImpl(parent, node, replacer, pre, post) + return a.rewriteRefOfSubImpl(parent, node, replacer) default: // this should never happen return nil @@ -924,16 +924,16 @@ func VisitBasicType(in BasicType, f Visit) error { } // rewriteBasicType is part of the Rewrite implementation -func rewriteBasicType(parent AST, node BasicType, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteBasicType(parent AST, node BasicType, replacer replacerFunc) error { cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -972,7 +972,7 @@ func VisitRefOfInterfaceContainer(in *InterfaceContainer, f Visit) error { } // rewriteRefOfInterfaceContainer is part of the Rewrite implementation -func rewriteRefOfInterfaceContainer(parent AST, node *InterfaceContainer, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfInterfaceContainer(parent AST, node *InterfaceContainer, replacer replacerFunc) error { if node == nil { return nil } @@ -981,10 +981,10 @@ func rewriteRefOfInterfaceContainer(parent AST, node *InterfaceContainer, replac parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -1116,7 +1116,7 @@ func VisitRefOfValueContainer(in *ValueContainer, f Visit) error { } // rewriteRefOfValueContainer is part of the Rewrite implementation -func rewriteRefOfValueContainer(parent AST, node *ValueContainer, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfValueContainer(parent AST, node *ValueContainer, replacer replacerFunc) error { if node == nil { return nil } @@ -1125,20 +1125,20 @@ func rewriteRefOfValueContainer(parent AST, node *ValueContainer, replacer repla parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { + if errF := a.rewriteAST(node, node.ASTType, func(newNode, parent AST) { parent.(*ValueContainer).ASTType = newNode.(AST) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { + if errF := a.rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { parent.(*ValueContainer).ASTImplementationType = newNode.(*Leaf) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -1191,7 +1191,7 @@ func VisitRefOfValueSliceContainer(in *ValueSliceContainer, f Visit) error { } // rewriteRefOfValueSliceContainer is part of the Rewrite implementation -func rewriteRefOfValueSliceContainer(parent AST, node *ValueSliceContainer, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfValueSliceContainer(parent AST, node *ValueSliceContainer, replacer replacerFunc) error { if node == nil { return nil } @@ -1200,24 +1200,24 @@ func rewriteRefOfValueSliceContainer(parent AST, node *ValueSliceContainer, repl parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node.ASTElements { - if errF := rewriteAST(node, el, func(newNode, parent AST) { + if errF := a.rewriteAST(node, el, func(newNode, parent AST) { parent.(*ValueSliceContainer).ASTElements[i] = newNode.(AST) - }, pre, post); errF != nil { + }); errF != nil { return errF } } for i, el := range node.ASTImplementationElements { - if errF := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + if errF := a.rewriteRefOfLeaf(node, el, func(newNode, parent AST) { parent.(*ValueSliceContainer).ASTImplementationElements[i] = newNode.(*Leaf) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil diff --git a/go/tools/asthelpergen/integration/test_helpers.go b/go/tools/asthelpergen/integration/test_helpers.go index cb1d62be847..6ca3df82bde 100644 --- a/go/tools/asthelpergen/integration/test_helpers.go +++ b/go/tools/asthelpergen/integration/test_helpers.go @@ -65,9 +65,14 @@ type replacerFunc func(newNode, parent AST) func Rewrite(node AST, pre, post ApplyFunc) (AST, error) { outer := &struct{ AST }{node} - err := rewriteAST(outer, node, func(newNode, parent AST) { + a := &application{ + pre: pre, + post: post, + } + + err := a.rewriteAST(outer, node, func(newNode, parent AST) { outer.AST = newNode - }, pre, post) + }) if err != nil { return nil, err diff --git a/go/tools/asthelpergen/integration/types.go b/go/tools/asthelpergen/integration/types.go index 1e25c50ed75..b45b6950839 100644 --- a/go/tools/asthelpergen/integration/types.go +++ b/go/tools/asthelpergen/integration/types.go @@ -175,3 +175,7 @@ func (r *NoCloneType) String() string { type Visit func(node AST) (bool, error) var errAbort = fmt.Errorf("this error is to abort the rewriter, it is not an actual error") + +type application struct { + pre, post ApplyFunc +} diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go index 63a0ae8497a..56cf2dcffe1 100644 --- a/go/tools/asthelpergen/rewrite_gen.go +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -64,7 +64,7 @@ func (e rewriteGen) interfaceMethod(t types.Type, iface *types.Interface, spi ge funcName := rewriteName + printableTypeName(t) spi.addType(t) caseBlock := jen.Case(jen.Id(typeString)).Block( - jen.Return(jen.Id(funcName).Call(jen.Id("parent, node, replacer, pre, post"))), + jen.Return(jen.Id("a").Dot(funcName).Call(jen.Id("parent, node, replacer"))), ) cases = append(cases, caseBlock) return nil @@ -224,11 +224,11 @@ func (e rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorS } func executePre() *jen.Statement { - return jen.If(jen.Id("pre!= nil && !pre(&cur)")).Block(returnNil()) + return jen.If(jen.Id("a.pre!= nil && !a.pre(&cur)")).Block(returnNil()) } func executePost() *jen.Statement { - return jen.If(jen.Id("post != nil && !post(&cur)")).Block(jen.Return(jen.Id(abort))) + return jen.If(jen.Id("a.post != nil && !a.post(&cur)")).Block(jen.Return(jen.Id(abort))) } func (e rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { @@ -255,10 +255,11 @@ func (e rewriteGen) rewriteFunc(t types.Type, stmts []jen.Code, spi generatorSPI typeString := types.TypeString(t, noQualifier) funcName := fmt.Sprintf("%s%s", rewriteName, printableTypeName(t)) - code := jen.Func().Id(funcName).Params( - jen.Id(fmt.Sprintf("parent %s, node %s, replacer replacerFunc, pre, post ApplyFunc", e.ifaceName, typeString)), - ).Error(). - Block(stmts...) + code := jen.Func().Params( + jen.Id("a").Op("*").Id("application"), + ).Id(funcName).Params( + jen.Id(fmt.Sprintf("parent %s, node %s, replacer replacerFunc", e.ifaceName, typeString)), + ).Error().Block(stmts...) spi.addFunc(funcName, rewrite, code) } @@ -334,12 +335,10 @@ func (e rewriteGen) rewriteChild(t, field types.Type, fieldName string, param je Block(replaceOrFail) rewriteField := jen.If( - jen.Id("errF := ").Id(funcName).Call( + jen.Id("errF := ").Id("a").Dot(funcName).Call( jen.Id("node"), param, - funcBlock, - jen.Id("pre"), - jen.Id("post")), + funcBlock), jen.Id("errF != nil").Block(jen.Return(jen.Id("errF")))) return rewriteField diff --git a/go/vt/sqlparser/ast_helper.go b/go/vt/sqlparser/ast_helper.go index c0ee1633720..593a4570490 100644 --- a/go/vt/sqlparser/ast_helper.go +++ b/go/vt/sqlparser/ast_helper.go @@ -1522,303 +1522,303 @@ func VisitSQLNode(in SQLNode, f Visit) error { } // rewriteSQLNode is part of the Rewrite implementation -func rewriteSQLNode(parent SQLNode, node SQLNode, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteSQLNode(parent SQLNode, node SQLNode, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case AccessMode: - return rewriteAccessMode(parent, node, replacer, pre, post) + return a.rewriteAccessMode(parent, node, replacer) case *AddColumns: - return rewriteRefOfAddColumns(parent, node, replacer, pre, post) + return a.rewriteRefOfAddColumns(parent, node, replacer) case *AddConstraintDefinition: - return rewriteRefOfAddConstraintDefinition(parent, node, replacer, pre, post) + return a.rewriteRefOfAddConstraintDefinition(parent, node, replacer) case *AddIndexDefinition: - return rewriteRefOfAddIndexDefinition(parent, node, replacer, pre, post) + return a.rewriteRefOfAddIndexDefinition(parent, node, replacer) case AlgorithmValue: - return rewriteAlgorithmValue(parent, node, replacer, pre, post) + return a.rewriteAlgorithmValue(parent, node, replacer) case *AliasedExpr: - return rewriteRefOfAliasedExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfAliasedExpr(parent, node, replacer) case *AliasedTableExpr: - return rewriteRefOfAliasedTableExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfAliasedTableExpr(parent, node, replacer) case *AlterCharset: - return rewriteRefOfAlterCharset(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterCharset(parent, node, replacer) case *AlterColumn: - return rewriteRefOfAlterColumn(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterColumn(parent, node, replacer) case *AlterDatabase: - return rewriteRefOfAlterDatabase(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterDatabase(parent, node, replacer) case *AlterMigration: - return rewriteRefOfAlterMigration(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterMigration(parent, node, replacer) case *AlterTable: - return rewriteRefOfAlterTable(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterTable(parent, node, replacer) case *AlterView: - return rewriteRefOfAlterView(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterView(parent, node, replacer) case *AlterVschema: - return rewriteRefOfAlterVschema(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterVschema(parent, node, replacer) case *AndExpr: - return rewriteRefOfAndExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfAndExpr(parent, node, replacer) case Argument: - return rewriteArgument(parent, node, replacer, pre, post) + return a.rewriteArgument(parent, node, replacer) case *AutoIncSpec: - return rewriteRefOfAutoIncSpec(parent, node, replacer, pre, post) + return a.rewriteRefOfAutoIncSpec(parent, node, replacer) case *Begin: - return rewriteRefOfBegin(parent, node, replacer, pre, post) + return a.rewriteRefOfBegin(parent, node, replacer) case *BinaryExpr: - return rewriteRefOfBinaryExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfBinaryExpr(parent, node, replacer) case BoolVal: - return rewriteBoolVal(parent, node, replacer, pre, post) + return a.rewriteBoolVal(parent, node, replacer) case *CallProc: - return rewriteRefOfCallProc(parent, node, replacer, pre, post) + return a.rewriteRefOfCallProc(parent, node, replacer) case *CaseExpr: - return rewriteRefOfCaseExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfCaseExpr(parent, node, replacer) case *ChangeColumn: - return rewriteRefOfChangeColumn(parent, node, replacer, pre, post) + return a.rewriteRefOfChangeColumn(parent, node, replacer) case *CheckConstraintDefinition: - return rewriteRefOfCheckConstraintDefinition(parent, node, replacer, pre, post) + return a.rewriteRefOfCheckConstraintDefinition(parent, node, replacer) case ColIdent: - return rewriteColIdent(parent, node, replacer, pre, post) + return a.rewriteColIdent(parent, node, replacer) case *ColName: - return rewriteRefOfColName(parent, node, replacer, pre, post) + return a.rewriteRefOfColName(parent, node, replacer) case *CollateExpr: - return rewriteRefOfCollateExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfCollateExpr(parent, node, replacer) case *ColumnDefinition: - return rewriteRefOfColumnDefinition(parent, node, replacer, pre, post) + return a.rewriteRefOfColumnDefinition(parent, node, replacer) case *ColumnType: - return rewriteRefOfColumnType(parent, node, replacer, pre, post) + return a.rewriteRefOfColumnType(parent, node, replacer) case Columns: - return rewriteColumns(parent, node, replacer, pre, post) + return a.rewriteColumns(parent, node, replacer) case Comments: - return rewriteComments(parent, node, replacer, pre, post) + return a.rewriteComments(parent, node, replacer) case *Commit: - return rewriteRefOfCommit(parent, node, replacer, pre, post) + return a.rewriteRefOfCommit(parent, node, replacer) case *ComparisonExpr: - return rewriteRefOfComparisonExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfComparisonExpr(parent, node, replacer) case *ConstraintDefinition: - return rewriteRefOfConstraintDefinition(parent, node, replacer, pre, post) + return a.rewriteRefOfConstraintDefinition(parent, node, replacer) case *ConvertExpr: - return rewriteRefOfConvertExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfConvertExpr(parent, node, replacer) case *ConvertType: - return rewriteRefOfConvertType(parent, node, replacer, pre, post) + return a.rewriteRefOfConvertType(parent, node, replacer) case *ConvertUsingExpr: - return rewriteRefOfConvertUsingExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfConvertUsingExpr(parent, node, replacer) case *CreateDatabase: - return rewriteRefOfCreateDatabase(parent, node, replacer, pre, post) + return a.rewriteRefOfCreateDatabase(parent, node, replacer) case *CreateTable: - return rewriteRefOfCreateTable(parent, node, replacer, pre, post) + return a.rewriteRefOfCreateTable(parent, node, replacer) case *CreateView: - return rewriteRefOfCreateView(parent, node, replacer, pre, post) + return a.rewriteRefOfCreateView(parent, node, replacer) case *CurTimeFuncExpr: - return rewriteRefOfCurTimeFuncExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfCurTimeFuncExpr(parent, node, replacer) case *Default: - return rewriteRefOfDefault(parent, node, replacer, pre, post) + return a.rewriteRefOfDefault(parent, node, replacer) case *Delete: - return rewriteRefOfDelete(parent, node, replacer, pre, post) + return a.rewriteRefOfDelete(parent, node, replacer) case *DerivedTable: - return rewriteRefOfDerivedTable(parent, node, replacer, pre, post) + return a.rewriteRefOfDerivedTable(parent, node, replacer) case *DropColumn: - return rewriteRefOfDropColumn(parent, node, replacer, pre, post) + return a.rewriteRefOfDropColumn(parent, node, replacer) case *DropDatabase: - return rewriteRefOfDropDatabase(parent, node, replacer, pre, post) + return a.rewriteRefOfDropDatabase(parent, node, replacer) case *DropKey: - return rewriteRefOfDropKey(parent, node, replacer, pre, post) + return a.rewriteRefOfDropKey(parent, node, replacer) case *DropTable: - return rewriteRefOfDropTable(parent, node, replacer, pre, post) + return a.rewriteRefOfDropTable(parent, node, replacer) case *DropView: - return rewriteRefOfDropView(parent, node, replacer, pre, post) + return a.rewriteRefOfDropView(parent, node, replacer) case *ExistsExpr: - return rewriteRefOfExistsExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfExistsExpr(parent, node, replacer) case *ExplainStmt: - return rewriteRefOfExplainStmt(parent, node, replacer, pre, post) + return a.rewriteRefOfExplainStmt(parent, node, replacer) case *ExplainTab: - return rewriteRefOfExplainTab(parent, node, replacer, pre, post) + return a.rewriteRefOfExplainTab(parent, node, replacer) case Exprs: - return rewriteExprs(parent, node, replacer, pre, post) + return a.rewriteExprs(parent, node, replacer) case *Flush: - return rewriteRefOfFlush(parent, node, replacer, pre, post) + return a.rewriteRefOfFlush(parent, node, replacer) case *Force: - return rewriteRefOfForce(parent, node, replacer, pre, post) + return a.rewriteRefOfForce(parent, node, replacer) case *ForeignKeyDefinition: - return rewriteRefOfForeignKeyDefinition(parent, node, replacer, pre, post) + return a.rewriteRefOfForeignKeyDefinition(parent, node, replacer) case *FuncExpr: - return rewriteRefOfFuncExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfFuncExpr(parent, node, replacer) case GroupBy: - return rewriteGroupBy(parent, node, replacer, pre, post) + return a.rewriteGroupBy(parent, node, replacer) case *GroupConcatExpr: - return rewriteRefOfGroupConcatExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfGroupConcatExpr(parent, node, replacer) case *IndexDefinition: - return rewriteRefOfIndexDefinition(parent, node, replacer, pre, post) + return a.rewriteRefOfIndexDefinition(parent, node, replacer) case *IndexHints: - return rewriteRefOfIndexHints(parent, node, replacer, pre, post) + return a.rewriteRefOfIndexHints(parent, node, replacer) case *IndexInfo: - return rewriteRefOfIndexInfo(parent, node, replacer, pre, post) + return a.rewriteRefOfIndexInfo(parent, node, replacer) case *Insert: - return rewriteRefOfInsert(parent, node, replacer, pre, post) + return a.rewriteRefOfInsert(parent, node, replacer) case *IntervalExpr: - return rewriteRefOfIntervalExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfIntervalExpr(parent, node, replacer) case *IsExpr: - return rewriteRefOfIsExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfIsExpr(parent, node, replacer) case IsolationLevel: - return rewriteIsolationLevel(parent, node, replacer, pre, post) + return a.rewriteIsolationLevel(parent, node, replacer) case JoinCondition: - return rewriteJoinCondition(parent, node, replacer, pre, post) + return a.rewriteJoinCondition(parent, node, replacer) case *JoinTableExpr: - return rewriteRefOfJoinTableExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfJoinTableExpr(parent, node, replacer) case *KeyState: - return rewriteRefOfKeyState(parent, node, replacer, pre, post) + return a.rewriteRefOfKeyState(parent, node, replacer) case *Limit: - return rewriteRefOfLimit(parent, node, replacer, pre, post) + return a.rewriteRefOfLimit(parent, node, replacer) case ListArg: - return rewriteListArg(parent, node, replacer, pre, post) + return a.rewriteListArg(parent, node, replacer) case *Literal: - return rewriteRefOfLiteral(parent, node, replacer, pre, post) + return a.rewriteRefOfLiteral(parent, node, replacer) case *Load: - return rewriteRefOfLoad(parent, node, replacer, pre, post) + return a.rewriteRefOfLoad(parent, node, replacer) case *LockOption: - return rewriteRefOfLockOption(parent, node, replacer, pre, post) + return a.rewriteRefOfLockOption(parent, node, replacer) case *LockTables: - return rewriteRefOfLockTables(parent, node, replacer, pre, post) + return a.rewriteRefOfLockTables(parent, node, replacer) case *MatchExpr: - return rewriteRefOfMatchExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfMatchExpr(parent, node, replacer) case *ModifyColumn: - return rewriteRefOfModifyColumn(parent, node, replacer, pre, post) + return a.rewriteRefOfModifyColumn(parent, node, replacer) case *Nextval: - return rewriteRefOfNextval(parent, node, replacer, pre, post) + return a.rewriteRefOfNextval(parent, node, replacer) case *NotExpr: - return rewriteRefOfNotExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfNotExpr(parent, node, replacer) case *NullVal: - return rewriteRefOfNullVal(parent, node, replacer, pre, post) + return a.rewriteRefOfNullVal(parent, node, replacer) case OnDup: - return rewriteOnDup(parent, node, replacer, pre, post) + return a.rewriteOnDup(parent, node, replacer) case *OptLike: - return rewriteRefOfOptLike(parent, node, replacer, pre, post) + return a.rewriteRefOfOptLike(parent, node, replacer) case *OrExpr: - return rewriteRefOfOrExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfOrExpr(parent, node, replacer) case *Order: - return rewriteRefOfOrder(parent, node, replacer, pre, post) + return a.rewriteRefOfOrder(parent, node, replacer) case OrderBy: - return rewriteOrderBy(parent, node, replacer, pre, post) + return a.rewriteOrderBy(parent, node, replacer) case *OrderByOption: - return rewriteRefOfOrderByOption(parent, node, replacer, pre, post) + return a.rewriteRefOfOrderByOption(parent, node, replacer) case *OtherAdmin: - return rewriteRefOfOtherAdmin(parent, node, replacer, pre, post) + return a.rewriteRefOfOtherAdmin(parent, node, replacer) case *OtherRead: - return rewriteRefOfOtherRead(parent, node, replacer, pre, post) + return a.rewriteRefOfOtherRead(parent, node, replacer) case *ParenSelect: - return rewriteRefOfParenSelect(parent, node, replacer, pre, post) + return a.rewriteRefOfParenSelect(parent, node, replacer) case *ParenTableExpr: - return rewriteRefOfParenTableExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfParenTableExpr(parent, node, replacer) case *PartitionDefinition: - return rewriteRefOfPartitionDefinition(parent, node, replacer, pre, post) + return a.rewriteRefOfPartitionDefinition(parent, node, replacer) case *PartitionSpec: - return rewriteRefOfPartitionSpec(parent, node, replacer, pre, post) + return a.rewriteRefOfPartitionSpec(parent, node, replacer) case Partitions: - return rewritePartitions(parent, node, replacer, pre, post) + return a.rewritePartitions(parent, node, replacer) case *RangeCond: - return rewriteRefOfRangeCond(parent, node, replacer, pre, post) + return a.rewriteRefOfRangeCond(parent, node, replacer) case ReferenceAction: - return rewriteReferenceAction(parent, node, replacer, pre, post) + return a.rewriteReferenceAction(parent, node, replacer) case *Release: - return rewriteRefOfRelease(parent, node, replacer, pre, post) + return a.rewriteRefOfRelease(parent, node, replacer) case *RenameIndex: - return rewriteRefOfRenameIndex(parent, node, replacer, pre, post) + return a.rewriteRefOfRenameIndex(parent, node, replacer) case *RenameTable: - return rewriteRefOfRenameTable(parent, node, replacer, pre, post) + return a.rewriteRefOfRenameTable(parent, node, replacer) case *RenameTableName: - return rewriteRefOfRenameTableName(parent, node, replacer, pre, post) + return a.rewriteRefOfRenameTableName(parent, node, replacer) case *RevertMigration: - return rewriteRefOfRevertMigration(parent, node, replacer, pre, post) + return a.rewriteRefOfRevertMigration(parent, node, replacer) case *Rollback: - return rewriteRefOfRollback(parent, node, replacer, pre, post) + return a.rewriteRefOfRollback(parent, node, replacer) case *SRollback: - return rewriteRefOfSRollback(parent, node, replacer, pre, post) + return a.rewriteRefOfSRollback(parent, node, replacer) case *Savepoint: - return rewriteRefOfSavepoint(parent, node, replacer, pre, post) + return a.rewriteRefOfSavepoint(parent, node, replacer) case *Select: - return rewriteRefOfSelect(parent, node, replacer, pre, post) + return a.rewriteRefOfSelect(parent, node, replacer) case SelectExprs: - return rewriteSelectExprs(parent, node, replacer, pre, post) + return a.rewriteSelectExprs(parent, node, replacer) case *SelectInto: - return rewriteRefOfSelectInto(parent, node, replacer, pre, post) + return a.rewriteRefOfSelectInto(parent, node, replacer) case *Set: - return rewriteRefOfSet(parent, node, replacer, pre, post) + return a.rewriteRefOfSet(parent, node, replacer) case *SetExpr: - return rewriteRefOfSetExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfSetExpr(parent, node, replacer) case SetExprs: - return rewriteSetExprs(parent, node, replacer, pre, post) + return a.rewriteSetExprs(parent, node, replacer) case *SetTransaction: - return rewriteRefOfSetTransaction(parent, node, replacer, pre, post) + return a.rewriteRefOfSetTransaction(parent, node, replacer) case *Show: - return rewriteRefOfShow(parent, node, replacer, pre, post) + return a.rewriteRefOfShow(parent, node, replacer) case *ShowBasic: - return rewriteRefOfShowBasic(parent, node, replacer, pre, post) + return a.rewriteRefOfShowBasic(parent, node, replacer) case *ShowCreate: - return rewriteRefOfShowCreate(parent, node, replacer, pre, post) + return a.rewriteRefOfShowCreate(parent, node, replacer) case *ShowFilter: - return rewriteRefOfShowFilter(parent, node, replacer, pre, post) + return a.rewriteRefOfShowFilter(parent, node, replacer) case *ShowLegacy: - return rewriteRefOfShowLegacy(parent, node, replacer, pre, post) + return a.rewriteRefOfShowLegacy(parent, node, replacer) case *StarExpr: - return rewriteRefOfStarExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfStarExpr(parent, node, replacer) case *Stream: - return rewriteRefOfStream(parent, node, replacer, pre, post) + return a.rewriteRefOfStream(parent, node, replacer) case *Subquery: - return rewriteRefOfSubquery(parent, node, replacer, pre, post) + return a.rewriteRefOfSubquery(parent, node, replacer) case *SubstrExpr: - return rewriteRefOfSubstrExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfSubstrExpr(parent, node, replacer) case TableExprs: - return rewriteTableExprs(parent, node, replacer, pre, post) + return a.rewriteTableExprs(parent, node, replacer) case TableIdent: - return rewriteTableIdent(parent, node, replacer, pre, post) + return a.rewriteTableIdent(parent, node, replacer) case TableName: - return rewriteTableName(parent, node, replacer, pre, post) + return a.rewriteTableName(parent, node, replacer) case TableNames: - return rewriteTableNames(parent, node, replacer, pre, post) + return a.rewriteTableNames(parent, node, replacer) case TableOptions: - return rewriteTableOptions(parent, node, replacer, pre, post) + return a.rewriteTableOptions(parent, node, replacer) case *TableSpec: - return rewriteRefOfTableSpec(parent, node, replacer, pre, post) + return a.rewriteRefOfTableSpec(parent, node, replacer) case *TablespaceOperation: - return rewriteRefOfTablespaceOperation(parent, node, replacer, pre, post) + return a.rewriteRefOfTablespaceOperation(parent, node, replacer) case *TimestampFuncExpr: - return rewriteRefOfTimestampFuncExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfTimestampFuncExpr(parent, node, replacer) case *TruncateTable: - return rewriteRefOfTruncateTable(parent, node, replacer, pre, post) + return a.rewriteRefOfTruncateTable(parent, node, replacer) case *UnaryExpr: - return rewriteRefOfUnaryExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfUnaryExpr(parent, node, replacer) case *Union: - return rewriteRefOfUnion(parent, node, replacer, pre, post) + return a.rewriteRefOfUnion(parent, node, replacer) case *UnionSelect: - return rewriteRefOfUnionSelect(parent, node, replacer, pre, post) + return a.rewriteRefOfUnionSelect(parent, node, replacer) case *UnlockTables: - return rewriteRefOfUnlockTables(parent, node, replacer, pre, post) + return a.rewriteRefOfUnlockTables(parent, node, replacer) case *Update: - return rewriteRefOfUpdate(parent, node, replacer, pre, post) + return a.rewriteRefOfUpdate(parent, node, replacer) case *UpdateExpr: - return rewriteRefOfUpdateExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfUpdateExpr(parent, node, replacer) case UpdateExprs: - return rewriteUpdateExprs(parent, node, replacer, pre, post) + return a.rewriteUpdateExprs(parent, node, replacer) case *Use: - return rewriteRefOfUse(parent, node, replacer, pre, post) + return a.rewriteRefOfUse(parent, node, replacer) case *VStream: - return rewriteRefOfVStream(parent, node, replacer, pre, post) + return a.rewriteRefOfVStream(parent, node, replacer) case ValTuple: - return rewriteValTuple(parent, node, replacer, pre, post) + return a.rewriteValTuple(parent, node, replacer) case *Validation: - return rewriteRefOfValidation(parent, node, replacer, pre, post) + return a.rewriteRefOfValidation(parent, node, replacer) case Values: - return rewriteValues(parent, node, replacer, pre, post) + return a.rewriteValues(parent, node, replacer) case *ValuesFuncExpr: - return rewriteRefOfValuesFuncExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfValuesFuncExpr(parent, node, replacer) case VindexParam: - return rewriteVindexParam(parent, node, replacer, pre, post) + return a.rewriteVindexParam(parent, node, replacer) case *VindexSpec: - return rewriteRefOfVindexSpec(parent, node, replacer, pre, post) + return a.rewriteRefOfVindexSpec(parent, node, replacer) case *When: - return rewriteRefOfWhen(parent, node, replacer, pre, post) + return a.rewriteRefOfWhen(parent, node, replacer) case *Where: - return rewriteRefOfWhere(parent, node, replacer, pre, post) + return a.rewriteRefOfWhere(parent, node, replacer) case *XorExpr: - return rewriteRefOfXorExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfXorExpr(parent, node, replacer) default: // this should never happen return nil @@ -1873,7 +1873,7 @@ func VisitRefOfAddColumns(in *AddColumns, f Visit) error { } // rewriteRefOfAddColumns is part of the Rewrite implementation -func rewriteRefOfAddColumns(parent SQLNode, node *AddColumns, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfAddColumns(parent SQLNode, node *AddColumns, replacer replacerFunc) error { if node == nil { return nil } @@ -1882,27 +1882,27 @@ func rewriteRefOfAddColumns(parent SQLNode, node *AddColumns, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node.Columns { - if errF := rewriteRefOfColumnDefinition(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfColumnDefinition(node, el, func(newNode, parent SQLNode) { parent.(*AddColumns).Columns[i] = newNode.(*ColumnDefinition) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if errF := rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { parent.(*AddColumns).First = newNode.(*ColName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { parent.(*AddColumns).After = newNode.(*ColName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -1944,7 +1944,7 @@ func VisitRefOfAddConstraintDefinition(in *AddConstraintDefinition, f Visit) err } // rewriteRefOfAddConstraintDefinition is part of the Rewrite implementation -func rewriteRefOfAddConstraintDefinition(parent SQLNode, node *AddConstraintDefinition, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfAddConstraintDefinition(parent SQLNode, node *AddConstraintDefinition, replacer replacerFunc) error { if node == nil { return nil } @@ -1953,15 +1953,15 @@ func rewriteRefOfAddConstraintDefinition(parent SQLNode, node *AddConstraintDefi parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteRefOfConstraintDefinition(node, node.ConstraintDefinition, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfConstraintDefinition(node, node.ConstraintDefinition, func(newNode, parent SQLNode) { parent.(*AddConstraintDefinition).ConstraintDefinition = newNode.(*ConstraintDefinition) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -2003,7 +2003,7 @@ func VisitRefOfAddIndexDefinition(in *AddIndexDefinition, f Visit) error { } // rewriteRefOfAddIndexDefinition is part of the Rewrite implementation -func rewriteRefOfAddIndexDefinition(parent SQLNode, node *AddIndexDefinition, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfAddIndexDefinition(parent SQLNode, node *AddIndexDefinition, replacer replacerFunc) error { if node == nil { return nil } @@ -2012,15 +2012,15 @@ func rewriteRefOfAddIndexDefinition(parent SQLNode, node *AddIndexDefinition, re parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteRefOfIndexDefinition(node, node.IndexDefinition, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfIndexDefinition(node, node.IndexDefinition, func(newNode, parent SQLNode) { parent.(*AddIndexDefinition).IndexDefinition = newNode.(*IndexDefinition) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -2067,7 +2067,7 @@ func VisitRefOfAliasedExpr(in *AliasedExpr, f Visit) error { } // rewriteRefOfAliasedExpr is part of the Rewrite implementation -func rewriteRefOfAliasedExpr(parent SQLNode, node *AliasedExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfAliasedExpr(parent SQLNode, node *AliasedExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -2076,20 +2076,20 @@ func rewriteRefOfAliasedExpr(parent SQLNode, node *AliasedExpr, replacer replace parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*AliasedExpr).Expr = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteColIdent(node, node.As, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.As, func(newNode, parent SQLNode) { parent.(*AliasedExpr).As = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -2146,7 +2146,7 @@ func VisitRefOfAliasedTableExpr(in *AliasedTableExpr, f Visit) error { } // rewriteRefOfAliasedTableExpr is part of the Rewrite implementation -func rewriteRefOfAliasedTableExpr(parent SQLNode, node *AliasedTableExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfAliasedTableExpr(parent SQLNode, node *AliasedTableExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -2155,30 +2155,30 @@ func rewriteRefOfAliasedTableExpr(parent SQLNode, node *AliasedTableExpr, replac parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteSimpleTableExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteSimpleTableExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*AliasedTableExpr).Expr = newNode.(SimpleTableExpr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { + if errF := a.rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { parent.(*AliasedTableExpr).Partitions = newNode.(Partitions) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteTableIdent(node, node.As, func(newNode, parent SQLNode) { + if errF := a.rewriteTableIdent(node, node.As, func(newNode, parent SQLNode) { parent.(*AliasedTableExpr).As = newNode.(TableIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfIndexHints(node, node.Hints, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfIndexHints(node, node.Hints, func(newNode, parent SQLNode) { parent.(*AliasedTableExpr).Hints = newNode.(*IndexHints) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -2217,7 +2217,7 @@ func VisitRefOfAlterCharset(in *AlterCharset, f Visit) error { } // rewriteRefOfAlterCharset is part of the Rewrite implementation -func rewriteRefOfAlterCharset(parent SQLNode, node *AlterCharset, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfAlterCharset(parent SQLNode, node *AlterCharset, replacer replacerFunc) error { if node == nil { return nil } @@ -2226,10 +2226,10 @@ func rewriteRefOfAlterCharset(parent SQLNode, node *AlterCharset, replacer repla parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -2277,7 +2277,7 @@ func VisitRefOfAlterColumn(in *AlterColumn, f Visit) error { } // rewriteRefOfAlterColumn is part of the Rewrite implementation -func rewriteRefOfAlterColumn(parent SQLNode, node *AlterColumn, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfAlterColumn(parent SQLNode, node *AlterColumn, replacer replacerFunc) error { if node == nil { return nil } @@ -2286,20 +2286,20 @@ func rewriteRefOfAlterColumn(parent SQLNode, node *AlterColumn, replacer replace parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteRefOfColName(node, node.Column, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfColName(node, node.Column, func(newNode, parent SQLNode) { parent.(*AlterColumn).Column = newNode.(*ColName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.DefaultVal, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.DefaultVal, func(newNode, parent SQLNode) { parent.(*AlterColumn).DefaultVal = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -2341,7 +2341,7 @@ func VisitRefOfAlterDatabase(in *AlterDatabase, f Visit) error { } // rewriteRefOfAlterDatabase is part of the Rewrite implementation -func rewriteRefOfAlterDatabase(parent SQLNode, node *AlterDatabase, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfAlterDatabase(parent SQLNode, node *AlterDatabase, replacer replacerFunc) error { if node == nil { return nil } @@ -2350,10 +2350,10 @@ func rewriteRefOfAlterDatabase(parent SQLNode, node *AlterDatabase, replacer rep parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -2392,7 +2392,7 @@ func VisitRefOfAlterMigration(in *AlterMigration, f Visit) error { } // rewriteRefOfAlterMigration is part of the Rewrite implementation -func rewriteRefOfAlterMigration(parent SQLNode, node *AlterMigration, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfAlterMigration(parent SQLNode, node *AlterMigration, replacer replacerFunc) error { if node == nil { return nil } @@ -2401,10 +2401,10 @@ func rewriteRefOfAlterMigration(parent SQLNode, node *AlterMigration, replacer r parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -2459,7 +2459,7 @@ func VisitRefOfAlterTable(in *AlterTable, f Visit) error { } // rewriteRefOfAlterTable is part of the Rewrite implementation -func rewriteRefOfAlterTable(parent SQLNode, node *AlterTable, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfAlterTable(parent SQLNode, node *AlterTable, replacer replacerFunc) error { if node == nil { return nil } @@ -2468,27 +2468,27 @@ func rewriteRefOfAlterTable(parent SQLNode, node *AlterTable, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { parent.(*AlterTable).Table = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } for i, el := range node.AlterOptions { - if errF := rewriteAlterOption(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteAlterOption(node, el, func(newNode, parent SQLNode) { parent.(*AlterTable).AlterOptions[i] = newNode.(AlterOption) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if errF := rewriteRefOfPartitionSpec(node, node.PartitionSpec, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfPartitionSpec(node, node.PartitionSpec, func(newNode, parent SQLNode) { parent.(*AlterTable).PartitionSpec = newNode.(*PartitionSpec) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -2544,7 +2544,7 @@ func VisitRefOfAlterView(in *AlterView, f Visit) error { } // rewriteRefOfAlterView is part of the Rewrite implementation -func rewriteRefOfAlterView(parent SQLNode, node *AlterView, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfAlterView(parent SQLNode, node *AlterView, replacer replacerFunc) error { if node == nil { return nil } @@ -2553,25 +2553,25 @@ func rewriteRefOfAlterView(parent SQLNode, node *AlterView, replacer replacerFun parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { parent.(*AlterView).ViewName = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { + if errF := a.rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { parent.(*AlterView).Columns = newNode.(Columns) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { parent.(*AlterView).Select = newNode.(SelectStatement) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -2631,7 +2631,7 @@ func VisitRefOfAlterVschema(in *AlterVschema, f Visit) error { } // rewriteRefOfAlterVschema is part of the Rewrite implementation -func rewriteRefOfAlterVschema(parent SQLNode, node *AlterVschema, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfAlterVschema(parent SQLNode, node *AlterVschema, replacer replacerFunc) error { if node == nil { return nil } @@ -2640,32 +2640,32 @@ func rewriteRefOfAlterVschema(parent SQLNode, node *AlterVschema, replacer repla parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { parent.(*AlterVschema).Table = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfVindexSpec(node, node.VindexSpec, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfVindexSpec(node, node.VindexSpec, func(newNode, parent SQLNode) { parent.(*AlterVschema).VindexSpec = newNode.(*VindexSpec) - }, pre, post); errF != nil { + }); errF != nil { return errF } for i, el := range node.VindexCols { - if errF := rewriteColIdent(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { parent.(*AlterVschema).VindexCols[i] = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if errF := rewriteRefOfAutoIncSpec(node, node.AutoIncSpec, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfAutoIncSpec(node, node.AutoIncSpec, func(newNode, parent SQLNode) { parent.(*AlterVschema).AutoIncSpec = newNode.(*AutoIncSpec) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -2712,7 +2712,7 @@ func VisitRefOfAndExpr(in *AndExpr, f Visit) error { } // rewriteRefOfAndExpr is part of the Rewrite implementation -func rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -2721,20 +2721,20 @@ func rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replacer replacerFunc, p parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { parent.(*AndExpr).Left = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { parent.(*AndExpr).Right = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -2781,7 +2781,7 @@ func VisitRefOfAutoIncSpec(in *AutoIncSpec, f Visit) error { } // rewriteRefOfAutoIncSpec is part of the Rewrite implementation -func rewriteRefOfAutoIncSpec(parent SQLNode, node *AutoIncSpec, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfAutoIncSpec(parent SQLNode, node *AutoIncSpec, replacer replacerFunc) error { if node == nil { return nil } @@ -2790,20 +2790,20 @@ func rewriteRefOfAutoIncSpec(parent SQLNode, node *AutoIncSpec, replacer replace parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteColIdent(node, node.Column, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.Column, func(newNode, parent SQLNode) { parent.(*AutoIncSpec).Column = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteTableName(node, node.Sequence, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.Sequence, func(newNode, parent SQLNode) { parent.(*AutoIncSpec).Sequence = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -2841,7 +2841,7 @@ func VisitRefOfBegin(in *Begin, f Visit) error { } // rewriteRefOfBegin is part of the Rewrite implementation -func rewriteRefOfBegin(parent SQLNode, node *Begin, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfBegin(parent SQLNode, node *Begin, replacer replacerFunc) error { if node == nil { return nil } @@ -2850,10 +2850,10 @@ func rewriteRefOfBegin(parent SQLNode, node *Begin, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -2901,7 +2901,7 @@ func VisitRefOfBinaryExpr(in *BinaryExpr, f Visit) error { } // rewriteRefOfBinaryExpr is part of the Rewrite implementation -func rewriteRefOfBinaryExpr(parent SQLNode, node *BinaryExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfBinaryExpr(parent SQLNode, node *BinaryExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -2910,20 +2910,20 @@ func rewriteRefOfBinaryExpr(parent SQLNode, node *BinaryExpr, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { parent.(*BinaryExpr).Left = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { parent.(*BinaryExpr).Right = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -2970,7 +2970,7 @@ func VisitRefOfCallProc(in *CallProc, f Visit) error { } // rewriteRefOfCallProc is part of the Rewrite implementation -func rewriteRefOfCallProc(parent SQLNode, node *CallProc, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfCallProc(parent SQLNode, node *CallProc, replacer replacerFunc) error { if node == nil { return nil } @@ -2979,20 +2979,20 @@ func rewriteRefOfCallProc(parent SQLNode, node *CallProc, replacer replacerFunc, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableName(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.Name, func(newNode, parent SQLNode) { parent.(*CallProc).Name = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExprs(node, node.Params, func(newNode, parent SQLNode) { + if errF := a.rewriteExprs(node, node.Params, func(newNode, parent SQLNode) { parent.(*CallProc).Params = newNode.(Exprs) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -3046,7 +3046,7 @@ func VisitRefOfCaseExpr(in *CaseExpr, f Visit) error { } // rewriteRefOfCaseExpr is part of the Rewrite implementation -func rewriteRefOfCaseExpr(parent SQLNode, node *CaseExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfCaseExpr(parent SQLNode, node *CaseExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -3055,27 +3055,27 @@ func rewriteRefOfCaseExpr(parent SQLNode, node *CaseExpr, replacer replacerFunc, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*CaseExpr).Expr = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } for i, el := range node.Whens { - if errF := rewriteRefOfWhen(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfWhen(node, el, func(newNode, parent SQLNode) { parent.(*CaseExpr).Whens[i] = newNode.(*When) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if errF := rewriteExpr(node, node.Else, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Else, func(newNode, parent SQLNode) { parent.(*CaseExpr).Else = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -3132,7 +3132,7 @@ func VisitRefOfChangeColumn(in *ChangeColumn, f Visit) error { } // rewriteRefOfChangeColumn is part of the Rewrite implementation -func rewriteRefOfChangeColumn(parent SQLNode, node *ChangeColumn, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfChangeColumn(parent SQLNode, node *ChangeColumn, replacer replacerFunc) error { if node == nil { return nil } @@ -3141,30 +3141,30 @@ func rewriteRefOfChangeColumn(parent SQLNode, node *ChangeColumn, replacer repla parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteRefOfColName(node, node.OldColumn, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfColName(node, node.OldColumn, func(newNode, parent SQLNode) { parent.(*ChangeColumn).OldColumn = newNode.(*ColName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfColumnDefinition(node, node.NewColDefinition, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfColumnDefinition(node, node.NewColDefinition, func(newNode, parent SQLNode) { parent.(*ChangeColumn).NewColDefinition = newNode.(*ColumnDefinition) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { parent.(*ChangeColumn).First = newNode.(*ColName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { parent.(*ChangeColumn).After = newNode.(*ColName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -3207,7 +3207,7 @@ func VisitRefOfCheckConstraintDefinition(in *CheckConstraintDefinition, f Visit) } // rewriteRefOfCheckConstraintDefinition is part of the Rewrite implementation -func rewriteRefOfCheckConstraintDefinition(parent SQLNode, node *CheckConstraintDefinition, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfCheckConstraintDefinition(parent SQLNode, node *CheckConstraintDefinition, replacer replacerFunc) error { if node == nil { return nil } @@ -3216,15 +3216,15 @@ func rewriteRefOfCheckConstraintDefinition(parent SQLNode, node *CheckConstraint parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*CheckConstraintDefinition).Expr = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -3251,20 +3251,20 @@ func VisitColIdent(in ColIdent, f Visit) error { } // rewriteColIdent is part of the Rewrite implementation -func rewriteColIdent(parent SQLNode, node ColIdent, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteColIdent(parent SQLNode, node ColIdent, replacer replacerFunc) error { var err error cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } if err != nil { return err } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -3305,7 +3305,7 @@ func VisitRefOfColName(in *ColName, f Visit) error { } // rewriteRefOfColName is part of the Rewrite implementation -func rewriteRefOfColName(parent SQLNode, node *ColName, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfColName(parent SQLNode, node *ColName, replacer replacerFunc) error { if node == nil { return nil } @@ -3314,20 +3314,20 @@ func rewriteRefOfColName(parent SQLNode, node *ColName, replacer replacerFunc, p parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*ColName).Name = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteTableName(node, node.Qualifier, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.Qualifier, func(newNode, parent SQLNode) { parent.(*ColName).Qualifier = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -3370,7 +3370,7 @@ func VisitRefOfCollateExpr(in *CollateExpr, f Visit) error { } // rewriteRefOfCollateExpr is part of the Rewrite implementation -func rewriteRefOfCollateExpr(parent SQLNode, node *CollateExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfCollateExpr(parent SQLNode, node *CollateExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -3379,15 +3379,15 @@ func rewriteRefOfCollateExpr(parent SQLNode, node *CollateExpr, replacer replace parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*CollateExpr).Expr = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -3431,7 +3431,7 @@ func VisitRefOfColumnDefinition(in *ColumnDefinition, f Visit) error { } // rewriteRefOfColumnDefinition is part of the Rewrite implementation -func rewriteRefOfColumnDefinition(parent SQLNode, node *ColumnDefinition, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfColumnDefinition(parent SQLNode, node *ColumnDefinition, replacer replacerFunc) error { if node == nil { return nil } @@ -3440,15 +3440,15 @@ func rewriteRefOfColumnDefinition(parent SQLNode, node *ColumnDefinition, replac parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*ColumnDefinition).Name = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -3504,7 +3504,7 @@ func VisitRefOfColumnType(in *ColumnType, f Visit) error { } // rewriteRefOfColumnType is part of the Rewrite implementation -func rewriteRefOfColumnType(parent SQLNode, node *ColumnType, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfColumnType(parent SQLNode, node *ColumnType, replacer replacerFunc) error { if node == nil { return nil } @@ -3513,20 +3513,20 @@ func rewriteRefOfColumnType(parent SQLNode, node *ColumnType, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { parent.(*ColumnType).Length = newNode.(*Literal) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfLiteral(node, node.Scale, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfLiteral(node, node.Scale, func(newNode, parent SQLNode) { parent.(*ColumnType).Scale = newNode.(*Literal) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -3571,7 +3571,7 @@ func VisitColumns(in Columns, f Visit) error { } // rewriteColumns is part of the Rewrite implementation -func rewriteColumns(parent SQLNode, node Columns, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteColumns(parent SQLNode, node Columns, replacer replacerFunc) error { if node == nil { return nil } @@ -3580,17 +3580,17 @@ func rewriteColumns(parent SQLNode, node Columns, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node { - if errF := rewriteColIdent(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { parent.(Columns)[i] = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -3623,7 +3623,7 @@ func VisitComments(in Comments, f Visit) error { } // rewriteComments is part of the Rewrite implementation -func rewriteComments(parent SQLNode, node Comments, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteComments(parent SQLNode, node Comments, replacer replacerFunc) error { if node == nil { return nil } @@ -3632,10 +3632,10 @@ func rewriteComments(parent SQLNode, node Comments, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -3673,7 +3673,7 @@ func VisitRefOfCommit(in *Commit, f Visit) error { } // rewriteRefOfCommit is part of the Rewrite implementation -func rewriteRefOfCommit(parent SQLNode, node *Commit, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfCommit(parent SQLNode, node *Commit, replacer replacerFunc) error { if node == nil { return nil } @@ -3682,10 +3682,10 @@ func rewriteRefOfCommit(parent SQLNode, node *Commit, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -3738,7 +3738,7 @@ func VisitRefOfComparisonExpr(in *ComparisonExpr, f Visit) error { } // rewriteRefOfComparisonExpr is part of the Rewrite implementation -func rewriteRefOfComparisonExpr(parent SQLNode, node *ComparisonExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfComparisonExpr(parent SQLNode, node *ComparisonExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -3747,25 +3747,25 @@ func rewriteRefOfComparisonExpr(parent SQLNode, node *ComparisonExpr, replacer r parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { parent.(*ComparisonExpr).Left = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { parent.(*ComparisonExpr).Right = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.Escape, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Escape, func(newNode, parent SQLNode) { parent.(*ComparisonExpr).Escape = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -3808,7 +3808,7 @@ func VisitRefOfConstraintDefinition(in *ConstraintDefinition, f Visit) error { } // rewriteRefOfConstraintDefinition is part of the Rewrite implementation -func rewriteRefOfConstraintDefinition(parent SQLNode, node *ConstraintDefinition, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfConstraintDefinition(parent SQLNode, node *ConstraintDefinition, replacer replacerFunc) error { if node == nil { return nil } @@ -3817,15 +3817,15 @@ func rewriteRefOfConstraintDefinition(parent SQLNode, node *ConstraintDefinition parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteConstraintInfo(node, node.Details, func(newNode, parent SQLNode) { + if errF := a.rewriteConstraintInfo(node, node.Details, func(newNode, parent SQLNode) { parent.(*ConstraintDefinition).Details = newNode.(ConstraintInfo) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -3872,7 +3872,7 @@ func VisitRefOfConvertExpr(in *ConvertExpr, f Visit) error { } // rewriteRefOfConvertExpr is part of the Rewrite implementation -func rewriteRefOfConvertExpr(parent SQLNode, node *ConvertExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfConvertExpr(parent SQLNode, node *ConvertExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -3881,20 +3881,20 @@ func rewriteRefOfConvertExpr(parent SQLNode, node *ConvertExpr, replacer replace parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*ConvertExpr).Expr = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfConvertType(node, node.Type, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfConvertType(node, node.Type, func(newNode, parent SQLNode) { parent.(*ConvertExpr).Type = newNode.(*ConvertType) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -3944,7 +3944,7 @@ func VisitRefOfConvertType(in *ConvertType, f Visit) error { } // rewriteRefOfConvertType is part of the Rewrite implementation -func rewriteRefOfConvertType(parent SQLNode, node *ConvertType, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfConvertType(parent SQLNode, node *ConvertType, replacer replacerFunc) error { if node == nil { return nil } @@ -3953,20 +3953,20 @@ func rewriteRefOfConvertType(parent SQLNode, node *ConvertType, replacer replace parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { parent.(*ConvertType).Length = newNode.(*Literal) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfLiteral(node, node.Scale, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfLiteral(node, node.Scale, func(newNode, parent SQLNode) { parent.(*ConvertType).Scale = newNode.(*Literal) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -4009,7 +4009,7 @@ func VisitRefOfConvertUsingExpr(in *ConvertUsingExpr, f Visit) error { } // rewriteRefOfConvertUsingExpr is part of the Rewrite implementation -func rewriteRefOfConvertUsingExpr(parent SQLNode, node *ConvertUsingExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfConvertUsingExpr(parent SQLNode, node *ConvertUsingExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -4018,15 +4018,15 @@ func rewriteRefOfConvertUsingExpr(parent SQLNode, node *ConvertUsingExpr, replac parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*ConvertUsingExpr).Expr = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -4073,7 +4073,7 @@ func VisitRefOfCreateDatabase(in *CreateDatabase, f Visit) error { } // rewriteRefOfCreateDatabase is part of the Rewrite implementation -func rewriteRefOfCreateDatabase(parent SQLNode, node *CreateDatabase, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfCreateDatabase(parent SQLNode, node *CreateDatabase, replacer replacerFunc) error { if node == nil { return nil } @@ -4082,15 +4082,15 @@ func rewriteRefOfCreateDatabase(parent SQLNode, node *CreateDatabase, replacer r parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*CreateDatabase).Comments = newNode.(Comments) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -4145,7 +4145,7 @@ func VisitRefOfCreateTable(in *CreateTable, f Visit) error { } // rewriteRefOfCreateTable is part of the Rewrite implementation -func rewriteRefOfCreateTable(parent SQLNode, node *CreateTable, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfCreateTable(parent SQLNode, node *CreateTable, replacer replacerFunc) error { if node == nil { return nil } @@ -4154,25 +4154,25 @@ func rewriteRefOfCreateTable(parent SQLNode, node *CreateTable, replacer replace parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { parent.(*CreateTable).Table = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfTableSpec(node, node.TableSpec, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfTableSpec(node, node.TableSpec, func(newNode, parent SQLNode) { parent.(*CreateTable).TableSpec = newNode.(*TableSpec) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfOptLike(node, node.OptLike, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfOptLike(node, node.OptLike, func(newNode, parent SQLNode) { parent.(*CreateTable).OptLike = newNode.(*OptLike) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -4229,7 +4229,7 @@ func VisitRefOfCreateView(in *CreateView, f Visit) error { } // rewriteRefOfCreateView is part of the Rewrite implementation -func rewriteRefOfCreateView(parent SQLNode, node *CreateView, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfCreateView(parent SQLNode, node *CreateView, replacer replacerFunc) error { if node == nil { return nil } @@ -4238,25 +4238,25 @@ func rewriteRefOfCreateView(parent SQLNode, node *CreateView, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { parent.(*CreateView).ViewName = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { + if errF := a.rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { parent.(*CreateView).Columns = newNode.(Columns) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { parent.(*CreateView).Select = newNode.(SelectStatement) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -4303,7 +4303,7 @@ func VisitRefOfCurTimeFuncExpr(in *CurTimeFuncExpr, f Visit) error { } // rewriteRefOfCurTimeFuncExpr is part of the Rewrite implementation -func rewriteRefOfCurTimeFuncExpr(parent SQLNode, node *CurTimeFuncExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfCurTimeFuncExpr(parent SQLNode, node *CurTimeFuncExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -4312,20 +4312,20 @@ func rewriteRefOfCurTimeFuncExpr(parent SQLNode, node *CurTimeFuncExpr, replacer parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*CurTimeFuncExpr).Name = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.Fsp, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Fsp, func(newNode, parent SQLNode) { parent.(*CurTimeFuncExpr).Fsp = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -4363,7 +4363,7 @@ func VisitRefOfDefault(in *Default, f Visit) error { } // rewriteRefOfDefault is part of the Rewrite implementation -func rewriteRefOfDefault(parent SQLNode, node *Default, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfDefault(parent SQLNode, node *Default, replacer replacerFunc) error { if node == nil { return nil } @@ -4372,10 +4372,10 @@ func rewriteRefOfDefault(parent SQLNode, node *Default, replacer replacerFunc, p parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -4448,7 +4448,7 @@ func VisitRefOfDelete(in *Delete, f Visit) error { } // rewriteRefOfDelete is part of the Rewrite implementation -func rewriteRefOfDelete(parent SQLNode, node *Delete, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfDelete(parent SQLNode, node *Delete, replacer replacerFunc) error { if node == nil { return nil } @@ -4457,45 +4457,45 @@ func rewriteRefOfDelete(parent SQLNode, node *Delete, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*Delete).Comments = newNode.(Comments) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteTableNames(node, node.Targets, func(newNode, parent SQLNode) { + if errF := a.rewriteTableNames(node, node.Targets, func(newNode, parent SQLNode) { parent.(*Delete).Targets = newNode.(TableNames) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteTableExprs(node, node.TableExprs, func(newNode, parent SQLNode) { + if errF := a.rewriteTableExprs(node, node.TableExprs, func(newNode, parent SQLNode) { parent.(*Delete).TableExprs = newNode.(TableExprs) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { + if errF := a.rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { parent.(*Delete).Partitions = newNode.(Partitions) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { parent.(*Delete).Where = newNode.(*Where) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { parent.(*Delete).OrderBy = newNode.(OrderBy) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { parent.(*Delete).Limit = newNode.(*Limit) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -4537,7 +4537,7 @@ func VisitRefOfDerivedTable(in *DerivedTable, f Visit) error { } // rewriteRefOfDerivedTable is part of the Rewrite implementation -func rewriteRefOfDerivedTable(parent SQLNode, node *DerivedTable, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfDerivedTable(parent SQLNode, node *DerivedTable, replacer replacerFunc) error { if node == nil { return nil } @@ -4546,15 +4546,15 @@ func rewriteRefOfDerivedTable(parent SQLNode, node *DerivedTable, replacer repla parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { parent.(*DerivedTable).Select = newNode.(SelectStatement) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -4596,7 +4596,7 @@ func VisitRefOfDropColumn(in *DropColumn, f Visit) error { } // rewriteRefOfDropColumn is part of the Rewrite implementation -func rewriteRefOfDropColumn(parent SQLNode, node *DropColumn, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfDropColumn(parent SQLNode, node *DropColumn, replacer replacerFunc) error { if node == nil { return nil } @@ -4605,15 +4605,15 @@ func rewriteRefOfDropColumn(parent SQLNode, node *DropColumn, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { parent.(*DropColumn).Name = newNode.(*ColName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -4657,7 +4657,7 @@ func VisitRefOfDropDatabase(in *DropDatabase, f Visit) error { } // rewriteRefOfDropDatabase is part of the Rewrite implementation -func rewriteRefOfDropDatabase(parent SQLNode, node *DropDatabase, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfDropDatabase(parent SQLNode, node *DropDatabase, replacer replacerFunc) error { if node == nil { return nil } @@ -4666,15 +4666,15 @@ func rewriteRefOfDropDatabase(parent SQLNode, node *DropDatabase, replacer repla parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*DropDatabase).Comments = newNode.(Comments) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -4713,7 +4713,7 @@ func VisitRefOfDropKey(in *DropKey, f Visit) error { } // rewriteRefOfDropKey is part of the Rewrite implementation -func rewriteRefOfDropKey(parent SQLNode, node *DropKey, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfDropKey(parent SQLNode, node *DropKey, replacer replacerFunc) error { if node == nil { return nil } @@ -4722,10 +4722,10 @@ func rewriteRefOfDropKey(parent SQLNode, node *DropKey, replacer replacerFunc, p parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -4769,7 +4769,7 @@ func VisitRefOfDropTable(in *DropTable, f Visit) error { } // rewriteRefOfDropTable is part of the Rewrite implementation -func rewriteRefOfDropTable(parent SQLNode, node *DropTable, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfDropTable(parent SQLNode, node *DropTable, replacer replacerFunc) error { if node == nil { return nil } @@ -4778,15 +4778,15 @@ func rewriteRefOfDropTable(parent SQLNode, node *DropTable, replacer replacerFun parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { + if errF := a.rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { parent.(*DropTable).FromTables = newNode.(TableNames) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -4829,7 +4829,7 @@ func VisitRefOfDropView(in *DropView, f Visit) error { } // rewriteRefOfDropView is part of the Rewrite implementation -func rewriteRefOfDropView(parent SQLNode, node *DropView, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfDropView(parent SQLNode, node *DropView, replacer replacerFunc) error { if node == nil { return nil } @@ -4838,15 +4838,15 @@ func rewriteRefOfDropView(parent SQLNode, node *DropView, replacer replacerFunc, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { + if errF := a.rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { parent.(*DropView).FromTables = newNode.(TableNames) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -4888,7 +4888,7 @@ func VisitRefOfExistsExpr(in *ExistsExpr, f Visit) error { } // rewriteRefOfExistsExpr is part of the Rewrite implementation -func rewriteRefOfExistsExpr(parent SQLNode, node *ExistsExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfExistsExpr(parent SQLNode, node *ExistsExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -4897,15 +4897,15 @@ func rewriteRefOfExistsExpr(parent SQLNode, node *ExistsExpr, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteRefOfSubquery(node, node.Subquery, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfSubquery(node, node.Subquery, func(newNode, parent SQLNode) { parent.(*ExistsExpr).Subquery = newNode.(*Subquery) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -4948,7 +4948,7 @@ func VisitRefOfExplainStmt(in *ExplainStmt, f Visit) error { } // rewriteRefOfExplainStmt is part of the Rewrite implementation -func rewriteRefOfExplainStmt(parent SQLNode, node *ExplainStmt, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfExplainStmt(parent SQLNode, node *ExplainStmt, replacer replacerFunc) error { if node == nil { return nil } @@ -4957,15 +4957,15 @@ func rewriteRefOfExplainStmt(parent SQLNode, node *ExplainStmt, replacer replace parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteStatement(node, node.Statement, func(newNode, parent SQLNode) { + if errF := a.rewriteStatement(node, node.Statement, func(newNode, parent SQLNode) { parent.(*ExplainStmt).Statement = newNode.(Statement) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -5008,7 +5008,7 @@ func VisitRefOfExplainTab(in *ExplainTab, f Visit) error { } // rewriteRefOfExplainTab is part of the Rewrite implementation -func rewriteRefOfExplainTab(parent SQLNode, node *ExplainTab, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfExplainTab(parent SQLNode, node *ExplainTab, replacer replacerFunc) error { if node == nil { return nil } @@ -5017,15 +5017,15 @@ func rewriteRefOfExplainTab(parent SQLNode, node *ExplainTab, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { parent.(*ExplainTab).Table = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -5070,7 +5070,7 @@ func VisitExprs(in Exprs, f Visit) error { } // rewriteExprs is part of the Rewrite implementation -func rewriteExprs(parent SQLNode, node Exprs, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteExprs(parent SQLNode, node Exprs, replacer replacerFunc) error { if node == nil { return nil } @@ -5079,17 +5079,17 @@ func rewriteExprs(parent SQLNode, node Exprs, replacer replacerFunc, pre, post A parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node { - if errF := rewriteExpr(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, el, func(newNode, parent SQLNode) { parent.(Exprs)[i] = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -5136,7 +5136,7 @@ func VisitRefOfFlush(in *Flush, f Visit) error { } // rewriteRefOfFlush is part of the Rewrite implementation -func rewriteRefOfFlush(parent SQLNode, node *Flush, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfFlush(parent SQLNode, node *Flush, replacer replacerFunc) error { if node == nil { return nil } @@ -5145,15 +5145,15 @@ func rewriteRefOfFlush(parent SQLNode, node *Flush, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableNames(node, node.TableNames, func(newNode, parent SQLNode) { + if errF := a.rewriteTableNames(node, node.TableNames, func(newNode, parent SQLNode) { parent.(*Flush).TableNames = newNode.(TableNames) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -5191,7 +5191,7 @@ func VisitRefOfForce(in *Force, f Visit) error { } // rewriteRefOfForce is part of the Rewrite implementation -func rewriteRefOfForce(parent SQLNode, node *Force, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfForce(parent SQLNode, node *Force, replacer replacerFunc) error { if node == nil { return nil } @@ -5200,10 +5200,10 @@ func rewriteRefOfForce(parent SQLNode, node *Force, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -5263,7 +5263,7 @@ func VisitRefOfForeignKeyDefinition(in *ForeignKeyDefinition, f Visit) error { } // rewriteRefOfForeignKeyDefinition is part of the Rewrite implementation -func rewriteRefOfForeignKeyDefinition(parent SQLNode, node *ForeignKeyDefinition, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfForeignKeyDefinition(parent SQLNode, node *ForeignKeyDefinition, replacer replacerFunc) error { if node == nil { return nil } @@ -5272,35 +5272,35 @@ func rewriteRefOfForeignKeyDefinition(parent SQLNode, node *ForeignKeyDefinition parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteColumns(node, node.Source, func(newNode, parent SQLNode) { + if errF := a.rewriteColumns(node, node.Source, func(newNode, parent SQLNode) { parent.(*ForeignKeyDefinition).Source = newNode.(Columns) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteTableName(node, node.ReferencedTable, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.ReferencedTable, func(newNode, parent SQLNode) { parent.(*ForeignKeyDefinition).ReferencedTable = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteColumns(node, node.ReferencedColumns, func(newNode, parent SQLNode) { + if errF := a.rewriteColumns(node, node.ReferencedColumns, func(newNode, parent SQLNode) { parent.(*ForeignKeyDefinition).ReferencedColumns = newNode.(Columns) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteReferenceAction(node, node.OnDelete, func(newNode, parent SQLNode) { + if errF := a.rewriteReferenceAction(node, node.OnDelete, func(newNode, parent SQLNode) { parent.(*ForeignKeyDefinition).OnDelete = newNode.(ReferenceAction) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteReferenceAction(node, node.OnUpdate, func(newNode, parent SQLNode) { + if errF := a.rewriteReferenceAction(node, node.OnUpdate, func(newNode, parent SQLNode) { parent.(*ForeignKeyDefinition).OnUpdate = newNode.(ReferenceAction) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -5353,7 +5353,7 @@ func VisitRefOfFuncExpr(in *FuncExpr, f Visit) error { } // rewriteRefOfFuncExpr is part of the Rewrite implementation -func rewriteRefOfFuncExpr(parent SQLNode, node *FuncExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfFuncExpr(parent SQLNode, node *FuncExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -5362,25 +5362,25 @@ func rewriteRefOfFuncExpr(parent SQLNode, node *FuncExpr, replacer replacerFunc, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { + if errF := a.rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { parent.(*FuncExpr).Qualifier = newNode.(TableIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*FuncExpr).Name = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteSelectExprs(node, node.Exprs, func(newNode, parent SQLNode) { + if errF := a.rewriteSelectExprs(node, node.Exprs, func(newNode, parent SQLNode) { parent.(*FuncExpr).Exprs = newNode.(SelectExprs) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -5425,7 +5425,7 @@ func VisitGroupBy(in GroupBy, f Visit) error { } // rewriteGroupBy is part of the Rewrite implementation -func rewriteGroupBy(parent SQLNode, node GroupBy, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteGroupBy(parent SQLNode, node GroupBy, replacer replacerFunc) error { if node == nil { return nil } @@ -5434,17 +5434,17 @@ func rewriteGroupBy(parent SQLNode, node GroupBy, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node { - if errF := rewriteExpr(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, el, func(newNode, parent SQLNode) { parent.(GroupBy)[i] = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -5498,7 +5498,7 @@ func VisitRefOfGroupConcatExpr(in *GroupConcatExpr, f Visit) error { } // rewriteRefOfGroupConcatExpr is part of the Rewrite implementation -func rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupConcatExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupConcatExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -5507,25 +5507,25 @@ func rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupConcatExpr, replacer parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteSelectExprs(node, node.Exprs, func(newNode, parent SQLNode) { + if errF := a.rewriteSelectExprs(node, node.Exprs, func(newNode, parent SQLNode) { parent.(*GroupConcatExpr).Exprs = newNode.(SelectExprs) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { parent.(*GroupConcatExpr).OrderBy = newNode.(OrderBy) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { parent.(*GroupConcatExpr).Limit = newNode.(*Limit) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -5571,7 +5571,7 @@ func VisitRefOfIndexDefinition(in *IndexDefinition, f Visit) error { } // rewriteRefOfIndexDefinition is part of the Rewrite implementation -func rewriteRefOfIndexDefinition(parent SQLNode, node *IndexDefinition, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfIndexDefinition(parent SQLNode, node *IndexDefinition, replacer replacerFunc) error { if node == nil { return nil } @@ -5580,15 +5580,15 @@ func rewriteRefOfIndexDefinition(parent SQLNode, node *IndexDefinition, replacer parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteRefOfIndexInfo(node, node.Info, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfIndexInfo(node, node.Info, func(newNode, parent SQLNode) { parent.(*IndexDefinition).Info = newNode.(*IndexInfo) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -5633,7 +5633,7 @@ func VisitRefOfIndexHints(in *IndexHints, f Visit) error { } // rewriteRefOfIndexHints is part of the Rewrite implementation -func rewriteRefOfIndexHints(parent SQLNode, node *IndexHints, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfIndexHints(parent SQLNode, node *IndexHints, replacer replacerFunc) error { if node == nil { return nil } @@ -5642,17 +5642,17 @@ func rewriteRefOfIndexHints(parent SQLNode, node *IndexHints, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node.Indexes { - if errF := rewriteColIdent(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { parent.(*IndexHints).Indexes[i] = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -5704,7 +5704,7 @@ func VisitRefOfIndexInfo(in *IndexInfo, f Visit) error { } // rewriteRefOfIndexInfo is part of the Rewrite implementation -func rewriteRefOfIndexInfo(parent SQLNode, node *IndexInfo, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfIndexInfo(parent SQLNode, node *IndexInfo, replacer replacerFunc) error { if node == nil { return nil } @@ -5713,20 +5713,20 @@ func rewriteRefOfIndexInfo(parent SQLNode, node *IndexInfo, replacer replacerFun parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*IndexInfo).Name = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteColIdent(node, node.ConstraintName, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.ConstraintName, func(newNode, parent SQLNode) { parent.(*IndexInfo).ConstraintName = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -5795,7 +5795,7 @@ func VisitRefOfInsert(in *Insert, f Visit) error { } // rewriteRefOfInsert is part of the Rewrite implementation -func rewriteRefOfInsert(parent SQLNode, node *Insert, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfInsert(parent SQLNode, node *Insert, replacer replacerFunc) error { if node == nil { return nil } @@ -5804,40 +5804,40 @@ func rewriteRefOfInsert(parent SQLNode, node *Insert, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*Insert).Comments = newNode.(Comments) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { parent.(*Insert).Table = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { + if errF := a.rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { parent.(*Insert).Partitions = newNode.(Partitions) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { + if errF := a.rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { parent.(*Insert).Columns = newNode.(Columns) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteInsertRows(node, node.Rows, func(newNode, parent SQLNode) { + if errF := a.rewriteInsertRows(node, node.Rows, func(newNode, parent SQLNode) { parent.(*Insert).Rows = newNode.(InsertRows) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteOnDup(node, node.OnDup, func(newNode, parent SQLNode) { + if errF := a.rewriteOnDup(node, node.OnDup, func(newNode, parent SQLNode) { parent.(*Insert).OnDup = newNode.(OnDup) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -5880,7 +5880,7 @@ func VisitRefOfIntervalExpr(in *IntervalExpr, f Visit) error { } // rewriteRefOfIntervalExpr is part of the Rewrite implementation -func rewriteRefOfIntervalExpr(parent SQLNode, node *IntervalExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfIntervalExpr(parent SQLNode, node *IntervalExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -5889,15 +5889,15 @@ func rewriteRefOfIntervalExpr(parent SQLNode, node *IntervalExpr, replacer repla parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*IntervalExpr).Expr = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -5940,7 +5940,7 @@ func VisitRefOfIsExpr(in *IsExpr, f Visit) error { } // rewriteRefOfIsExpr is part of the Rewrite implementation -func rewriteRefOfIsExpr(parent SQLNode, node *IsExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfIsExpr(parent SQLNode, node *IsExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -5949,15 +5949,15 @@ func rewriteRefOfIsExpr(parent SQLNode, node *IsExpr, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*IsExpr).Expr = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -5989,30 +5989,30 @@ func VisitJoinCondition(in JoinCondition, f Visit) error { } // rewriteJoinCondition is part of the Rewrite implementation -func rewriteJoinCondition(parent SQLNode, node JoinCondition, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteJoinCondition(parent SQLNode, node JoinCondition, replacer replacerFunc) error { var err error cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.On, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.On, func(newNode, parent SQLNode) { err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'On' on 'JoinCondition'") - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteColumns(node, node.Using, func(newNode, parent SQLNode) { + if errF := a.rewriteColumns(node, node.Using, func(newNode, parent SQLNode) { err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Using' on 'JoinCondition'") - }, pre, post); errF != nil { + }); errF != nil { return errF } if err != nil { return err } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -6065,7 +6065,7 @@ func VisitRefOfJoinTableExpr(in *JoinTableExpr, f Visit) error { } // rewriteRefOfJoinTableExpr is part of the Rewrite implementation -func rewriteRefOfJoinTableExpr(parent SQLNode, node *JoinTableExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfJoinTableExpr(parent SQLNode, node *JoinTableExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -6074,25 +6074,25 @@ func rewriteRefOfJoinTableExpr(parent SQLNode, node *JoinTableExpr, replacer rep parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableExpr(node, node.LeftExpr, func(newNode, parent SQLNode) { + if errF := a.rewriteTableExpr(node, node.LeftExpr, func(newNode, parent SQLNode) { parent.(*JoinTableExpr).LeftExpr = newNode.(TableExpr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteTableExpr(node, node.RightExpr, func(newNode, parent SQLNode) { + if errF := a.rewriteTableExpr(node, node.RightExpr, func(newNode, parent SQLNode) { parent.(*JoinTableExpr).RightExpr = newNode.(TableExpr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteJoinCondition(node, node.Condition, func(newNode, parent SQLNode) { + if errF := a.rewriteJoinCondition(node, node.Condition, func(newNode, parent SQLNode) { parent.(*JoinTableExpr).Condition = newNode.(JoinCondition) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -6130,7 +6130,7 @@ func VisitRefOfKeyState(in *KeyState, f Visit) error { } // rewriteRefOfKeyState is part of the Rewrite implementation -func rewriteRefOfKeyState(parent SQLNode, node *KeyState, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfKeyState(parent SQLNode, node *KeyState, replacer replacerFunc) error { if node == nil { return nil } @@ -6139,10 +6139,10 @@ func rewriteRefOfKeyState(parent SQLNode, node *KeyState, replacer replacerFunc, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -6189,7 +6189,7 @@ func VisitRefOfLimit(in *Limit, f Visit) error { } // rewriteRefOfLimit is part of the Rewrite implementation -func rewriteRefOfLimit(parent SQLNode, node *Limit, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfLimit(parent SQLNode, node *Limit, replacer replacerFunc) error { if node == nil { return nil } @@ -6198,20 +6198,20 @@ func rewriteRefOfLimit(parent SQLNode, node *Limit, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Offset, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Offset, func(newNode, parent SQLNode) { parent.(*Limit).Offset = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.Rowcount, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Rowcount, func(newNode, parent SQLNode) { parent.(*Limit).Rowcount = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -6244,7 +6244,7 @@ func VisitListArg(in ListArg, f Visit) error { } // rewriteListArg is part of the Rewrite implementation -func rewriteListArg(parent SQLNode, node ListArg, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteListArg(parent SQLNode, node ListArg, replacer replacerFunc) error { if node == nil { return nil } @@ -6253,10 +6253,10 @@ func rewriteListArg(parent SQLNode, node ListArg, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -6295,7 +6295,7 @@ func VisitRefOfLiteral(in *Literal, f Visit) error { } // rewriteRefOfLiteral is part of the Rewrite implementation -func rewriteRefOfLiteral(parent SQLNode, node *Literal, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfLiteral(parent SQLNode, node *Literal, replacer replacerFunc) error { if node == nil { return nil } @@ -6304,10 +6304,10 @@ func rewriteRefOfLiteral(parent SQLNode, node *Literal, replacer replacerFunc, p parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -6345,7 +6345,7 @@ func VisitRefOfLoad(in *Load, f Visit) error { } // rewriteRefOfLoad is part of the Rewrite implementation -func rewriteRefOfLoad(parent SQLNode, node *Load, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfLoad(parent SQLNode, node *Load, replacer replacerFunc) error { if node == nil { return nil } @@ -6354,10 +6354,10 @@ func rewriteRefOfLoad(parent SQLNode, node *Load, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -6395,7 +6395,7 @@ func VisitRefOfLockOption(in *LockOption, f Visit) error { } // rewriteRefOfLockOption is part of the Rewrite implementation -func rewriteRefOfLockOption(parent SQLNode, node *LockOption, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfLockOption(parent SQLNode, node *LockOption, replacer replacerFunc) error { if node == nil { return nil } @@ -6404,10 +6404,10 @@ func rewriteRefOfLockOption(parent SQLNode, node *LockOption, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -6446,7 +6446,7 @@ func VisitRefOfLockTables(in *LockTables, f Visit) error { } // rewriteRefOfLockTables is part of the Rewrite implementation -func rewriteRefOfLockTables(parent SQLNode, node *LockTables, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfLockTables(parent SQLNode, node *LockTables, replacer replacerFunc) error { if node == nil { return nil } @@ -6455,10 +6455,10 @@ func rewriteRefOfLockTables(parent SQLNode, node *LockTables, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -6506,7 +6506,7 @@ func VisitRefOfMatchExpr(in *MatchExpr, f Visit) error { } // rewriteRefOfMatchExpr is part of the Rewrite implementation -func rewriteRefOfMatchExpr(parent SQLNode, node *MatchExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfMatchExpr(parent SQLNode, node *MatchExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -6515,20 +6515,20 @@ func rewriteRefOfMatchExpr(parent SQLNode, node *MatchExpr, replacer replacerFun parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteSelectExprs(node, node.Columns, func(newNode, parent SQLNode) { + if errF := a.rewriteSelectExprs(node, node.Columns, func(newNode, parent SQLNode) { parent.(*MatchExpr).Columns = newNode.(SelectExprs) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*MatchExpr).Expr = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -6580,7 +6580,7 @@ func VisitRefOfModifyColumn(in *ModifyColumn, f Visit) error { } // rewriteRefOfModifyColumn is part of the Rewrite implementation -func rewriteRefOfModifyColumn(parent SQLNode, node *ModifyColumn, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfModifyColumn(parent SQLNode, node *ModifyColumn, replacer replacerFunc) error { if node == nil { return nil } @@ -6589,25 +6589,25 @@ func rewriteRefOfModifyColumn(parent SQLNode, node *ModifyColumn, replacer repla parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteRefOfColumnDefinition(node, node.NewColDefinition, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfColumnDefinition(node, node.NewColDefinition, func(newNode, parent SQLNode) { parent.(*ModifyColumn).NewColDefinition = newNode.(*ColumnDefinition) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { parent.(*ModifyColumn).First = newNode.(*ColName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { parent.(*ModifyColumn).After = newNode.(*ColName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -6649,7 +6649,7 @@ func VisitRefOfNextval(in *Nextval, f Visit) error { } // rewriteRefOfNextval is part of the Rewrite implementation -func rewriteRefOfNextval(parent SQLNode, node *Nextval, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfNextval(parent SQLNode, node *Nextval, replacer replacerFunc) error { if node == nil { return nil } @@ -6658,15 +6658,15 @@ func rewriteRefOfNextval(parent SQLNode, node *Nextval, replacer replacerFunc, p parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*Nextval).Expr = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -6708,7 +6708,7 @@ func VisitRefOfNotExpr(in *NotExpr, f Visit) error { } // rewriteRefOfNotExpr is part of the Rewrite implementation -func rewriteRefOfNotExpr(parent SQLNode, node *NotExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfNotExpr(parent SQLNode, node *NotExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -6717,15 +6717,15 @@ func rewriteRefOfNotExpr(parent SQLNode, node *NotExpr, replacer replacerFunc, p parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*NotExpr).Expr = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -6763,7 +6763,7 @@ func VisitRefOfNullVal(in *NullVal, f Visit) error { } // rewriteRefOfNullVal is part of the Rewrite implementation -func rewriteRefOfNullVal(parent SQLNode, node *NullVal, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfNullVal(parent SQLNode, node *NullVal, replacer replacerFunc) error { if node == nil { return nil } @@ -6772,10 +6772,10 @@ func rewriteRefOfNullVal(parent SQLNode, node *NullVal, replacer replacerFunc, p parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -6820,7 +6820,7 @@ func VisitOnDup(in OnDup, f Visit) error { } // rewriteOnDup is part of the Rewrite implementation -func rewriteOnDup(parent SQLNode, node OnDup, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteOnDup(parent SQLNode, node OnDup, replacer replacerFunc) error { if node == nil { return nil } @@ -6829,17 +6829,17 @@ func rewriteOnDup(parent SQLNode, node OnDup, replacer replacerFunc, pre, post A parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node { - if errF := rewriteRefOfUpdateExpr(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfUpdateExpr(node, el, func(newNode, parent SQLNode) { parent.(OnDup)[i] = newNode.(*UpdateExpr) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -6881,7 +6881,7 @@ func VisitRefOfOptLike(in *OptLike, f Visit) error { } // rewriteRefOfOptLike is part of the Rewrite implementation -func rewriteRefOfOptLike(parent SQLNode, node *OptLike, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfOptLike(parent SQLNode, node *OptLike, replacer replacerFunc) error { if node == nil { return nil } @@ -6890,15 +6890,15 @@ func rewriteRefOfOptLike(parent SQLNode, node *OptLike, replacer replacerFunc, p parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableName(node, node.LikeTable, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.LikeTable, func(newNode, parent SQLNode) { parent.(*OptLike).LikeTable = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -6945,7 +6945,7 @@ func VisitRefOfOrExpr(in *OrExpr, f Visit) error { } // rewriteRefOfOrExpr is part of the Rewrite implementation -func rewriteRefOfOrExpr(parent SQLNode, node *OrExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfOrExpr(parent SQLNode, node *OrExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -6954,20 +6954,20 @@ func rewriteRefOfOrExpr(parent SQLNode, node *OrExpr, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { parent.(*OrExpr).Left = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { parent.(*OrExpr).Right = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7010,7 +7010,7 @@ func VisitRefOfOrder(in *Order, f Visit) error { } // rewriteRefOfOrder is part of the Rewrite implementation -func rewriteRefOfOrder(parent SQLNode, node *Order, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfOrder(parent SQLNode, node *Order, replacer replacerFunc) error { if node == nil { return nil } @@ -7019,15 +7019,15 @@ func rewriteRefOfOrder(parent SQLNode, node *Order, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*Order).Expr = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7072,7 +7072,7 @@ func VisitOrderBy(in OrderBy, f Visit) error { } // rewriteOrderBy is part of the Rewrite implementation -func rewriteOrderBy(parent SQLNode, node OrderBy, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteOrderBy(parent SQLNode, node OrderBy, replacer replacerFunc) error { if node == nil { return nil } @@ -7081,17 +7081,17 @@ func rewriteOrderBy(parent SQLNode, node OrderBy, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node { - if errF := rewriteRefOfOrder(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfOrder(node, el, func(newNode, parent SQLNode) { parent.(OrderBy)[i] = newNode.(*Order) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7133,7 +7133,7 @@ func VisitRefOfOrderByOption(in *OrderByOption, f Visit) error { } // rewriteRefOfOrderByOption is part of the Rewrite implementation -func rewriteRefOfOrderByOption(parent SQLNode, node *OrderByOption, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfOrderByOption(parent SQLNode, node *OrderByOption, replacer replacerFunc) error { if node == nil { return nil } @@ -7142,15 +7142,15 @@ func rewriteRefOfOrderByOption(parent SQLNode, node *OrderByOption, replacer rep parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteColumns(node, node.Cols, func(newNode, parent SQLNode) { + if errF := a.rewriteColumns(node, node.Cols, func(newNode, parent SQLNode) { parent.(*OrderByOption).Cols = newNode.(Columns) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7188,7 +7188,7 @@ func VisitRefOfOtherAdmin(in *OtherAdmin, f Visit) error { } // rewriteRefOfOtherAdmin is part of the Rewrite implementation -func rewriteRefOfOtherAdmin(parent SQLNode, node *OtherAdmin, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfOtherAdmin(parent SQLNode, node *OtherAdmin, replacer replacerFunc) error { if node == nil { return nil } @@ -7197,10 +7197,10 @@ func rewriteRefOfOtherAdmin(parent SQLNode, node *OtherAdmin, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7238,7 +7238,7 @@ func VisitRefOfOtherRead(in *OtherRead, f Visit) error { } // rewriteRefOfOtherRead is part of the Rewrite implementation -func rewriteRefOfOtherRead(parent SQLNode, node *OtherRead, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfOtherRead(parent SQLNode, node *OtherRead, replacer replacerFunc) error { if node == nil { return nil } @@ -7247,10 +7247,10 @@ func rewriteRefOfOtherRead(parent SQLNode, node *OtherRead, replacer replacerFun parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7292,7 +7292,7 @@ func VisitRefOfParenSelect(in *ParenSelect, f Visit) error { } // rewriteRefOfParenSelect is part of the Rewrite implementation -func rewriteRefOfParenSelect(parent SQLNode, node *ParenSelect, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfParenSelect(parent SQLNode, node *ParenSelect, replacer replacerFunc) error { if node == nil { return nil } @@ -7301,15 +7301,15 @@ func rewriteRefOfParenSelect(parent SQLNode, node *ParenSelect, replacer replace parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { parent.(*ParenSelect).Select = newNode.(SelectStatement) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7351,7 +7351,7 @@ func VisitRefOfParenTableExpr(in *ParenTableExpr, f Visit) error { } // rewriteRefOfParenTableExpr is part of the Rewrite implementation -func rewriteRefOfParenTableExpr(parent SQLNode, node *ParenTableExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfParenTableExpr(parent SQLNode, node *ParenTableExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -7360,15 +7360,15 @@ func rewriteRefOfParenTableExpr(parent SQLNode, node *ParenTableExpr, replacer r parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableExprs(node, node.Exprs, func(newNode, parent SQLNode) { + if errF := a.rewriteTableExprs(node, node.Exprs, func(newNode, parent SQLNode) { parent.(*ParenTableExpr).Exprs = newNode.(TableExprs) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7416,7 +7416,7 @@ func VisitRefOfPartitionDefinition(in *PartitionDefinition, f Visit) error { } // rewriteRefOfPartitionDefinition is part of the Rewrite implementation -func rewriteRefOfPartitionDefinition(parent SQLNode, node *PartitionDefinition, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfPartitionDefinition(parent SQLNode, node *PartitionDefinition, replacer replacerFunc) error { if node == nil { return nil } @@ -7425,20 +7425,20 @@ func rewriteRefOfPartitionDefinition(parent SQLNode, node *PartitionDefinition, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*PartitionDefinition).Name = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.Limit, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Limit, func(newNode, parent SQLNode) { parent.(*PartitionDefinition).Limit = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7500,7 +7500,7 @@ func VisitRefOfPartitionSpec(in *PartitionSpec, f Visit) error { } // rewriteRefOfPartitionSpec is part of the Rewrite implementation -func rewriteRefOfPartitionSpec(parent SQLNode, node *PartitionSpec, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfPartitionSpec(parent SQLNode, node *PartitionSpec, replacer replacerFunc) error { if node == nil { return nil } @@ -7509,32 +7509,32 @@ func rewriteRefOfPartitionSpec(parent SQLNode, node *PartitionSpec, replacer rep parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewritePartitions(node, node.Names, func(newNode, parent SQLNode) { + if errF := a.rewritePartitions(node, node.Names, func(newNode, parent SQLNode) { parent.(*PartitionSpec).Names = newNode.(Partitions) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfLiteral(node, node.Number, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfLiteral(node, node.Number, func(newNode, parent SQLNode) { parent.(*PartitionSpec).Number = newNode.(*Literal) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteTableName(node, node.TableName, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.TableName, func(newNode, parent SQLNode) { parent.(*PartitionSpec).TableName = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } for i, el := range node.Definitions { - if errF := rewriteRefOfPartitionDefinition(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfPartitionDefinition(node, el, func(newNode, parent SQLNode) { parent.(*PartitionSpec).Definitions[i] = newNode.(*PartitionDefinition) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7579,7 +7579,7 @@ func VisitPartitions(in Partitions, f Visit) error { } // rewritePartitions is part of the Rewrite implementation -func rewritePartitions(parent SQLNode, node Partitions, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewritePartitions(parent SQLNode, node Partitions, replacer replacerFunc) error { if node == nil { return nil } @@ -7588,17 +7588,17 @@ func rewritePartitions(parent SQLNode, node Partitions, replacer replacerFunc, p parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node { - if errF := rewriteColIdent(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { parent.(Partitions)[i] = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7651,7 +7651,7 @@ func VisitRefOfRangeCond(in *RangeCond, f Visit) error { } // rewriteRefOfRangeCond is part of the Rewrite implementation -func rewriteRefOfRangeCond(parent SQLNode, node *RangeCond, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfRangeCond(parent SQLNode, node *RangeCond, replacer replacerFunc) error { if node == nil { return nil } @@ -7660,25 +7660,25 @@ func rewriteRefOfRangeCond(parent SQLNode, node *RangeCond, replacer replacerFun parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { parent.(*RangeCond).Left = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.From, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.From, func(newNode, parent SQLNode) { parent.(*RangeCond).From = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.To, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.To, func(newNode, parent SQLNode) { parent.(*RangeCond).To = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7720,7 +7720,7 @@ func VisitRefOfRelease(in *Release, f Visit) error { } // rewriteRefOfRelease is part of the Rewrite implementation -func rewriteRefOfRelease(parent SQLNode, node *Release, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfRelease(parent SQLNode, node *Release, replacer replacerFunc) error { if node == nil { return nil } @@ -7729,15 +7729,15 @@ func rewriteRefOfRelease(parent SQLNode, node *Release, replacer replacerFunc, p parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*Release).Name = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7776,7 +7776,7 @@ func VisitRefOfRenameIndex(in *RenameIndex, f Visit) error { } // rewriteRefOfRenameIndex is part of the Rewrite implementation -func rewriteRefOfRenameIndex(parent SQLNode, node *RenameIndex, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfRenameIndex(parent SQLNode, node *RenameIndex, replacer replacerFunc) error { if node == nil { return nil } @@ -7785,10 +7785,10 @@ func rewriteRefOfRenameIndex(parent SQLNode, node *RenameIndex, replacer replace parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7827,7 +7827,7 @@ func VisitRefOfRenameTable(in *RenameTable, f Visit) error { } // rewriteRefOfRenameTable is part of the Rewrite implementation -func rewriteRefOfRenameTable(parent SQLNode, node *RenameTable, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfRenameTable(parent SQLNode, node *RenameTable, replacer replacerFunc) error { if node == nil { return nil } @@ -7836,10 +7836,10 @@ func rewriteRefOfRenameTable(parent SQLNode, node *RenameTable, replacer replace parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7881,7 +7881,7 @@ func VisitRefOfRenameTableName(in *RenameTableName, f Visit) error { } // rewriteRefOfRenameTableName is part of the Rewrite implementation -func rewriteRefOfRenameTableName(parent SQLNode, node *RenameTableName, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfRenameTableName(parent SQLNode, node *RenameTableName, replacer replacerFunc) error { if node == nil { return nil } @@ -7890,15 +7890,15 @@ func rewriteRefOfRenameTableName(parent SQLNode, node *RenameTableName, replacer parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { parent.(*RenameTableName).Table = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7936,7 +7936,7 @@ func VisitRefOfRevertMigration(in *RevertMigration, f Visit) error { } // rewriteRefOfRevertMigration is part of the Rewrite implementation -func rewriteRefOfRevertMigration(parent SQLNode, node *RevertMigration, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfRevertMigration(parent SQLNode, node *RevertMigration, replacer replacerFunc) error { if node == nil { return nil } @@ -7945,10 +7945,10 @@ func rewriteRefOfRevertMigration(parent SQLNode, node *RevertMigration, replacer parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -7986,7 +7986,7 @@ func VisitRefOfRollback(in *Rollback, f Visit) error { } // rewriteRefOfRollback is part of the Rewrite implementation -func rewriteRefOfRollback(parent SQLNode, node *Rollback, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfRollback(parent SQLNode, node *Rollback, replacer replacerFunc) error { if node == nil { return nil } @@ -7995,10 +7995,10 @@ func rewriteRefOfRollback(parent SQLNode, node *Rollback, replacer replacerFunc, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -8040,7 +8040,7 @@ func VisitRefOfSRollback(in *SRollback, f Visit) error { } // rewriteRefOfSRollback is part of the Rewrite implementation -func rewriteRefOfSRollback(parent SQLNode, node *SRollback, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfSRollback(parent SQLNode, node *SRollback, replacer replacerFunc) error { if node == nil { return nil } @@ -8049,15 +8049,15 @@ func rewriteRefOfSRollback(parent SQLNode, node *SRollback, replacer replacerFun parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*SRollback).Name = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -8099,7 +8099,7 @@ func VisitRefOfSavepoint(in *Savepoint, f Visit) error { } // rewriteRefOfSavepoint is part of the Rewrite implementation -func rewriteRefOfSavepoint(parent SQLNode, node *Savepoint, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfSavepoint(parent SQLNode, node *Savepoint, replacer replacerFunc) error { if node == nil { return nil } @@ -8108,15 +8108,15 @@ func rewriteRefOfSavepoint(parent SQLNode, node *Savepoint, replacer replacerFun parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*Savepoint).Name = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -8204,7 +8204,7 @@ func VisitRefOfSelect(in *Select, f Visit) error { } // rewriteRefOfSelect is part of the Rewrite implementation -func rewriteRefOfSelect(parent SQLNode, node *Select, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfSelect(parent SQLNode, node *Select, replacer replacerFunc) error { if node == nil { return nil } @@ -8213,55 +8213,55 @@ func rewriteRefOfSelect(parent SQLNode, node *Select, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*Select).Comments = newNode.(Comments) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteSelectExprs(node, node.SelectExprs, func(newNode, parent SQLNode) { + if errF := a.rewriteSelectExprs(node, node.SelectExprs, func(newNode, parent SQLNode) { parent.(*Select).SelectExprs = newNode.(SelectExprs) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteTableExprs(node, node.From, func(newNode, parent SQLNode) { + if errF := a.rewriteTableExprs(node, node.From, func(newNode, parent SQLNode) { parent.(*Select).From = newNode.(TableExprs) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { parent.(*Select).Where = newNode.(*Where) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteGroupBy(node, node.GroupBy, func(newNode, parent SQLNode) { + if errF := a.rewriteGroupBy(node, node.GroupBy, func(newNode, parent SQLNode) { parent.(*Select).GroupBy = newNode.(GroupBy) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfWhere(node, node.Having, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfWhere(node, node.Having, func(newNode, parent SQLNode) { parent.(*Select).Having = newNode.(*Where) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { parent.(*Select).OrderBy = newNode.(OrderBy) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { parent.(*Select).Limit = newNode.(*Limit) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfSelectInto(node, node.Into, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfSelectInto(node, node.Into, func(newNode, parent SQLNode) { parent.(*Select).Into = newNode.(*SelectInto) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -8306,7 +8306,7 @@ func VisitSelectExprs(in SelectExprs, f Visit) error { } // rewriteSelectExprs is part of the Rewrite implementation -func rewriteSelectExprs(parent SQLNode, node SelectExprs, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteSelectExprs(parent SQLNode, node SelectExprs, replacer replacerFunc) error { if node == nil { return nil } @@ -8315,17 +8315,17 @@ func rewriteSelectExprs(parent SQLNode, node SelectExprs, replacer replacerFunc, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node { - if errF := rewriteSelectExpr(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteSelectExpr(node, el, func(newNode, parent SQLNode) { parent.(SelectExprs)[i] = newNode.(SelectExpr) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -8369,7 +8369,7 @@ func VisitRefOfSelectInto(in *SelectInto, f Visit) error { } // rewriteRefOfSelectInto is part of the Rewrite implementation -func rewriteRefOfSelectInto(parent SQLNode, node *SelectInto, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfSelectInto(parent SQLNode, node *SelectInto, replacer replacerFunc) error { if node == nil { return nil } @@ -8378,10 +8378,10 @@ func rewriteRefOfSelectInto(parent SQLNode, node *SelectInto, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -8428,7 +8428,7 @@ func VisitRefOfSet(in *Set, f Visit) error { } // rewriteRefOfSet is part of the Rewrite implementation -func rewriteRefOfSet(parent SQLNode, node *Set, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfSet(parent SQLNode, node *Set, replacer replacerFunc) error { if node == nil { return nil } @@ -8437,20 +8437,20 @@ func rewriteRefOfSet(parent SQLNode, node *Set, replacer replacerFunc, pre, post parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*Set).Comments = newNode.(Comments) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteSetExprs(node, node.Exprs, func(newNode, parent SQLNode) { + if errF := a.rewriteSetExprs(node, node.Exprs, func(newNode, parent SQLNode) { parent.(*Set).Exprs = newNode.(SetExprs) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -8498,7 +8498,7 @@ func VisitRefOfSetExpr(in *SetExpr, f Visit) error { } // rewriteRefOfSetExpr is part of the Rewrite implementation -func rewriteRefOfSetExpr(parent SQLNode, node *SetExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfSetExpr(parent SQLNode, node *SetExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -8507,20 +8507,20 @@ func rewriteRefOfSetExpr(parent SQLNode, node *SetExpr, replacer replacerFunc, p parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*SetExpr).Name = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*SetExpr).Expr = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -8565,7 +8565,7 @@ func VisitSetExprs(in SetExprs, f Visit) error { } // rewriteSetExprs is part of the Rewrite implementation -func rewriteSetExprs(parent SQLNode, node SetExprs, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteSetExprs(parent SQLNode, node SetExprs, replacer replacerFunc) error { if node == nil { return nil } @@ -8574,17 +8574,17 @@ func rewriteSetExprs(parent SQLNode, node SetExprs, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node { - if errF := rewriteRefOfSetExpr(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfSetExpr(node, el, func(newNode, parent SQLNode) { parent.(SetExprs)[i] = newNode.(*SetExpr) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -8639,7 +8639,7 @@ func VisitRefOfSetTransaction(in *SetTransaction, f Visit) error { } // rewriteRefOfSetTransaction is part of the Rewrite implementation -func rewriteRefOfSetTransaction(parent SQLNode, node *SetTransaction, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfSetTransaction(parent SQLNode, node *SetTransaction, replacer replacerFunc) error { if node == nil { return nil } @@ -8648,27 +8648,27 @@ func rewriteRefOfSetTransaction(parent SQLNode, node *SetTransaction, replacer r parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteSQLNode(node, node.SQLNode, func(newNode, parent SQLNode) { + if errF := a.rewriteSQLNode(node, node.SQLNode, func(newNode, parent SQLNode) { parent.(*SetTransaction).SQLNode = newNode.(SQLNode) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*SetTransaction).Comments = newNode.(Comments) - }, pre, post); errF != nil { + }); errF != nil { return errF } for i, el := range node.Characteristics { - if errF := rewriteCharacteristic(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteCharacteristic(node, el, func(newNode, parent SQLNode) { parent.(*SetTransaction).Characteristics[i] = newNode.(Characteristic) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -8710,7 +8710,7 @@ func VisitRefOfShow(in *Show, f Visit) error { } // rewriteRefOfShow is part of the Rewrite implementation -func rewriteRefOfShow(parent SQLNode, node *Show, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfShow(parent SQLNode, node *Show, replacer replacerFunc) error { if node == nil { return nil } @@ -8719,15 +8719,15 @@ func rewriteRefOfShow(parent SQLNode, node *Show, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteShowInternal(node, node.Internal, func(newNode, parent SQLNode) { + if errF := a.rewriteShowInternal(node, node.Internal, func(newNode, parent SQLNode) { parent.(*Show).Internal = newNode.(ShowInternal) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -8777,7 +8777,7 @@ func VisitRefOfShowBasic(in *ShowBasic, f Visit) error { } // rewriteRefOfShowBasic is part of the Rewrite implementation -func rewriteRefOfShowBasic(parent SQLNode, node *ShowBasic, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfShowBasic(parent SQLNode, node *ShowBasic, replacer replacerFunc) error { if node == nil { return nil } @@ -8786,20 +8786,20 @@ func rewriteRefOfShowBasic(parent SQLNode, node *ShowBasic, replacer replacerFun parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableName(node, node.Tbl, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.Tbl, func(newNode, parent SQLNode) { parent.(*ShowBasic).Tbl = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfShowFilter(node, node.Filter, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfShowFilter(node, node.Filter, func(newNode, parent SQLNode) { parent.(*ShowBasic).Filter = newNode.(*ShowFilter) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -8842,7 +8842,7 @@ func VisitRefOfShowCreate(in *ShowCreate, f Visit) error { } // rewriteRefOfShowCreate is part of the Rewrite implementation -func rewriteRefOfShowCreate(parent SQLNode, node *ShowCreate, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfShowCreate(parent SQLNode, node *ShowCreate, replacer replacerFunc) error { if node == nil { return nil } @@ -8851,15 +8851,15 @@ func rewriteRefOfShowCreate(parent SQLNode, node *ShowCreate, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableName(node, node.Op, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.Op, func(newNode, parent SQLNode) { parent.(*ShowCreate).Op = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -8902,7 +8902,7 @@ func VisitRefOfShowFilter(in *ShowFilter, f Visit) error { } // rewriteRefOfShowFilter is part of the Rewrite implementation -func rewriteRefOfShowFilter(parent SQLNode, node *ShowFilter, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfShowFilter(parent SQLNode, node *ShowFilter, replacer replacerFunc) error { if node == nil { return nil } @@ -8911,15 +8911,15 @@ func rewriteRefOfShowFilter(parent SQLNode, node *ShowFilter, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Filter, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Filter, func(newNode, parent SQLNode) { parent.(*ShowFilter).Filter = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -8976,7 +8976,7 @@ func VisitRefOfShowLegacy(in *ShowLegacy, f Visit) error { } // rewriteRefOfShowLegacy is part of the Rewrite implementation -func rewriteRefOfShowLegacy(parent SQLNode, node *ShowLegacy, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfShowLegacy(parent SQLNode, node *ShowLegacy, replacer replacerFunc) error { if node == nil { return nil } @@ -8985,25 +8985,25 @@ func rewriteRefOfShowLegacy(parent SQLNode, node *ShowLegacy, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableName(node, node.OnTable, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.OnTable, func(newNode, parent SQLNode) { parent.(*ShowLegacy).OnTable = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { parent.(*ShowLegacy).Table = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.ShowCollationFilterOpt, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.ShowCollationFilterOpt, func(newNode, parent SQLNode) { parent.(*ShowLegacy).ShowCollationFilterOpt = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -9045,7 +9045,7 @@ func VisitRefOfStarExpr(in *StarExpr, f Visit) error { } // rewriteRefOfStarExpr is part of the Rewrite implementation -func rewriteRefOfStarExpr(parent SQLNode, node *StarExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfStarExpr(parent SQLNode, node *StarExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -9054,15 +9054,15 @@ func rewriteRefOfStarExpr(parent SQLNode, node *StarExpr, replacer replacerFunc, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableName(node, node.TableName, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.TableName, func(newNode, parent SQLNode) { parent.(*StarExpr).TableName = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -9114,7 +9114,7 @@ func VisitRefOfStream(in *Stream, f Visit) error { } // rewriteRefOfStream is part of the Rewrite implementation -func rewriteRefOfStream(parent SQLNode, node *Stream, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfStream(parent SQLNode, node *Stream, replacer replacerFunc) error { if node == nil { return nil } @@ -9123,25 +9123,25 @@ func rewriteRefOfStream(parent SQLNode, node *Stream, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*Stream).Comments = newNode.(Comments) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteSelectExpr(node, node.SelectExpr, func(newNode, parent SQLNode) { + if errF := a.rewriteSelectExpr(node, node.SelectExpr, func(newNode, parent SQLNode) { parent.(*Stream).SelectExpr = newNode.(SelectExpr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { parent.(*Stream).Table = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -9183,7 +9183,7 @@ func VisitRefOfSubquery(in *Subquery, f Visit) error { } // rewriteRefOfSubquery is part of the Rewrite implementation -func rewriteRefOfSubquery(parent SQLNode, node *Subquery, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfSubquery(parent SQLNode, node *Subquery, replacer replacerFunc) error { if node == nil { return nil } @@ -9192,15 +9192,15 @@ func rewriteRefOfSubquery(parent SQLNode, node *Subquery, replacer replacerFunc, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { parent.(*Subquery).Select = newNode.(SelectStatement) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -9257,7 +9257,7 @@ func VisitRefOfSubstrExpr(in *SubstrExpr, f Visit) error { } // rewriteRefOfSubstrExpr is part of the Rewrite implementation -func rewriteRefOfSubstrExpr(parent SQLNode, node *SubstrExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfSubstrExpr(parent SQLNode, node *SubstrExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -9266,30 +9266,30 @@ func rewriteRefOfSubstrExpr(parent SQLNode, node *SubstrExpr, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { parent.(*SubstrExpr).Name = newNode.(*ColName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfLiteral(node, node.StrVal, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfLiteral(node, node.StrVal, func(newNode, parent SQLNode) { parent.(*SubstrExpr).StrVal = newNode.(*Literal) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.From, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.From, func(newNode, parent SQLNode) { parent.(*SubstrExpr).From = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.To, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.To, func(newNode, parent SQLNode) { parent.(*SubstrExpr).To = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -9334,7 +9334,7 @@ func VisitTableExprs(in TableExprs, f Visit) error { } // rewriteTableExprs is part of the Rewrite implementation -func rewriteTableExprs(parent SQLNode, node TableExprs, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteTableExprs(parent SQLNode, node TableExprs, replacer replacerFunc) error { if node == nil { return nil } @@ -9343,17 +9343,17 @@ func rewriteTableExprs(parent SQLNode, node TableExprs, replacer replacerFunc, p parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node { - if errF := rewriteTableExpr(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteTableExpr(node, el, func(newNode, parent SQLNode) { parent.(TableExprs)[i] = newNode.(TableExpr) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -9378,20 +9378,20 @@ func VisitTableIdent(in TableIdent, f Visit) error { } // rewriteTableIdent is part of the Rewrite implementation -func rewriteTableIdent(parent SQLNode, node TableIdent, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteTableIdent(parent SQLNode, node TableIdent, replacer replacerFunc) error { var err error cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } if err != nil { return err } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -9423,30 +9423,30 @@ func VisitTableName(in TableName, f Visit) error { } // rewriteTableName is part of the Rewrite implementation -func rewriteTableName(parent SQLNode, node TableName, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteTableName(parent SQLNode, node TableName, replacer replacerFunc) error { var err error cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Name' on 'TableName'") - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { + if errF := a.rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Qualifier' on 'TableName'") - }, pre, post); errF != nil { + }); errF != nil { return errF } if err != nil { return err } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -9491,7 +9491,7 @@ func VisitTableNames(in TableNames, f Visit) error { } // rewriteTableNames is part of the Rewrite implementation -func rewriteTableNames(parent SQLNode, node TableNames, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteTableNames(parent SQLNode, node TableNames, replacer replacerFunc) error { if node == nil { return nil } @@ -9500,17 +9500,17 @@ func rewriteTableNames(parent SQLNode, node TableNames, replacer replacerFunc, p parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node { - if errF := rewriteTableName(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, el, func(newNode, parent SQLNode) { parent.(TableNames)[i] = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -9545,7 +9545,7 @@ func VisitTableOptions(in TableOptions, f Visit) error { } // rewriteTableOptions is part of the Rewrite implementation -func rewriteTableOptions(parent SQLNode, node TableOptions, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteTableOptions(parent SQLNode, node TableOptions, replacer replacerFunc) error { if node == nil { return nil } @@ -9554,10 +9554,10 @@ func rewriteTableOptions(parent SQLNode, node TableOptions, replacer replacerFun parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -9620,7 +9620,7 @@ func VisitRefOfTableSpec(in *TableSpec, f Visit) error { } // rewriteRefOfTableSpec is part of the Rewrite implementation -func rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, replacer replacerFunc) error { if node == nil { return nil } @@ -9629,36 +9629,36 @@ func rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, replacer replacerFun parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node.Columns { - if errF := rewriteRefOfColumnDefinition(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfColumnDefinition(node, el, func(newNode, parent SQLNode) { parent.(*TableSpec).Columns[i] = newNode.(*ColumnDefinition) - }, pre, post); errF != nil { + }); errF != nil { return errF } } for i, el := range node.Indexes { - if errF := rewriteRefOfIndexDefinition(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfIndexDefinition(node, el, func(newNode, parent SQLNode) { parent.(*TableSpec).Indexes[i] = newNode.(*IndexDefinition) - }, pre, post); errF != nil { + }); errF != nil { return errF } } for i, el := range node.Constraints { - if errF := rewriteRefOfConstraintDefinition(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfConstraintDefinition(node, el, func(newNode, parent SQLNode) { parent.(*TableSpec).Constraints[i] = newNode.(*ConstraintDefinition) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if errF := rewriteTableOptions(node, node.Options, func(newNode, parent SQLNode) { + if errF := a.rewriteTableOptions(node, node.Options, func(newNode, parent SQLNode) { parent.(*TableSpec).Options = newNode.(TableOptions) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -9696,7 +9696,7 @@ func VisitRefOfTablespaceOperation(in *TablespaceOperation, f Visit) error { } // rewriteRefOfTablespaceOperation is part of the Rewrite implementation -func rewriteRefOfTablespaceOperation(parent SQLNode, node *TablespaceOperation, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfTablespaceOperation(parent SQLNode, node *TablespaceOperation, replacer replacerFunc) error { if node == nil { return nil } @@ -9705,10 +9705,10 @@ func rewriteRefOfTablespaceOperation(parent SQLNode, node *TablespaceOperation, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -9757,7 +9757,7 @@ func VisitRefOfTimestampFuncExpr(in *TimestampFuncExpr, f Visit) error { } // rewriteRefOfTimestampFuncExpr is part of the Rewrite implementation -func rewriteRefOfTimestampFuncExpr(parent SQLNode, node *TimestampFuncExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfTimestampFuncExpr(parent SQLNode, node *TimestampFuncExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -9766,20 +9766,20 @@ func rewriteRefOfTimestampFuncExpr(parent SQLNode, node *TimestampFuncExpr, repl parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Expr1, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr1, func(newNode, parent SQLNode) { parent.(*TimestampFuncExpr).Expr1 = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.Expr2, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr2, func(newNode, parent SQLNode) { parent.(*TimestampFuncExpr).Expr2 = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -9821,7 +9821,7 @@ func VisitRefOfTruncateTable(in *TruncateTable, f Visit) error { } // rewriteRefOfTruncateTable is part of the Rewrite implementation -func rewriteRefOfTruncateTable(parent SQLNode, node *TruncateTable, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfTruncateTable(parent SQLNode, node *TruncateTable, replacer replacerFunc) error { if node == nil { return nil } @@ -9830,15 +9830,15 @@ func rewriteRefOfTruncateTable(parent SQLNode, node *TruncateTable, replacer rep parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { parent.(*TruncateTable).Table = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -9881,7 +9881,7 @@ func VisitRefOfUnaryExpr(in *UnaryExpr, f Visit) error { } // rewriteRefOfUnaryExpr is part of the Rewrite implementation -func rewriteRefOfUnaryExpr(parent SQLNode, node *UnaryExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfUnaryExpr(parent SQLNode, node *UnaryExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -9890,15 +9890,15 @@ func rewriteRefOfUnaryExpr(parent SQLNode, node *UnaryExpr, replacer replacerFun parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*UnaryExpr).Expr = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -9958,7 +9958,7 @@ func VisitRefOfUnion(in *Union, f Visit) error { } // rewriteRefOfUnion is part of the Rewrite implementation -func rewriteRefOfUnion(parent SQLNode, node *Union, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfUnion(parent SQLNode, node *Union, replacer replacerFunc) error { if node == nil { return nil } @@ -9967,32 +9967,32 @@ func rewriteRefOfUnion(parent SQLNode, node *Union, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteSelectStatement(node, node.FirstStatement, func(newNode, parent SQLNode) { + if errF := a.rewriteSelectStatement(node, node.FirstStatement, func(newNode, parent SQLNode) { parent.(*Union).FirstStatement = newNode.(SelectStatement) - }, pre, post); errF != nil { + }); errF != nil { return errF } for i, el := range node.UnionSelects { - if errF := rewriteRefOfUnionSelect(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfUnionSelect(node, el, func(newNode, parent SQLNode) { parent.(*Union).UnionSelects[i] = newNode.(*UnionSelect) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if errF := rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { parent.(*Union).OrderBy = newNode.(OrderBy) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { parent.(*Union).Limit = newNode.(*Limit) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -10035,7 +10035,7 @@ func VisitRefOfUnionSelect(in *UnionSelect, f Visit) error { } // rewriteRefOfUnionSelect is part of the Rewrite implementation -func rewriteRefOfUnionSelect(parent SQLNode, node *UnionSelect, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfUnionSelect(parent SQLNode, node *UnionSelect, replacer replacerFunc) error { if node == nil { return nil } @@ -10044,15 +10044,15 @@ func rewriteRefOfUnionSelect(parent SQLNode, node *UnionSelect, replacer replace parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteSelectStatement(node, node.Statement, func(newNode, parent SQLNode) { + if errF := a.rewriteSelectStatement(node, node.Statement, func(newNode, parent SQLNode) { parent.(*UnionSelect).Statement = newNode.(SelectStatement) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -10090,7 +10090,7 @@ func VisitRefOfUnlockTables(in *UnlockTables, f Visit) error { } // rewriteRefOfUnlockTables is part of the Rewrite implementation -func rewriteRefOfUnlockTables(parent SQLNode, node *UnlockTables, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfUnlockTables(parent SQLNode, node *UnlockTables, replacer replacerFunc) error { if node == nil { return nil } @@ -10099,10 +10099,10 @@ func rewriteRefOfUnlockTables(parent SQLNode, node *UnlockTables, replacer repla parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -10170,7 +10170,7 @@ func VisitRefOfUpdate(in *Update, f Visit) error { } // rewriteRefOfUpdate is part of the Rewrite implementation -func rewriteRefOfUpdate(parent SQLNode, node *Update, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfUpdate(parent SQLNode, node *Update, replacer replacerFunc) error { if node == nil { return nil } @@ -10179,40 +10179,40 @@ func rewriteRefOfUpdate(parent SQLNode, node *Update, replacer replacerFunc, pre parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*Update).Comments = newNode.(Comments) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteTableExprs(node, node.TableExprs, func(newNode, parent SQLNode) { + if errF := a.rewriteTableExprs(node, node.TableExprs, func(newNode, parent SQLNode) { parent.(*Update).TableExprs = newNode.(TableExprs) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteUpdateExprs(node, node.Exprs, func(newNode, parent SQLNode) { + if errF := a.rewriteUpdateExprs(node, node.Exprs, func(newNode, parent SQLNode) { parent.(*Update).Exprs = newNode.(UpdateExprs) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { parent.(*Update).Where = newNode.(*Where) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { parent.(*Update).OrderBy = newNode.(OrderBy) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { parent.(*Update).Limit = newNode.(*Limit) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -10259,7 +10259,7 @@ func VisitRefOfUpdateExpr(in *UpdateExpr, f Visit) error { } // rewriteRefOfUpdateExpr is part of the Rewrite implementation -func rewriteRefOfUpdateExpr(parent SQLNode, node *UpdateExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfUpdateExpr(parent SQLNode, node *UpdateExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -10268,20 +10268,20 @@ func rewriteRefOfUpdateExpr(parent SQLNode, node *UpdateExpr, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { parent.(*UpdateExpr).Name = newNode.(*ColName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*UpdateExpr).Expr = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -10326,7 +10326,7 @@ func VisitUpdateExprs(in UpdateExprs, f Visit) error { } // rewriteUpdateExprs is part of the Rewrite implementation -func rewriteUpdateExprs(parent SQLNode, node UpdateExprs, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteUpdateExprs(parent SQLNode, node UpdateExprs, replacer replacerFunc) error { if node == nil { return nil } @@ -10335,17 +10335,17 @@ func rewriteUpdateExprs(parent SQLNode, node UpdateExprs, replacer replacerFunc, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node { - if errF := rewriteRefOfUpdateExpr(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfUpdateExpr(node, el, func(newNode, parent SQLNode) { parent.(UpdateExprs)[i] = newNode.(*UpdateExpr) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -10387,7 +10387,7 @@ func VisitRefOfUse(in *Use, f Visit) error { } // rewriteRefOfUse is part of the Rewrite implementation -func rewriteRefOfUse(parent SQLNode, node *Use, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfUse(parent SQLNode, node *Use, replacer replacerFunc) error { if node == nil { return nil } @@ -10396,15 +10396,15 @@ func rewriteRefOfUse(parent SQLNode, node *Use, replacer replacerFunc, pre, post parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableIdent(node, node.DBName, func(newNode, parent SQLNode) { + if errF := a.rewriteTableIdent(node, node.DBName, func(newNode, parent SQLNode) { parent.(*Use).DBName = newNode.(TableIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -10466,7 +10466,7 @@ func VisitRefOfVStream(in *VStream, f Visit) error { } // rewriteRefOfVStream is part of the Rewrite implementation -func rewriteRefOfVStream(parent SQLNode, node *VStream, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfVStream(parent SQLNode, node *VStream, replacer replacerFunc) error { if node == nil { return nil } @@ -10475,35 +10475,35 @@ func rewriteRefOfVStream(parent SQLNode, node *VStream, replacer replacerFunc, p parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*VStream).Comments = newNode.(Comments) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteSelectExpr(node, node.SelectExpr, func(newNode, parent SQLNode) { + if errF := a.rewriteSelectExpr(node, node.SelectExpr, func(newNode, parent SQLNode) { parent.(*VStream).SelectExpr = newNode.(SelectExpr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { parent.(*VStream).Table = newNode.(TableName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { parent.(*VStream).Where = newNode.(*Where) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { parent.(*VStream).Limit = newNode.(*Limit) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -10548,7 +10548,7 @@ func VisitValTuple(in ValTuple, f Visit) error { } // rewriteValTuple is part of the Rewrite implementation -func rewriteValTuple(parent SQLNode, node ValTuple, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteValTuple(parent SQLNode, node ValTuple, replacer replacerFunc) error { if node == nil { return nil } @@ -10557,17 +10557,17 @@ func rewriteValTuple(parent SQLNode, node ValTuple, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node { - if errF := rewriteExpr(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, el, func(newNode, parent SQLNode) { parent.(ValTuple)[i] = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -10605,7 +10605,7 @@ func VisitRefOfValidation(in *Validation, f Visit) error { } // rewriteRefOfValidation is part of the Rewrite implementation -func rewriteRefOfValidation(parent SQLNode, node *Validation, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfValidation(parent SQLNode, node *Validation, replacer replacerFunc) error { if node == nil { return nil } @@ -10614,10 +10614,10 @@ func rewriteRefOfValidation(parent SQLNode, node *Validation, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -10662,7 +10662,7 @@ func VisitValues(in Values, f Visit) error { } // rewriteValues is part of the Rewrite implementation -func rewriteValues(parent SQLNode, node Values, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteValues(parent SQLNode, node Values, replacer replacerFunc) error { if node == nil { return nil } @@ -10671,17 +10671,17 @@ func rewriteValues(parent SQLNode, node Values, replacer replacerFunc, pre, post parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } for i, el := range node { - if errF := rewriteValTuple(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteValTuple(node, el, func(newNode, parent SQLNode) { parent.(Values)[i] = newNode.(ValTuple) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -10723,7 +10723,7 @@ func VisitRefOfValuesFuncExpr(in *ValuesFuncExpr, f Visit) error { } // rewriteRefOfValuesFuncExpr is part of the Rewrite implementation -func rewriteRefOfValuesFuncExpr(parent SQLNode, node *ValuesFuncExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfValuesFuncExpr(parent SQLNode, node *ValuesFuncExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -10732,15 +10732,15 @@ func rewriteRefOfValuesFuncExpr(parent SQLNode, node *ValuesFuncExpr, replacer r parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { parent.(*ValuesFuncExpr).Name = newNode.(*ColName) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -10769,25 +10769,25 @@ func VisitVindexParam(in VindexParam, f Visit) error { } // rewriteVindexParam is part of the Rewrite implementation -func rewriteVindexParam(parent SQLNode, node VindexParam, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteVindexParam(parent SQLNode, node VindexParam, replacer replacerFunc) error { var err error cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Key' on 'VindexParam'") - }, pre, post); errF != nil { + }); errF != nil { return errF } if err != nil { return err } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -10841,7 +10841,7 @@ func VisitRefOfVindexSpec(in *VindexSpec, f Visit) error { } // rewriteRefOfVindexSpec is part of the Rewrite implementation -func rewriteRefOfVindexSpec(parent SQLNode, node *VindexSpec, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfVindexSpec(parent SQLNode, node *VindexSpec, replacer replacerFunc) error { if node == nil { return nil } @@ -10850,27 +10850,27 @@ func rewriteRefOfVindexSpec(parent SQLNode, node *VindexSpec, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*VindexSpec).Name = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteColIdent(node, node.Type, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.Type, func(newNode, parent SQLNode) { parent.(*VindexSpec).Type = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } for i, el := range node.Params { - if errF := rewriteVindexParam(node, el, func(newNode, parent SQLNode) { + if errF := a.rewriteVindexParam(node, el, func(newNode, parent SQLNode) { parent.(*VindexSpec).Params[i] = newNode.(VindexParam) - }, pre, post); errF != nil { + }); errF != nil { return errF } } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -10917,7 +10917,7 @@ func VisitRefOfWhen(in *When, f Visit) error { } // rewriteRefOfWhen is part of the Rewrite implementation -func rewriteRefOfWhen(parent SQLNode, node *When, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfWhen(parent SQLNode, node *When, replacer replacerFunc) error { if node == nil { return nil } @@ -10926,20 +10926,20 @@ func rewriteRefOfWhen(parent SQLNode, node *When, replacer replacerFunc, pre, po parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Cond, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Cond, func(newNode, parent SQLNode) { parent.(*When).Cond = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.Val, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Val, func(newNode, parent SQLNode) { parent.(*When).Val = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -10982,7 +10982,7 @@ func VisitRefOfWhere(in *Where, f Visit) error { } // rewriteRefOfWhere is part of the Rewrite implementation -func rewriteRefOfWhere(parent SQLNode, node *Where, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfWhere(parent SQLNode, node *Where, replacer replacerFunc) error { if node == nil { return nil } @@ -10991,15 +10991,15 @@ func rewriteRefOfWhere(parent SQLNode, node *Where, replacer replacerFunc, pre, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*Where).Expr = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -11046,7 +11046,7 @@ func VisitRefOfXorExpr(in *XorExpr, f Visit) error { } // rewriteRefOfXorExpr is part of the Rewrite implementation -func rewriteRefOfXorExpr(parent SQLNode, node *XorExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfXorExpr(parent SQLNode, node *XorExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -11055,20 +11055,20 @@ func rewriteRefOfXorExpr(parent SQLNode, node *XorExpr, replacer replacerFunc, p parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { parent.(*XorExpr).Left = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { parent.(*XorExpr).Right = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -11304,49 +11304,49 @@ func VisitAlterOption(in AlterOption, f Visit) error { } // rewriteAlterOption is part of the Rewrite implementation -func rewriteAlterOption(parent SQLNode, node AlterOption, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteAlterOption(parent SQLNode, node AlterOption, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case *AddColumns: - return rewriteRefOfAddColumns(parent, node, replacer, pre, post) + return a.rewriteRefOfAddColumns(parent, node, replacer) case *AddConstraintDefinition: - return rewriteRefOfAddConstraintDefinition(parent, node, replacer, pre, post) + return a.rewriteRefOfAddConstraintDefinition(parent, node, replacer) case *AddIndexDefinition: - return rewriteRefOfAddIndexDefinition(parent, node, replacer, pre, post) + return a.rewriteRefOfAddIndexDefinition(parent, node, replacer) case AlgorithmValue: - return rewriteAlgorithmValue(parent, node, replacer, pre, post) + return a.rewriteAlgorithmValue(parent, node, replacer) case *AlterCharset: - return rewriteRefOfAlterCharset(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterCharset(parent, node, replacer) case *AlterColumn: - return rewriteRefOfAlterColumn(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterColumn(parent, node, replacer) case *ChangeColumn: - return rewriteRefOfChangeColumn(parent, node, replacer, pre, post) + return a.rewriteRefOfChangeColumn(parent, node, replacer) case *DropColumn: - return rewriteRefOfDropColumn(parent, node, replacer, pre, post) + return a.rewriteRefOfDropColumn(parent, node, replacer) case *DropKey: - return rewriteRefOfDropKey(parent, node, replacer, pre, post) + return a.rewriteRefOfDropKey(parent, node, replacer) case *Force: - return rewriteRefOfForce(parent, node, replacer, pre, post) + return a.rewriteRefOfForce(parent, node, replacer) case *KeyState: - return rewriteRefOfKeyState(parent, node, replacer, pre, post) + return a.rewriteRefOfKeyState(parent, node, replacer) case *LockOption: - return rewriteRefOfLockOption(parent, node, replacer, pre, post) + return a.rewriteRefOfLockOption(parent, node, replacer) case *ModifyColumn: - return rewriteRefOfModifyColumn(parent, node, replacer, pre, post) + return a.rewriteRefOfModifyColumn(parent, node, replacer) case *OrderByOption: - return rewriteRefOfOrderByOption(parent, node, replacer, pre, post) + return a.rewriteRefOfOrderByOption(parent, node, replacer) case *RenameIndex: - return rewriteRefOfRenameIndex(parent, node, replacer, pre, post) + return a.rewriteRefOfRenameIndex(parent, node, replacer) case *RenameTableName: - return rewriteRefOfRenameTableName(parent, node, replacer, pre, post) + return a.rewriteRefOfRenameTableName(parent, node, replacer) case TableOptions: - return rewriteTableOptions(parent, node, replacer, pre, post) + return a.rewriteTableOptions(parent, node, replacer) case *TablespaceOperation: - return rewriteRefOfTablespaceOperation(parent, node, replacer, pre, post) + return a.rewriteRefOfTablespaceOperation(parent, node, replacer) case *Validation: - return rewriteRefOfValidation(parent, node, replacer, pre, post) + return a.rewriteRefOfValidation(parent, node, replacer) default: // this should never happen return nil @@ -11413,15 +11413,15 @@ func VisitCharacteristic(in Characteristic, f Visit) error { } // rewriteCharacteristic is part of the Rewrite implementation -func rewriteCharacteristic(parent SQLNode, node Characteristic, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteCharacteristic(parent SQLNode, node Characteristic, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case AccessMode: - return rewriteAccessMode(parent, node, replacer, pre, post) + return a.rewriteAccessMode(parent, node, replacer) case IsolationLevel: - return rewriteIsolationLevel(parent, node, replacer, pre, post) + return a.rewriteIsolationLevel(parent, node, replacer) default: // this should never happen return nil @@ -11498,17 +11498,17 @@ func VisitColTuple(in ColTuple, f Visit) error { } // rewriteColTuple is part of the Rewrite implementation -func rewriteColTuple(parent SQLNode, node ColTuple, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteColTuple(parent SQLNode, node ColTuple, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case ListArg: - return rewriteListArg(parent, node, replacer, pre, post) + return a.rewriteListArg(parent, node, replacer) case *Subquery: - return rewriteRefOfSubquery(parent, node, replacer, pre, post) + return a.rewriteRefOfSubquery(parent, node, replacer) case ValTuple: - return rewriteValTuple(parent, node, replacer, pre, post) + return a.rewriteValTuple(parent, node, replacer) default: // this should never happen return nil @@ -11575,15 +11575,15 @@ func VisitConstraintInfo(in ConstraintInfo, f Visit) error { } // rewriteConstraintInfo is part of the Rewrite implementation -func rewriteConstraintInfo(parent SQLNode, node ConstraintInfo, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteConstraintInfo(parent SQLNode, node ConstraintInfo, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case *CheckConstraintDefinition: - return rewriteRefOfCheckConstraintDefinition(parent, node, replacer, pre, post) + return a.rewriteRefOfCheckConstraintDefinition(parent, node, replacer) case *ForeignKeyDefinition: - return rewriteRefOfForeignKeyDefinition(parent, node, replacer, pre, post) + return a.rewriteRefOfForeignKeyDefinition(parent, node, replacer) default: // this should never happen return nil @@ -11660,17 +11660,17 @@ func VisitDBDDLStatement(in DBDDLStatement, f Visit) error { } // rewriteDBDDLStatement is part of the Rewrite implementation -func rewriteDBDDLStatement(parent SQLNode, node DBDDLStatement, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteDBDDLStatement(parent SQLNode, node DBDDLStatement, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case *AlterDatabase: - return rewriteRefOfAlterDatabase(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterDatabase(parent, node, replacer) case *CreateDatabase: - return rewriteRefOfCreateDatabase(parent, node, replacer, pre, post) + return a.rewriteRefOfCreateDatabase(parent, node, replacer) case *DropDatabase: - return rewriteRefOfDropDatabase(parent, node, replacer, pre, post) + return a.rewriteRefOfDropDatabase(parent, node, replacer) default: // this should never happen return nil @@ -11797,27 +11797,27 @@ func VisitDDLStatement(in DDLStatement, f Visit) error { } // rewriteDDLStatement is part of the Rewrite implementation -func rewriteDDLStatement(parent SQLNode, node DDLStatement, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteDDLStatement(parent SQLNode, node DDLStatement, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case *AlterTable: - return rewriteRefOfAlterTable(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterTable(parent, node, replacer) case *AlterView: - return rewriteRefOfAlterView(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterView(parent, node, replacer) case *CreateTable: - return rewriteRefOfCreateTable(parent, node, replacer, pre, post) + return a.rewriteRefOfCreateTable(parent, node, replacer) case *CreateView: - return rewriteRefOfCreateView(parent, node, replacer, pre, post) + return a.rewriteRefOfCreateView(parent, node, replacer) case *DropTable: - return rewriteRefOfDropTable(parent, node, replacer, pre, post) + return a.rewriteRefOfDropTable(parent, node, replacer) case *DropView: - return rewriteRefOfDropView(parent, node, replacer, pre, post) + return a.rewriteRefOfDropView(parent, node, replacer) case *RenameTable: - return rewriteRefOfRenameTable(parent, node, replacer, pre, post) + return a.rewriteRefOfRenameTable(parent, node, replacer) case *TruncateTable: - return rewriteRefOfTruncateTable(parent, node, replacer, pre, post) + return a.rewriteRefOfTruncateTable(parent, node, replacer) default: // this should never happen return nil @@ -11884,15 +11884,15 @@ func VisitExplain(in Explain, f Visit) error { } // rewriteExplain is part of the Rewrite implementation -func rewriteExplain(parent SQLNode, node Explain, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteExplain(parent SQLNode, node Explain, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case *ExplainStmt: - return rewriteRefOfExplainStmt(parent, node, replacer, pre, post) + return a.rewriteRefOfExplainStmt(parent, node, replacer) case *ExplainTab: - return rewriteRefOfExplainTab(parent, node, replacer, pre, post) + return a.rewriteRefOfExplainTab(parent, node, replacer) default: // this should never happen return nil @@ -12249,73 +12249,73 @@ func VisitExpr(in Expr, f Visit) error { } // rewriteExpr is part of the Rewrite implementation -func rewriteExpr(parent SQLNode, node Expr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteExpr(parent SQLNode, node Expr, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case *AndExpr: - return rewriteRefOfAndExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfAndExpr(parent, node, replacer) case Argument: - return rewriteArgument(parent, node, replacer, pre, post) + return a.rewriteArgument(parent, node, replacer) case *BinaryExpr: - return rewriteRefOfBinaryExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfBinaryExpr(parent, node, replacer) case BoolVal: - return rewriteBoolVal(parent, node, replacer, pre, post) + return a.rewriteBoolVal(parent, node, replacer) case *CaseExpr: - return rewriteRefOfCaseExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfCaseExpr(parent, node, replacer) case *ColName: - return rewriteRefOfColName(parent, node, replacer, pre, post) + return a.rewriteRefOfColName(parent, node, replacer) case *CollateExpr: - return rewriteRefOfCollateExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfCollateExpr(parent, node, replacer) case *ComparisonExpr: - return rewriteRefOfComparisonExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfComparisonExpr(parent, node, replacer) case *ConvertExpr: - return rewriteRefOfConvertExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfConvertExpr(parent, node, replacer) case *ConvertUsingExpr: - return rewriteRefOfConvertUsingExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfConvertUsingExpr(parent, node, replacer) case *CurTimeFuncExpr: - return rewriteRefOfCurTimeFuncExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfCurTimeFuncExpr(parent, node, replacer) case *Default: - return rewriteRefOfDefault(parent, node, replacer, pre, post) + return a.rewriteRefOfDefault(parent, node, replacer) case *ExistsExpr: - return rewriteRefOfExistsExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfExistsExpr(parent, node, replacer) case *FuncExpr: - return rewriteRefOfFuncExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfFuncExpr(parent, node, replacer) case *GroupConcatExpr: - return rewriteRefOfGroupConcatExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfGroupConcatExpr(parent, node, replacer) case *IntervalExpr: - return rewriteRefOfIntervalExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfIntervalExpr(parent, node, replacer) case *IsExpr: - return rewriteRefOfIsExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfIsExpr(parent, node, replacer) case ListArg: - return rewriteListArg(parent, node, replacer, pre, post) + return a.rewriteListArg(parent, node, replacer) case *Literal: - return rewriteRefOfLiteral(parent, node, replacer, pre, post) + return a.rewriteRefOfLiteral(parent, node, replacer) case *MatchExpr: - return rewriteRefOfMatchExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfMatchExpr(parent, node, replacer) case *NotExpr: - return rewriteRefOfNotExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfNotExpr(parent, node, replacer) case *NullVal: - return rewriteRefOfNullVal(parent, node, replacer, pre, post) + return a.rewriteRefOfNullVal(parent, node, replacer) case *OrExpr: - return rewriteRefOfOrExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfOrExpr(parent, node, replacer) case *RangeCond: - return rewriteRefOfRangeCond(parent, node, replacer, pre, post) + return a.rewriteRefOfRangeCond(parent, node, replacer) case *Subquery: - return rewriteRefOfSubquery(parent, node, replacer, pre, post) + return a.rewriteRefOfSubquery(parent, node, replacer) case *SubstrExpr: - return rewriteRefOfSubstrExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfSubstrExpr(parent, node, replacer) case *TimestampFuncExpr: - return rewriteRefOfTimestampFuncExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfTimestampFuncExpr(parent, node, replacer) case *UnaryExpr: - return rewriteRefOfUnaryExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfUnaryExpr(parent, node, replacer) case ValTuple: - return rewriteValTuple(parent, node, replacer, pre, post) + return a.rewriteValTuple(parent, node, replacer) case *ValuesFuncExpr: - return rewriteRefOfValuesFuncExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfValuesFuncExpr(parent, node, replacer) case *XorExpr: - return rewriteRefOfXorExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfXorExpr(parent, node, replacer) default: // this should never happen return nil @@ -12402,19 +12402,19 @@ func VisitInsertRows(in InsertRows, f Visit) error { } // rewriteInsertRows is part of the Rewrite implementation -func rewriteInsertRows(parent SQLNode, node InsertRows, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteInsertRows(parent SQLNode, node InsertRows, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case *ParenSelect: - return rewriteRefOfParenSelect(parent, node, replacer, pre, post) + return a.rewriteRefOfParenSelect(parent, node, replacer) case *Select: - return rewriteRefOfSelect(parent, node, replacer, pre, post) + return a.rewriteRefOfSelect(parent, node, replacer) case *Union: - return rewriteRefOfUnion(parent, node, replacer, pre, post) + return a.rewriteRefOfUnion(parent, node, replacer) case Values: - return rewriteValues(parent, node, replacer, pre, post) + return a.rewriteValues(parent, node, replacer) default: // this should never happen return nil @@ -12491,17 +12491,17 @@ func VisitSelectExpr(in SelectExpr, f Visit) error { } // rewriteSelectExpr is part of the Rewrite implementation -func rewriteSelectExpr(parent SQLNode, node SelectExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteSelectExpr(parent SQLNode, node SelectExpr, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case *AliasedExpr: - return rewriteRefOfAliasedExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfAliasedExpr(parent, node, replacer) case *Nextval: - return rewriteRefOfNextval(parent, node, replacer, pre, post) + return a.rewriteRefOfNextval(parent, node, replacer) case *StarExpr: - return rewriteRefOfStarExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfStarExpr(parent, node, replacer) default: // this should never happen return nil @@ -12578,17 +12578,17 @@ func VisitSelectStatement(in SelectStatement, f Visit) error { } // rewriteSelectStatement is part of the Rewrite implementation -func rewriteSelectStatement(parent SQLNode, node SelectStatement, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteSelectStatement(parent SQLNode, node SelectStatement, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case *ParenSelect: - return rewriteRefOfParenSelect(parent, node, replacer, pre, post) + return a.rewriteRefOfParenSelect(parent, node, replacer) case *Select: - return rewriteRefOfSelect(parent, node, replacer, pre, post) + return a.rewriteRefOfSelect(parent, node, replacer) case *Union: - return rewriteRefOfUnion(parent, node, replacer, pre, post) + return a.rewriteRefOfUnion(parent, node, replacer) default: // this should never happen return nil @@ -12665,17 +12665,17 @@ func VisitShowInternal(in ShowInternal, f Visit) error { } // rewriteShowInternal is part of the Rewrite implementation -func rewriteShowInternal(parent SQLNode, node ShowInternal, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteShowInternal(parent SQLNode, node ShowInternal, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case *ShowBasic: - return rewriteRefOfShowBasic(parent, node, replacer, pre, post) + return a.rewriteRefOfShowBasic(parent, node, replacer) case *ShowCreate: - return rewriteRefOfShowCreate(parent, node, replacer, pre, post) + return a.rewriteRefOfShowCreate(parent, node, replacer) case *ShowLegacy: - return rewriteRefOfShowLegacy(parent, node, replacer, pre, post) + return a.rewriteRefOfShowLegacy(parent, node, replacer) default: // this should never happen return nil @@ -12742,15 +12742,15 @@ func VisitSimpleTableExpr(in SimpleTableExpr, f Visit) error { } // rewriteSimpleTableExpr is part of the Rewrite implementation -func rewriteSimpleTableExpr(parent SQLNode, node SimpleTableExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteSimpleTableExpr(parent SQLNode, node SimpleTableExpr, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case *DerivedTable: - return rewriteRefOfDerivedTable(parent, node, replacer, pre, post) + return a.rewriteRefOfDerivedTable(parent, node, replacer) case TableName: - return rewriteTableName(parent, node, replacer, pre, post) + return a.rewriteTableName(parent, node, replacer) default: // this should never happen return nil @@ -13207,93 +13207,93 @@ func VisitStatement(in Statement, f Visit) error { } // rewriteStatement is part of the Rewrite implementation -func rewriteStatement(parent SQLNode, node Statement, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteStatement(parent SQLNode, node Statement, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case *AlterDatabase: - return rewriteRefOfAlterDatabase(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterDatabase(parent, node, replacer) case *AlterMigration: - return rewriteRefOfAlterMigration(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterMigration(parent, node, replacer) case *AlterTable: - return rewriteRefOfAlterTable(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterTable(parent, node, replacer) case *AlterView: - return rewriteRefOfAlterView(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterView(parent, node, replacer) case *AlterVschema: - return rewriteRefOfAlterVschema(parent, node, replacer, pre, post) + return a.rewriteRefOfAlterVschema(parent, node, replacer) case *Begin: - return rewriteRefOfBegin(parent, node, replacer, pre, post) + return a.rewriteRefOfBegin(parent, node, replacer) case *CallProc: - return rewriteRefOfCallProc(parent, node, replacer, pre, post) + return a.rewriteRefOfCallProc(parent, node, replacer) case *Commit: - return rewriteRefOfCommit(parent, node, replacer, pre, post) + return a.rewriteRefOfCommit(parent, node, replacer) case *CreateDatabase: - return rewriteRefOfCreateDatabase(parent, node, replacer, pre, post) + return a.rewriteRefOfCreateDatabase(parent, node, replacer) case *CreateTable: - return rewriteRefOfCreateTable(parent, node, replacer, pre, post) + return a.rewriteRefOfCreateTable(parent, node, replacer) case *CreateView: - return rewriteRefOfCreateView(parent, node, replacer, pre, post) + return a.rewriteRefOfCreateView(parent, node, replacer) case *Delete: - return rewriteRefOfDelete(parent, node, replacer, pre, post) + return a.rewriteRefOfDelete(parent, node, replacer) case *DropDatabase: - return rewriteRefOfDropDatabase(parent, node, replacer, pre, post) + return a.rewriteRefOfDropDatabase(parent, node, replacer) case *DropTable: - return rewriteRefOfDropTable(parent, node, replacer, pre, post) + return a.rewriteRefOfDropTable(parent, node, replacer) case *DropView: - return rewriteRefOfDropView(parent, node, replacer, pre, post) + return a.rewriteRefOfDropView(parent, node, replacer) case *ExplainStmt: - return rewriteRefOfExplainStmt(parent, node, replacer, pre, post) + return a.rewriteRefOfExplainStmt(parent, node, replacer) case *ExplainTab: - return rewriteRefOfExplainTab(parent, node, replacer, pre, post) + return a.rewriteRefOfExplainTab(parent, node, replacer) case *Flush: - return rewriteRefOfFlush(parent, node, replacer, pre, post) + return a.rewriteRefOfFlush(parent, node, replacer) case *Insert: - return rewriteRefOfInsert(parent, node, replacer, pre, post) + return a.rewriteRefOfInsert(parent, node, replacer) case *Load: - return rewriteRefOfLoad(parent, node, replacer, pre, post) + return a.rewriteRefOfLoad(parent, node, replacer) case *LockTables: - return rewriteRefOfLockTables(parent, node, replacer, pre, post) + return a.rewriteRefOfLockTables(parent, node, replacer) case *OtherAdmin: - return rewriteRefOfOtherAdmin(parent, node, replacer, pre, post) + return a.rewriteRefOfOtherAdmin(parent, node, replacer) case *OtherRead: - return rewriteRefOfOtherRead(parent, node, replacer, pre, post) + return a.rewriteRefOfOtherRead(parent, node, replacer) case *ParenSelect: - return rewriteRefOfParenSelect(parent, node, replacer, pre, post) + return a.rewriteRefOfParenSelect(parent, node, replacer) case *Release: - return rewriteRefOfRelease(parent, node, replacer, pre, post) + return a.rewriteRefOfRelease(parent, node, replacer) case *RenameTable: - return rewriteRefOfRenameTable(parent, node, replacer, pre, post) + return a.rewriteRefOfRenameTable(parent, node, replacer) case *RevertMigration: - return rewriteRefOfRevertMigration(parent, node, replacer, pre, post) + return a.rewriteRefOfRevertMigration(parent, node, replacer) case *Rollback: - return rewriteRefOfRollback(parent, node, replacer, pre, post) + return a.rewriteRefOfRollback(parent, node, replacer) case *SRollback: - return rewriteRefOfSRollback(parent, node, replacer, pre, post) + return a.rewriteRefOfSRollback(parent, node, replacer) case *Savepoint: - return rewriteRefOfSavepoint(parent, node, replacer, pre, post) + return a.rewriteRefOfSavepoint(parent, node, replacer) case *Select: - return rewriteRefOfSelect(parent, node, replacer, pre, post) + return a.rewriteRefOfSelect(parent, node, replacer) case *Set: - return rewriteRefOfSet(parent, node, replacer, pre, post) + return a.rewriteRefOfSet(parent, node, replacer) case *SetTransaction: - return rewriteRefOfSetTransaction(parent, node, replacer, pre, post) + return a.rewriteRefOfSetTransaction(parent, node, replacer) case *Show: - return rewriteRefOfShow(parent, node, replacer, pre, post) + return a.rewriteRefOfShow(parent, node, replacer) case *Stream: - return rewriteRefOfStream(parent, node, replacer, pre, post) + return a.rewriteRefOfStream(parent, node, replacer) case *TruncateTable: - return rewriteRefOfTruncateTable(parent, node, replacer, pre, post) + return a.rewriteRefOfTruncateTable(parent, node, replacer) case *Union: - return rewriteRefOfUnion(parent, node, replacer, pre, post) + return a.rewriteRefOfUnion(parent, node, replacer) case *UnlockTables: - return rewriteRefOfUnlockTables(parent, node, replacer, pre, post) + return a.rewriteRefOfUnlockTables(parent, node, replacer) case *Update: - return rewriteRefOfUpdate(parent, node, replacer, pre, post) + return a.rewriteRefOfUpdate(parent, node, replacer) case *Use: - return rewriteRefOfUse(parent, node, replacer, pre, post) + return a.rewriteRefOfUse(parent, node, replacer) case *VStream: - return rewriteRefOfVStream(parent, node, replacer, pre, post) + return a.rewriteRefOfVStream(parent, node, replacer) default: // this should never happen return nil @@ -13370,17 +13370,17 @@ func VisitTableExpr(in TableExpr, f Visit) error { } // rewriteTableExpr is part of the Rewrite implementation -func rewriteTableExpr(parent SQLNode, node TableExpr, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteTableExpr(parent SQLNode, node TableExpr, replacer replacerFunc) error { if node == nil { return nil } switch node := node.(type) { case *AliasedTableExpr: - return rewriteRefOfAliasedTableExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfAliasedTableExpr(parent, node, replacer) case *JoinTableExpr: - return rewriteRefOfJoinTableExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfJoinTableExpr(parent, node, replacer) case *ParenTableExpr: - return rewriteRefOfParenTableExpr(parent, node, replacer, pre, post) + return a.rewriteRefOfParenTableExpr(parent, node, replacer) default: // this should never happen return nil @@ -13394,16 +13394,16 @@ func VisitAccessMode(in AccessMode, f Visit) error { } // rewriteAccessMode is part of the Rewrite implementation -func rewriteAccessMode(parent SQLNode, node AccessMode, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteAccessMode(parent SQLNode, node AccessMode, replacer replacerFunc) error { cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -13416,16 +13416,16 @@ func VisitAlgorithmValue(in AlgorithmValue, f Visit) error { } // rewriteAlgorithmValue is part of the Rewrite implementation -func rewriteAlgorithmValue(parent SQLNode, node AlgorithmValue, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteAlgorithmValue(parent SQLNode, node AlgorithmValue, replacer replacerFunc) error { cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -13438,16 +13438,16 @@ func VisitArgument(in Argument, f Visit) error { } // rewriteArgument is part of the Rewrite implementation -func rewriteArgument(parent SQLNode, node Argument, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteArgument(parent SQLNode, node Argument, replacer replacerFunc) error { cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -13460,16 +13460,16 @@ func VisitBoolVal(in BoolVal, f Visit) error { } // rewriteBoolVal is part of the Rewrite implementation -func rewriteBoolVal(parent SQLNode, node BoolVal, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteBoolVal(parent SQLNode, node BoolVal, replacer replacerFunc) error { cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -13482,16 +13482,16 @@ func VisitIsolationLevel(in IsolationLevel, f Visit) error { } // rewriteIsolationLevel is part of the Rewrite implementation -func rewriteIsolationLevel(parent SQLNode, node IsolationLevel, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteIsolationLevel(parent SQLNode, node IsolationLevel, replacer replacerFunc) error { cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -13504,16 +13504,16 @@ func VisitReferenceAction(in ReferenceAction, f Visit) error { } // rewriteReferenceAction is part of the Rewrite implementation -func rewriteReferenceAction(parent SQLNode, node ReferenceAction, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteReferenceAction(parent SQLNode, node ReferenceAction, replacer replacerFunc) error { cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -13663,7 +13663,7 @@ func VisitRefOfColIdent(in *ColIdent, f Visit) error { } // rewriteRefOfColIdent is part of the Rewrite implementation -func rewriteRefOfColIdent(parent SQLNode, node *ColIdent, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfColIdent(parent SQLNode, node *ColIdent, replacer replacerFunc) error { if node == nil { return nil } @@ -13672,10 +13672,10 @@ func rewriteRefOfColIdent(parent SQLNode, node *ColIdent, replacer replacerFunc, parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -13832,7 +13832,7 @@ func VisitRefOfJoinCondition(in *JoinCondition, f Visit) error { } // rewriteRefOfJoinCondition is part of the Rewrite implementation -func rewriteRefOfJoinCondition(parent SQLNode, node *JoinCondition, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfJoinCondition(parent SQLNode, node *JoinCondition, replacer replacerFunc) error { if node == nil { return nil } @@ -13841,20 +13841,20 @@ func rewriteRefOfJoinCondition(parent SQLNode, node *JoinCondition, replacer rep parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteExpr(node, node.On, func(newNode, parent SQLNode) { + if errF := a.rewriteExpr(node, node.On, func(newNode, parent SQLNode) { parent.(*JoinCondition).On = newNode.(Expr) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteColumns(node, node.Using, func(newNode, parent SQLNode) { + if errF := a.rewriteColumns(node, node.Using, func(newNode, parent SQLNode) { parent.(*JoinCondition).Using = newNode.(Columns) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -14023,7 +14023,7 @@ func VisitRefOfTableIdent(in *TableIdent, f Visit) error { } // rewriteRefOfTableIdent is part of the Rewrite implementation -func rewriteRefOfTableIdent(parent SQLNode, node *TableIdent, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfTableIdent(parent SQLNode, node *TableIdent, replacer replacerFunc) error { if node == nil { return nil } @@ -14032,10 +14032,10 @@ func rewriteRefOfTableIdent(parent SQLNode, node *TableIdent, replacer replacerF parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -14082,7 +14082,7 @@ func VisitRefOfTableName(in *TableName, f Visit) error { } // rewriteRefOfTableName is part of the Rewrite implementation -func rewriteRefOfTableName(parent SQLNode, node *TableName, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfTableName(parent SQLNode, node *TableName, replacer replacerFunc) error { if node == nil { return nil } @@ -14091,20 +14091,20 @@ func rewriteRefOfTableName(parent SQLNode, node *TableName, replacer replacerFun parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { + if errF := a.rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*TableName).Name = newNode.(TableIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if errF := rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { + if errF := a.rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { parent.(*TableName).Qualifier = newNode.(TableIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil @@ -14238,7 +14238,7 @@ func VisitRefOfVindexParam(in *VindexParam, f Visit) error { } // rewriteRefOfVindexParam is part of the Rewrite implementation -func rewriteRefOfVindexParam(parent SQLNode, node *VindexParam, replacer replacerFunc, pre, post ApplyFunc) error { +func (a *application) rewriteRefOfVindexParam(parent SQLNode, node *VindexParam, replacer replacerFunc) error { if node == nil { return nil } @@ -14247,15 +14247,15 @@ func rewriteRefOfVindexParam(parent SQLNode, node *VindexParam, replacer replace parent: parent, replacer: replacer, } - if pre != nil && !pre(&cur) { + if a.pre != nil && !a.pre(&cur) { return nil } - if errF := rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { + if errF := a.rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { parent.(*VindexParam).Key = newNode.(ColIdent) - }, pre, post); errF != nil { + }); errF != nil { return errF } - if post != nil && !post(&cur) { + if a.post != nil && !a.post(&cur) { return errAbort } return nil diff --git a/go/vt/sqlparser/rewriter_api.go b/go/vt/sqlparser/rewriter_api.go index 53371224a96..bd8dd1efbff 100644 --- a/go/vt/sqlparser/rewriter_api.go +++ b/go/vt/sqlparser/rewriter_api.go @@ -46,7 +46,12 @@ func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode, err error) { parent.SQLNode = newNode } - err = rewriteSQLNode(parent, node, replacer, pre, post) + a := &application{ + pre: pre, + post: post, + } + + err = a.rewriteSQLNode(parent, node, replacer) if err != nil && err != errAbort { return nil, err } @@ -87,3 +92,9 @@ func (c *Cursor) Replace(newNode SQLNode) { } type replacerFunc func(newNode, parent SQLNode) + +// application carries all the shared data so we can pass it around cheaply. +type application struct { + pre, post ApplyFunc + cursor Cursor +} From b65c0b4ca1b251b31f4bd1afab21cee421d478b7 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Sat, 20 Mar 2021 08:53:08 +0100 Subject: [PATCH 10/15] sort asthelper methods Signed-off-by: Andres Taylor --- go/tools/asthelpergen/asthelpergen.go | 100 +- .../asthelpergen/integration/ast_helper.go | 1432 +- go/vt/sqlparser/ast_helper.go | 21098 ++++++++-------- 3 files changed, 10957 insertions(+), 11673 deletions(-) diff --git a/go/tools/asthelpergen/asthelpergen.go b/go/tools/asthelpergen/asthelpergen.go index 0fbcd19cb8d..e22c48936e1 100644 --- a/go/tools/asthelpergen/asthelpergen.go +++ b/go/tools/asthelpergen/asthelpergen.go @@ -23,6 +23,7 @@ import ( "io/ioutil" "log" "path" + "sort" "strings" "github.com/dave/jennifer/jen" @@ -43,38 +44,60 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.` -type generatorSPI interface { - addType(t types.Type) - addFunc(name string, t methodType, code jen.Code) - scope() *types.Scope - findImplementations(iff *types.Interface, impl func(types.Type) error) error - iface() *types.Interface +type ( + generatorSPI interface { + addType(t types.Type) + addFunc(name string, t methodType, code jen.Code) + scope() *types.Scope + findImplementations(iff *types.Interface, impl func(types.Type) error) error + iface() *types.Interface + } + generator2 interface { + interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error + structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error + ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error + ptrToBasicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error + sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error + basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error + } + // astHelperGen finds implementations of the given interface, + // and uses the supplied `generator`s to produce the output code + astHelperGen struct { + DebugTypes bool + mod *packages.Module + sizes types.Sizes + namedIface *types.Named + _iface *types.Interface + gens []generator2 + + functions methods + _scope *types.Scope + todo []types.Type + } + + method struct { + name string + code jen.Code + typ methodType + } + + methods []method +) + +func (m methods) Len() int { + return len(m) } -type generator2 interface { - interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error - structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error - ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error - ptrToBasicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error - sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error - basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error +func (m methods) Less(i, j int) bool { + return m[i].name < m[j].name } -// astHelperGen finds implementations of the given interface, -// and uses the supplied `generator`s to produce the output code -type astHelperGen struct { - DebugTypes bool - mod *packages.Module - sizes types.Sizes - namedIface *types.Named - _iface *types.Interface - gens []generator2 - - methods []jen.Code - _scope *types.Scope - todo []types.Type +func (m methods) Swap(i, j int) { + m[i], m[j] = m[j], m[i] } +var _ sort.Interface = (methods)(nil) + func (gen *astHelperGen) iface() *types.Interface { return gen._iface } @@ -264,18 +287,7 @@ const ( ) func (gen *astHelperGen) addFunc(name string, typ methodType, code jen.Code) { - var comment string - switch typ { - case clone: - comment = " creates a deep clone of the input." - case equals: - comment = " does deep equals between the two objects." - case visit: - comment = " will visit all parts of the AST" - case rewrite: - comment = " is part of the Rewrite implementation" - } - gen.methods = append(gen.methods, jen.Comment(name+comment), code) + gen.functions = append(gen.functions, method{name: name, code: code, typ: typ}) } func (gen *astHelperGen) createFile(pkgName string) (string, *jen.File) { @@ -334,8 +346,16 @@ func (gen *astHelperGen) createFile(pkgName string) (string, *jen.File) { alreadyDone[typeName] = true } - for _, method := range gen.methods { - out.Add(method) + sort.Sort(gen.functions) + + for _, m := range gen.functions { + switch m.typ { + case clone: + out.Add(jen.Comment(fmt.Sprintf("%s creates a deep clone of the input.", m.name))) + case equals: + out.Add(jen.Comment(fmt.Sprintf("%s does deep equals between the two objects.", m.name))) + } + out.Add(m.code) } return "ast_helper.go", out diff --git a/go/tools/asthelpergen/integration/ast_helper.go b/go/tools/asthelpergen/integration/ast_helper.go index e5d904b83a5..1c4cb991de7 100644 --- a/go/tools/asthelpergen/integration/ast_helper.go +++ b/go/tools/asthelpergen/integration/ast_helper.go @@ -22,6 +22,211 @@ import ( vterrors "vitess.io/vitess/go/vt/vterrors" ) +// CloneAST creates a deep clone of the input. +func CloneAST(in AST) AST { + if in == nil { + return nil + } + switch in := in.(type) { + case BasicType: + return in + case Bytes: + return CloneBytes(in) + case InterfaceContainer: + return CloneInterfaceContainer(in) + case InterfaceSlice: + return CloneInterfaceSlice(in) + case *Leaf: + return CloneRefOfLeaf(in) + case LeafSlice: + return CloneLeafSlice(in) + case *NoCloneType: + return CloneRefOfNoCloneType(in) + case *RefContainer: + return CloneRefOfRefContainer(in) + case *RefSliceContainer: + return CloneRefOfRefSliceContainer(in) + case *SubImpl: + return CloneRefOfSubImpl(in) + case ValueContainer: + return CloneValueContainer(in) + case ValueSliceContainer: + return CloneValueSliceContainer(in) + default: + // this should never happen + return nil + } +} + +// CloneBytes creates a deep clone of the input. +func CloneBytes(n Bytes) Bytes { + res := make(Bytes, 0, len(n)) + copy(res, n) + return res +} + +// CloneInterfaceContainer creates a deep clone of the input. +func CloneInterfaceContainer(n InterfaceContainer) InterfaceContainer { + return *CloneRefOfInterfaceContainer(&n) +} + +// CloneInterfaceSlice creates a deep clone of the input. +func CloneInterfaceSlice(n InterfaceSlice) InterfaceSlice { + res := make(InterfaceSlice, 0, len(n)) + for _, x := range n { + res = append(res, CloneAST(x)) + } + return res +} + +// CloneLeafSlice creates a deep clone of the input. +func CloneLeafSlice(n LeafSlice) LeafSlice { + res := make(LeafSlice, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfLeaf(x)) + } + return res +} + +// CloneRefOfBool creates a deep clone of the input. +func CloneRefOfBool(n *bool) *bool { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfInterfaceContainer creates a deep clone of the input. +func CloneRefOfInterfaceContainer(n *InterfaceContainer) *InterfaceContainer { + if n == nil { + return nil + } + out := *n + out.v = n.v + return &out +} + +// CloneRefOfLeaf creates a deep clone of the input. +func CloneRefOfLeaf(n *Leaf) *Leaf { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfNoCloneType creates a deep clone of the input. +func CloneRefOfNoCloneType(n *NoCloneType) *NoCloneType { + return n +} + +// CloneRefOfRefContainer creates a deep clone of the input. +func CloneRefOfRefContainer(n *RefContainer) *RefContainer { + if n == nil { + return nil + } + out := *n + out.ASTType = CloneAST(n.ASTType) + out.ASTImplementationType = CloneRefOfLeaf(n.ASTImplementationType) + return &out +} + +// CloneRefOfRefSliceContainer creates a deep clone of the input. +func CloneRefOfRefSliceContainer(n *RefSliceContainer) *RefSliceContainer { + if n == nil { + return nil + } + out := *n + out.ASTElements = CloneSliceOfAST(n.ASTElements) + out.NotASTElements = CloneSliceOfInt(n.NotASTElements) + out.ASTImplementationElements = CloneSliceOfRefOfLeaf(n.ASTImplementationElements) + return &out +} + +// CloneRefOfSubImpl creates a deep clone of the input. +func CloneRefOfSubImpl(n *SubImpl) *SubImpl { + if n == nil { + return nil + } + out := *n + out.inner = CloneSubIface(n.inner) + out.field = CloneRefOfBool(n.field) + return &out +} + +// CloneRefOfValueContainer creates a deep clone of the input. +func CloneRefOfValueContainer(n *ValueContainer) *ValueContainer { + if n == nil { + return nil + } + out := *n + out.ASTType = CloneAST(n.ASTType) + out.ASTImplementationType = CloneRefOfLeaf(n.ASTImplementationType) + return &out +} + +// CloneRefOfValueSliceContainer creates a deep clone of the input. +func CloneRefOfValueSliceContainer(n *ValueSliceContainer) *ValueSliceContainer { + if n == nil { + return nil + } + out := *n + out.ASTElements = CloneSliceOfAST(n.ASTElements) + out.NotASTElements = CloneSliceOfInt(n.NotASTElements) + out.ASTImplementationElements = CloneSliceOfRefOfLeaf(n.ASTImplementationElements) + return &out +} + +// CloneSliceOfAST creates a deep clone of the input. +func CloneSliceOfAST(n []AST) []AST { + res := make([]AST, 0, len(n)) + for _, x := range n { + res = append(res, CloneAST(x)) + } + return res +} + +// CloneSliceOfInt creates a deep clone of the input. +func CloneSliceOfInt(n []int) []int { + res := make([]int, 0, len(n)) + copy(res, n) + return res +} + +// CloneSliceOfRefOfLeaf creates a deep clone of the input. +func CloneSliceOfRefOfLeaf(n []*Leaf) []*Leaf { + res := make([]*Leaf, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfLeaf(x)) + } + return res +} + +// CloneSubIface creates a deep clone of the input. +func CloneSubIface(in SubIface) SubIface { + if in == nil { + return nil + } + switch in := in.(type) { + case *SubImpl: + return CloneRefOfSubImpl(in) + default: + // this should never happen + return nil + } +} + +// CloneValueContainer creates a deep clone of the input. +func CloneValueContainer(n ValueContainer) ValueContainer { + return *CloneRefOfValueContainer(&n) +} + +// CloneValueSliceContainer creates a deep clone of the input. +func CloneValueSliceContainer(n ValueSliceContainer) ValueSliceContainer { + return *CloneRefOfValueSliceContainer(&n) +} + // EqualsAST does deep equals between the two objects. func EqualsAST(inA, inB AST) bool { if inA == nil && inB == nil { @@ -109,199 +314,160 @@ func EqualsAST(inA, inB AST) bool { } } -// CloneAST creates a deep clone of the input. -func CloneAST(in AST) AST { - if in == nil { - return nil +// EqualsBytes does deep equals between the two objects. +func EqualsBytes(a, b Bytes) bool { + if len(a) != len(b) { + return false } - switch in := in.(type) { - case BasicType: - return in - case Bytes: - return CloneBytes(in) - case InterfaceContainer: - return CloneInterfaceContainer(in) - case InterfaceSlice: - return CloneInterfaceSlice(in) - case *Leaf: - return CloneRefOfLeaf(in) - case LeafSlice: - return CloneLeafSlice(in) - case *NoCloneType: - return CloneRefOfNoCloneType(in) - case *RefContainer: - return CloneRefOfRefContainer(in) - case *RefSliceContainer: - return CloneRefOfRefSliceContainer(in) - case *SubImpl: - return CloneRefOfSubImpl(in) - case ValueContainer: - return CloneValueContainer(in) - case ValueSliceContainer: - return CloneValueSliceContainer(in) - default: - // this should never happen - return nil + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false + } } + return true } -// VisitAST will visit all parts of the AST -func VisitAST(in AST, f Visit) error { - if in == nil { - return nil - } - switch in := in.(type) { - case BasicType: - return VisitBasicType(in, f) - case Bytes: - return VisitBytes(in, f) - case InterfaceContainer: - return VisitInterfaceContainer(in, f) - case InterfaceSlice: - return VisitInterfaceSlice(in, f) - case *Leaf: - return VisitRefOfLeaf(in, f) - case LeafSlice: - return VisitLeafSlice(in, f) - case *NoCloneType: - return VisitRefOfNoCloneType(in, f) - case *RefContainer: - return VisitRefOfRefContainer(in, f) - case *RefSliceContainer: - return VisitRefOfRefSliceContainer(in, f) - case *SubImpl: - return VisitRefOfSubImpl(in, f) - case ValueContainer: - return VisitValueContainer(in, f) - case ValueSliceContainer: - return VisitValueSliceContainer(in, f) - default: - // this should never happen - return nil - } +// EqualsInterfaceContainer does deep equals between the two objects. +func EqualsInterfaceContainer(a, b InterfaceContainer) bool { + return true } -// rewriteAST is part of the Rewrite implementation -func (a *application) rewriteAST(parent AST, node AST, replacer replacerFunc) error { - if node == nil { - return nil +// EqualsInterfaceSlice does deep equals between the two objects. +func EqualsInterfaceSlice(a, b InterfaceSlice) bool { + if len(a) != len(b) { + return false } - switch node := node.(type) { - case BasicType: - return a.rewriteBasicType(parent, node, replacer) - case Bytes: - return a.rewriteBytes(parent, node, replacer) - case InterfaceContainer: - return a.rewriteInterfaceContainer(parent, node, replacer) - case InterfaceSlice: - return a.rewriteInterfaceSlice(parent, node, replacer) - case *Leaf: - return a.rewriteRefOfLeaf(parent, node, replacer) - case LeafSlice: - return a.rewriteLeafSlice(parent, node, replacer) - case *NoCloneType: - return a.rewriteRefOfNoCloneType(parent, node, replacer) - case *RefContainer: - return a.rewriteRefOfRefContainer(parent, node, replacer) - case *RefSliceContainer: - return a.rewriteRefOfRefSliceContainer(parent, node, replacer) - case *SubImpl: - return a.rewriteRefOfSubImpl(parent, node, replacer) - case ValueContainer: - return a.rewriteValueContainer(parent, node, replacer) - case ValueSliceContainer: - return a.rewriteValueSliceContainer(parent, node, replacer) - default: - // this should never happen - return nil + for i := 0; i < len(a); i++ { + if !EqualsAST(a[i], b[i]) { + return false + } } + return true } -// EqualsBytes does deep equals between the two objects. -func EqualsBytes(a, b Bytes) bool { +// EqualsLeafSlice does deep equals between the two objects. +func EqualsLeafSlice(a, b LeafSlice) bool { if len(a) != len(b) { return false } for i := 0; i < len(a); i++ { - if a[i] != b[i] { + if !EqualsRefOfLeaf(a[i], b[i]) { return false } } return true } -// CloneBytes creates a deep clone of the input. -func CloneBytes(n Bytes) Bytes { - res := make(Bytes, 0, len(n)) - copy(res, n) - return res +// EqualsRefOfBool does deep equals between the two objects. +func EqualsRefOfBool(a, b *bool) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b } -// VisitBytes will visit all parts of the AST -func VisitBytes(in Bytes, f Visit) error { - _, err := f(in) - return err +// EqualsRefOfInterfaceContainer does deep equals between the two objects. +func EqualsRefOfInterfaceContainer(a, b *InterfaceContainer) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return true } -// rewriteBytes is part of the Rewrite implementation -func (a *application) rewriteBytes(parent AST, node Bytes, replacer replacerFunc) error { - if node == nil { - return nil +// EqualsRefOfLeaf does deep equals between the two objects. +func EqualsRefOfLeaf(a, b *Leaf) bool { + if a == b { + return true } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if a == nil || b == nil { + return false } - if a.pre != nil && !a.pre(&cur) { - return nil + return a.v == b.v +} + +// EqualsRefOfNoCloneType does deep equals between the two objects. +func EqualsRefOfNoCloneType(a, b *NoCloneType) bool { + if a == b { + return true } - if a.post != nil && !a.post(&cur) { - return errAbort + if a == nil || b == nil { + return false } - return nil + return a.v == b.v } -// EqualsInterfaceContainer does deep equals between the two objects. -func EqualsInterfaceContainer(a, b InterfaceContainer) bool { - return true +// EqualsRefOfRefContainer does deep equals between the two objects. +func EqualsRefOfRefContainer(a, b *RefContainer) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.NotASTType == b.NotASTType && + EqualsAST(a.ASTType, b.ASTType) && + EqualsRefOfLeaf(a.ASTImplementationType, b.ASTImplementationType) } -// CloneInterfaceContainer creates a deep clone of the input. -func CloneInterfaceContainer(n InterfaceContainer) InterfaceContainer { - return *CloneRefOfInterfaceContainer(&n) +// EqualsRefOfRefSliceContainer does deep equals between the two objects. +func EqualsRefOfRefSliceContainer(a, b *RefSliceContainer) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsSliceOfAST(a.ASTElements, b.ASTElements) && + EqualsSliceOfInt(a.NotASTElements, b.NotASTElements) && + EqualsSliceOfRefOfLeaf(a.ASTImplementationElements, b.ASTImplementationElements) } -// VisitInterfaceContainer will visit all parts of the AST -func VisitInterfaceContainer(in InterfaceContainer, f Visit) error { - if cont, err := f(in); err != nil || !cont { - return err +// EqualsRefOfSubImpl does deep equals between the two objects. +func EqualsRefOfSubImpl(a, b *SubImpl) bool { + if a == b { + return true } - return nil + if a == nil || b == nil { + return false + } + return EqualsSubIface(a.inner, b.inner) && + EqualsRefOfBool(a.field, b.field) } -// rewriteInterfaceContainer is part of the Rewrite implementation -func (a *application) rewriteInterfaceContainer(parent AST, node InterfaceContainer, replacer replacerFunc) error { - var err error - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, +// EqualsRefOfValueContainer does deep equals between the two objects. +func EqualsRefOfValueContainer(a, b *ValueContainer) bool { + if a == b { + return true } - if a.pre != nil && !a.pre(&cur) { - return nil + if a == nil || b == nil { + return false } - if err != nil { - return err + return a.NotASTType == b.NotASTType && + EqualsAST(a.ASTType, b.ASTType) && + EqualsRefOfLeaf(a.ASTImplementationType, b.ASTImplementationType) +} + +// EqualsRefOfValueSliceContainer does deep equals between the two objects. +func EqualsRefOfValueSliceContainer(a, b *ValueSliceContainer) bool { + if a == b { + return true } - if a.post != nil && !a.post(&cur) { - return errAbort + if a == nil || b == nil { + return false } - return nil + return EqualsSliceOfAST(a.ASTElements, b.ASTElements) && + EqualsSliceOfInt(a.NotASTElements, b.NotASTElements) && + EqualsSliceOfRefOfLeaf(a.ASTImplementationElements, b.ASTImplementationElements) } -// EqualsInterfaceSlice does deep equals between the two objects. -func EqualsInterfaceSlice(a, b InterfaceSlice) bool { +// EqualsSliceOfAST does deep equals between the two objects. +func EqualsSliceOfAST(a, b []AST) bool { if len(a) != len(b) { return false } @@ -313,131 +479,115 @@ func EqualsInterfaceSlice(a, b InterfaceSlice) bool { return true } -// CloneInterfaceSlice creates a deep clone of the input. -func CloneInterfaceSlice(n InterfaceSlice) InterfaceSlice { - res := make(InterfaceSlice, 0, len(n)) - for _, x := range n { - res = append(res, CloneAST(x)) - } - return res -} - -// VisitInterfaceSlice will visit all parts of the AST -func VisitInterfaceSlice(in InterfaceSlice, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsSliceOfInt does deep equals between the two objects. +func EqualsSliceOfInt(a, b []int) bool { + if len(a) != len(b) { + return false } - for _, el := range in { - if err := VisitAST(el, f); err != nil { - return err + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false } } - return nil + return true } -// rewriteInterfaceSlice is part of the Rewrite implementation -func (a *application) rewriteInterfaceSlice(parent AST, node InterfaceSlice, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil +// EqualsSliceOfRefOfLeaf does deep equals between the two objects. +func EqualsSliceOfRefOfLeaf(a, b []*Leaf) bool { + if len(a) != len(b) { + return false } - for i, el := range node { - if errF := a.rewriteAST(node, el, func(newNode, parent AST) { - parent.(InterfaceSlice)[i] = newNode.(AST) - }); errF != nil { - return errF + for i := 0; i < len(a); i++ { + if !EqualsRefOfLeaf(a[i], b[i]) { + return false } } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + return true } -// EqualsRefOfLeaf does deep equals between the two objects. -func EqualsRefOfLeaf(a, b *Leaf) bool { - if a == b { +// EqualsSubIface does deep equals between the two objects. +func EqualsSubIface(inA, inB SubIface) bool { + if inA == nil && inB == nil { return true } - if a == nil || b == nil { + if inA == nil || inB == nil { return false } - return a.v == b.v -} - -// CloneRefOfLeaf creates a deep clone of the input. -func CloneRefOfLeaf(n *Leaf) *Leaf { - if n == nil { - return nil + switch a := inA.(type) { + case *SubImpl: + b, ok := inB.(*SubImpl) + if !ok { + return false + } + return EqualsRefOfSubImpl(a, b) + default: + // this should never happen + return false } - out := *n - return &out } -// VisitRefOfLeaf will visit all parts of the AST -func VisitRefOfLeaf(in *Leaf, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil +// EqualsValueContainer does deep equals between the two objects. +func EqualsValueContainer(a, b ValueContainer) bool { + return a.NotASTType == b.NotASTType && + EqualsAST(a.ASTType, b.ASTType) && + EqualsRefOfLeaf(a.ASTImplementationType, b.ASTImplementationType) } -// rewriteRefOfLeaf is part of the Rewrite implementation -func (a *application) rewriteRefOfLeaf(parent AST, node *Leaf, replacer replacerFunc) error { - if node == nil { +// EqualsValueSliceContainer does deep equals between the two objects. +func EqualsValueSliceContainer(a, b ValueSliceContainer) bool { + return EqualsSliceOfAST(a.ASTElements, b.ASTElements) && + EqualsSliceOfInt(a.NotASTElements, b.NotASTElements) && + EqualsSliceOfRefOfLeaf(a.ASTImplementationElements, b.ASTImplementationElements) +} +func VisitAST(in AST, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + switch in := in.(type) { + case BasicType: + return VisitBasicType(in, f) + case Bytes: + return VisitBytes(in, f) + case InterfaceContainer: + return VisitInterfaceContainer(in, f) + case InterfaceSlice: + return VisitInterfaceSlice(in, f) + case *Leaf: + return VisitRefOfLeaf(in, f) + case LeafSlice: + return VisitLeafSlice(in, f) + case *NoCloneType: + return VisitRefOfNoCloneType(in, f) + case *RefContainer: + return VisitRefOfRefContainer(in, f) + case *RefSliceContainer: + return VisitRefOfRefSliceContainer(in, f) + case *SubImpl: + return VisitRefOfSubImpl(in, f) + case ValueContainer: + return VisitValueContainer(in, f) + case ValueSliceContainer: + return VisitValueSliceContainer(in, f) + default: + // this should never happen return nil } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil } - -// EqualsLeafSlice does deep equals between the two objects. -func EqualsLeafSlice(a, b LeafSlice) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsRefOfLeaf(a[i], b[i]) { - return false - } - } - return true +func VisitBasicType(in BasicType, f Visit) error { + _, err := f(in) + return err } - -// CloneLeafSlice creates a deep clone of the input. -func CloneLeafSlice(n LeafSlice) LeafSlice { - res := make(LeafSlice, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfLeaf(x)) +func VisitBytes(in Bytes, f Visit) error { + _, err := f(in) + return err +} +func VisitInterfaceContainer(in InterfaceContainer, f Visit) error { + if cont, err := f(in); err != nil || !cont { + return err } - return res + return nil } - -// VisitLeafSlice will visit all parts of the AST -func VisitLeafSlice(in LeafSlice, f Visit) error { +func VisitInterfaceSlice(in InterfaceSlice, f Visit) error { if in == nil { return nil } @@ -445,57 +595,27 @@ func VisitLeafSlice(in LeafSlice, f Visit) error { return err } for _, el := range in { - if err := VisitRefOfLeaf(el, f); err != nil { + if err := VisitAST(el, f); err != nil { return err } } return nil } - -// rewriteLeafSlice is part of the Rewrite implementation -func (a *application) rewriteLeafSlice(parent AST, node LeafSlice, replacer replacerFunc) error { - if node == nil { +func VisitLeafSlice(in LeafSlice, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil + if cont, err := f(in); err != nil || !cont { + return err } - for i, el := range node { - if errF := a.rewriteRefOfLeaf(node, el, func(newNode, parent AST) { - parent.(LeafSlice)[i] = newNode.(*Leaf) - }); errF != nil { - return errF + for _, el := range in { + if err := VisitRefOfLeaf(el, f); err != nil { + return err } } - if a.post != nil && !a.post(&cur) { - return errAbort - } return nil } - -// EqualsRefOfNoCloneType does deep equals between the two objects. -func EqualsRefOfNoCloneType(a, b *NoCloneType) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.v == b.v -} - -// CloneRefOfNoCloneType creates a deep clone of the input. -func CloneRefOfNoCloneType(n *NoCloneType) *NoCloneType { - return n -} - -// VisitRefOfNoCloneType will visit all parts of the AST -func VisitRefOfNoCloneType(in *NoCloneType, f Visit) error { +func VisitRefOfInterfaceContainer(in *InterfaceContainer, f Visit) error { if in == nil { return nil } @@ -504,51 +624,24 @@ func VisitRefOfNoCloneType(in *NoCloneType, f Visit) error { } return nil } - -// rewriteRefOfNoCloneType is part of the Rewrite implementation -func (a *application) rewriteRefOfNoCloneType(parent AST, node *NoCloneType, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +func VisitRefOfLeaf(in *Leaf, f Visit) error { + if in == nil { return nil } - if a.post != nil && !a.post(&cur) { - return errAbort + if cont, err := f(in); err != nil || !cont { + return err } return nil } - -// EqualsRefOfRefContainer does deep equals between the two objects. -func EqualsRefOfRefContainer(a, b *RefContainer) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.NotASTType == b.NotASTType && - EqualsAST(a.ASTType, b.ASTType) && - EqualsRefOfLeaf(a.ASTImplementationType, b.ASTImplementationType) -} - -// CloneRefOfRefContainer creates a deep clone of the input. -func CloneRefOfRefContainer(n *RefContainer) *RefContainer { - if n == nil { +func VisitRefOfNoCloneType(in *NoCloneType, f Visit) error { + if in == nil { return nil } - out := *n - out.ASTType = CloneAST(n.ASTType) - out.ASTImplementationType = CloneRefOfLeaf(n.ASTImplementationType) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfRefContainer will visit all parts of the AST func VisitRefOfRefContainer(in *RefContainer, f Visit) error { if in == nil { return nil @@ -564,63 +657,53 @@ func VisitRefOfRefContainer(in *RefContainer, f Visit) error { } return nil } - -// rewriteRefOfRefContainer is part of the Rewrite implementation -func (a *application) rewriteRefOfRefContainer(parent AST, node *RefContainer, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +func VisitRefOfRefSliceContainer(in *RefSliceContainer, f Visit) error { + if in == nil { return nil } - if errF := a.rewriteAST(node, node.ASTType, func(newNode, parent AST) { - parent.(*RefContainer).ASTType = newNode.(AST) - }); errF != nil { - return errF + if cont, err := f(in); err != nil || !cont { + return err } - if errF := a.rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { - parent.(*RefContainer).ASTImplementationType = newNode.(*Leaf) - }); errF != nil { - return errF + for _, el := range in.ASTElements { + if err := VisitAST(el, f); err != nil { + return err + } } - if a.post != nil && !a.post(&cur) { - return errAbort + for _, el := range in.ASTImplementationElements { + if err := VisitRefOfLeaf(el, f); err != nil { + return err + } } return nil } - -// EqualsRefOfRefSliceContainer does deep equals between the two objects. -func EqualsRefOfRefSliceContainer(a, b *RefSliceContainer) bool { - if a == b { - return true +func VisitRefOfSubImpl(in *SubImpl, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsSliceOfAST(a.ASTElements, b.ASTElements) && - EqualsSliceOfInt(a.NotASTElements, b.NotASTElements) && - EqualsSliceOfRefOfLeaf(a.ASTImplementationElements, b.ASTImplementationElements) + if err := VisitSubIface(in.inner, f); err != nil { + return err + } + return nil } - -// CloneRefOfRefSliceContainer creates a deep clone of the input. -func CloneRefOfRefSliceContainer(n *RefSliceContainer) *RefSliceContainer { - if n == nil { +func VisitRefOfValueContainer(in *ValueContainer, f Visit) error { + if in == nil { return nil } - out := *n - out.ASTElements = CloneSliceOfAST(n.ASTElements) - out.NotASTElements = CloneSliceOfInt(n.NotASTElements) - out.ASTImplementationElements = CloneSliceOfRefOfLeaf(n.ASTImplementationElements) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitAST(in.ASTType, f); err != nil { + return err + } + if err := VisitRefOfLeaf(in.ASTImplementationType, f); err != nil { + return err + } + return nil } - -// VisitRefOfRefSliceContainer will visit all parts of the AST -func VisitRefOfRefSliceContainer(in *RefSliceContainer, f Visit) error { +func VisitRefOfValueSliceContainer(in *ValueSliceContainer, f Visit) error { if in == nil { return nil } @@ -639,79 +722,95 @@ func VisitRefOfRefSliceContainer(in *RefSliceContainer, f Visit) error { } return nil } - -// rewriteRefOfRefSliceContainer is part of the Rewrite implementation -func (a *application) rewriteRefOfRefSliceContainer(parent AST, node *RefSliceContainer, replacer replacerFunc) error { - if node == nil { +func VisitSubIface(in SubIface, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + switch in := in.(type) { + case *SubImpl: + return VisitRefOfSubImpl(in, f) + default: + // this should never happen return nil } - for i, el := range node.ASTElements { - if errF := a.rewriteAST(node, el, func(newNode, parent AST) { - parent.(*RefSliceContainer).ASTElements[i] = newNode.(AST) - }); errF != nil { - return errF - } +} +func VisitValueContainer(in ValueContainer, f Visit) error { + if cont, err := f(in); err != nil || !cont { + return err } - for i, el := range node.ASTImplementationElements { - if errF := a.rewriteRefOfLeaf(node, el, func(newNode, parent AST) { - parent.(*RefSliceContainer).ASTImplementationElements[i] = newNode.(*Leaf) - }); errF != nil { - return errF - } + if err := VisitAST(in.ASTType, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitRefOfLeaf(in.ASTImplementationType, f); err != nil { + return err } return nil } - -// EqualsRefOfSubImpl does deep equals between the two objects. -func EqualsRefOfSubImpl(a, b *SubImpl) bool { - if a == b { - return true +func VisitValueSliceContainer(in ValueSliceContainer, f Visit) error { + if cont, err := f(in); err != nil || !cont { + return err } - if a == nil || b == nil { - return false + for _, el := range in.ASTElements { + if err := VisitAST(el, f); err != nil { + return err + } } - return EqualsSubIface(a.inner, b.inner) && - EqualsRefOfBool(a.field, b.field) + for _, el := range in.ASTImplementationElements { + if err := VisitRefOfLeaf(el, f); err != nil { + return err + } + } + return nil } - -// CloneRefOfSubImpl creates a deep clone of the input. -func CloneRefOfSubImpl(n *SubImpl) *SubImpl { - if n == nil { +func (a *application) rewriteAST(parent AST, node AST, replacer replacerFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case BasicType: + return a.rewriteBasicType(parent, node, replacer) + case Bytes: + return a.rewriteBytes(parent, node, replacer) + case InterfaceContainer: + return a.rewriteInterfaceContainer(parent, node, replacer) + case InterfaceSlice: + return a.rewriteInterfaceSlice(parent, node, replacer) + case *Leaf: + return a.rewriteRefOfLeaf(parent, node, replacer) + case LeafSlice: + return a.rewriteLeafSlice(parent, node, replacer) + case *NoCloneType: + return a.rewriteRefOfNoCloneType(parent, node, replacer) + case *RefContainer: + return a.rewriteRefOfRefContainer(parent, node, replacer) + case *RefSliceContainer: + return a.rewriteRefOfRefSliceContainer(parent, node, replacer) + case *SubImpl: + return a.rewriteRefOfSubImpl(parent, node, replacer) + case ValueContainer: + return a.rewriteValueContainer(parent, node, replacer) + case ValueSliceContainer: + return a.rewriteValueSliceContainer(parent, node, replacer) + default: + // this should never happen return nil } - out := *n - out.inner = CloneSubIface(n.inner) - out.field = CloneRefOfBool(n.field) - return &out } - -// VisitRefOfSubImpl will visit all parts of the AST -func VisitRefOfSubImpl(in *SubImpl, f Visit) error { - if in == nil { - return nil +func (a *application) rewriteBasicType(parent AST, node BasicType, replacer replacerFunc) error { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if cont, err := f(in); err != nil || !cont { - return err + if a.pre != nil && !a.pre(&cur) { + return nil } - if err := VisitSubIface(in.inner, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfSubImpl is part of the Rewrite implementation -func (a *application) rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer replacerFunc) error { +func (a *application) rewriteBytes(parent AST, node Bytes, replacer replacerFunc) error { if node == nil { return nil } @@ -723,45 +822,12 @@ func (a *application) rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer re if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteSubIface(node, node.inner, func(newNode, parent AST) { - parent.(*SubImpl).inner = newNode.(SubIface) - }); errF != nil { - return errF - } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsValueContainer does deep equals between the two objects. -func EqualsValueContainer(a, b ValueContainer) bool { - return a.NotASTType == b.NotASTType && - EqualsAST(a.ASTType, b.ASTType) && - EqualsRefOfLeaf(a.ASTImplementationType, b.ASTImplementationType) -} - -// CloneValueContainer creates a deep clone of the input. -func CloneValueContainer(n ValueContainer) ValueContainer { - return *CloneRefOfValueContainer(&n) -} - -// VisitValueContainer will visit all parts of the AST -func VisitValueContainer(in ValueContainer, f Visit) error { - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitAST(in.ASTType, f); err != nil { - return err - } - if err := VisitRefOfLeaf(in.ASTImplementationType, f); err != nil { - return err - } - return nil -} - -// rewriteValueContainer is part of the Rewrite implementation -func (a *application) rewriteValueContainer(parent AST, node ValueContainer, replacer replacerFunc) error { +func (a *application) rewriteInterfaceContainer(parent AST, node InterfaceContainer, replacer replacerFunc) error { var err error cur := Cursor{ node: node, @@ -771,16 +837,6 @@ func (a *application) rewriteValueContainer(parent AST, node ValueContainer, rep if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteAST(node, node.ASTType, func(newNode, parent AST) { - err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTType' on 'ValueContainer'") - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { - err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTImplementationType' on 'ValueContainer'") - }); errF != nil { - return errF - } if err != nil { return err } @@ -789,40 +845,10 @@ func (a *application) rewriteValueContainer(parent AST, node ValueContainer, rep } return nil } - -// EqualsValueSliceContainer does deep equals between the two objects. -func EqualsValueSliceContainer(a, b ValueSliceContainer) bool { - return EqualsSliceOfAST(a.ASTElements, b.ASTElements) && - EqualsSliceOfInt(a.NotASTElements, b.NotASTElements) && - EqualsSliceOfRefOfLeaf(a.ASTImplementationElements, b.ASTImplementationElements) -} - -// CloneValueSliceContainer creates a deep clone of the input. -func CloneValueSliceContainer(n ValueSliceContainer) ValueSliceContainer { - return *CloneRefOfValueSliceContainer(&n) -} - -// VisitValueSliceContainer will visit all parts of the AST -func VisitValueSliceContainer(in ValueSliceContainer, f Visit) error { - if cont, err := f(in); err != nil || !cont { - return err - } - for _, el := range in.ASTElements { - if err := VisitAST(el, f); err != nil { - return err - } - } - for _, el := range in.ASTImplementationElements { - if err := VisitRefOfLeaf(el, f); err != nil { - return err - } +func (a *application) rewriteInterfaceSlice(parent AST, node InterfaceSlice, replacer replacerFunc) error { + if node == nil { + return nil } - return nil -} - -// rewriteValueSliceContainer is part of the Rewrite implementation -func (a *application) rewriteValueSliceContainer(parent AST, node ValueSliceContainer, replacer replacerFunc) error { - var err error cur := Cursor{ node: node, parent: parent, @@ -831,100 +857,80 @@ func (a *application) rewriteValueSliceContainer(parent AST, node ValueSliceCont if a.pre != nil && !a.pre(&cur) { return nil } - for _, el := range node.ASTElements { + for i, el := range node { if errF := a.rewriteAST(node, el, func(newNode, parent AST) { - err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTElements' on 'ValueSliceContainer'") + parent.(InterfaceSlice)[i] = newNode.(AST) }); errF != nil { return errF } } - for _, el := range node.ASTImplementationElements { + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil +} +func (a *application) rewriteLeafSlice(parent AST, node LeafSlice, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + for i, el := range node { if errF := a.rewriteRefOfLeaf(node, el, func(newNode, parent AST) { - err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTImplementationElements' on 'ValueSliceContainer'") + parent.(LeafSlice)[i] = newNode.(*Leaf) }); errF != nil { return errF } } - if err != nil { - return err - } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsSubIface does deep equals between the two objects. -func EqualsSubIface(inA, inB SubIface) bool { - if inA == nil && inB == nil { - return true - } - if inA == nil || inB == nil { - return false +func (a *application) rewriteRefOfInterfaceContainer(parent AST, node *InterfaceContainer, replacer replacerFunc) error { + if node == nil { + return nil } - switch a := inA.(type) { - case *SubImpl: - b, ok := inB.(*SubImpl) - if !ok { - return false - } - return EqualsRefOfSubImpl(a, b) - default: - // this should never happen - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } -} - -// CloneSubIface creates a deep clone of the input. -func CloneSubIface(in SubIface) SubIface { - if in == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - switch in := in.(type) { - case *SubImpl: - return CloneRefOfSubImpl(in) - default: - // this should never happen - return nil + if a.post != nil && !a.post(&cur) { + return errAbort } + return nil } - -// VisitSubIface will visit all parts of the AST -func VisitSubIface(in SubIface, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfLeaf(parent AST, node *Leaf, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *SubImpl: - return VisitRefOfSubImpl(in, f) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// rewriteSubIface is part of the Rewrite implementation -func (a *application) rewriteSubIface(parent AST, node SubIface, replacer replacerFunc) error { +func (a *application) rewriteRefOfNoCloneType(parent AST, node *NoCloneType, replacer replacerFunc) error { if node == nil { return nil } - switch node := node.(type) { - case *SubImpl: - return a.rewriteRefOfSubImpl(parent, node, replacer) - default: - // this should never happen - return nil - } -} - -// VisitBasicType will visit all parts of the AST -func VisitBasicType(in BasicType, f Visit) error { - _, err := f(in) - return err -} - -// rewriteBasicType is part of the Rewrite implementation -func (a *application) rewriteBasicType(parent AST, node BasicType, replacer replacerFunc) error { cur := Cursor{ node: node, parent: parent, @@ -938,41 +944,34 @@ func (a *application) rewriteBasicType(parent AST, node BasicType, replacer repl } return nil } - -// EqualsRefOfInterfaceContainer does deep equals between the two objects. -func EqualsRefOfInterfaceContainer(a, b *InterfaceContainer) bool { - if a == b { - return true +func (a *application) rewriteRefOfRefContainer(parent AST, node *RefContainer, replacer replacerFunc) error { + if node == nil { + return nil } - if a == nil || b == nil { - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return true -} - -// CloneRefOfInterfaceContainer creates a deep clone of the input. -func CloneRefOfInterfaceContainer(n *InterfaceContainer) *InterfaceContainer { - if n == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - out := *n - out.v = n.v - return &out -} - -// VisitRefOfInterfaceContainer will visit all parts of the AST -func VisitRefOfInterfaceContainer(in *InterfaceContainer, f Visit) error { - if in == nil { - return nil + if errF := a.rewriteAST(node, node.ASTType, func(newNode, parent AST) { + parent.(*RefContainer).ASTType = newNode.(AST) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { + parent.(*RefContainer).ASTImplementationType = newNode.(*Leaf) + }); errF != nil { + return errF } - if cont, err := f(in); err != nil || !cont { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfInterfaceContainer is part of the Rewrite implementation -func (a *application) rewriteRefOfInterfaceContainer(parent AST, node *InterfaceContainer, replacer replacerFunc) error { +func (a *application) rewriteRefOfRefSliceContainer(parent AST, node *RefSliceContainer, replacer replacerFunc) error { if node == nil { return nil } @@ -984,138 +983,47 @@ func (a *application) rewriteRefOfInterfaceContainer(parent AST, node *Interface if a.pre != nil && !a.pre(&cur) { return nil } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil -} - -// EqualsSliceOfAST does deep equals between the two objects. -func EqualsSliceOfAST(a, b []AST) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsAST(a[i], b[i]) { - return false - } - } - return true -} - -// CloneSliceOfAST creates a deep clone of the input. -func CloneSliceOfAST(n []AST) []AST { - res := make([]AST, 0, len(n)) - for _, x := range n { - res = append(res, CloneAST(x)) - } - return res -} - -// EqualsSliceOfInt does deep equals between the two objects. -func EqualsSliceOfInt(a, b []int) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if a[i] != b[i] { - return false + for i, el := range node.ASTElements { + if errF := a.rewriteAST(node, el, func(newNode, parent AST) { + parent.(*RefSliceContainer).ASTElements[i] = newNode.(AST) + }); errF != nil { + return errF } } - return true -} - -// CloneSliceOfInt creates a deep clone of the input. -func CloneSliceOfInt(n []int) []int { - res := make([]int, 0, len(n)) - copy(res, n) - return res -} - -// EqualsSliceOfRefOfLeaf does deep equals between the two objects. -func EqualsSliceOfRefOfLeaf(a, b []*Leaf) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsRefOfLeaf(a[i], b[i]) { - return false + for i, el := range node.ASTImplementationElements { + if errF := a.rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + parent.(*RefSliceContainer).ASTImplementationElements[i] = newNode.(*Leaf) + }); errF != nil { + return errF } } - return true -} - -// CloneSliceOfRefOfLeaf creates a deep clone of the input. -func CloneSliceOfRefOfLeaf(n []*Leaf) []*Leaf { - res := make([]*Leaf, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfLeaf(x)) - } - return res -} - -// EqualsRefOfBool does deep equals between the two objects. -func EqualsRefOfBool(a, b *bool) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false + if a.post != nil && !a.post(&cur) { + return errAbort } - return *a == *b + return nil } - -// CloneRefOfBool creates a deep clone of the input. -func CloneRefOfBool(n *bool) *bool { - if n == nil { +func (a *application) rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - return &out -} - -// EqualsRefOfValueContainer does deep equals between the two objects. -func EqualsRefOfValueContainer(a, b *ValueContainer) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.NotASTType == b.NotASTType && - EqualsAST(a.ASTType, b.ASTType) && - EqualsRefOfLeaf(a.ASTImplementationType, b.ASTImplementationType) -} - -// CloneRefOfValueContainer creates a deep clone of the input. -func CloneRefOfValueContainer(n *ValueContainer) *ValueContainer { - if n == nil { - return nil + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - out := *n - out.ASTType = CloneAST(n.ASTType) - out.ASTImplementationType = CloneRefOfLeaf(n.ASTImplementationType) - return &out -} - -// VisitRefOfValueContainer will visit all parts of the AST -func VisitRefOfValueContainer(in *ValueContainer, f Visit) error { - if in == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitAST(in.ASTType, f); err != nil { - return err + if errF := a.rewriteSubIface(node, node.inner, func(newNode, parent AST) { + parent.(*SubImpl).inner = newNode.(SubIface) + }); errF != nil { + return errF } - if err := VisitRefOfLeaf(in.ASTImplementationType, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfValueContainer is part of the Rewrite implementation func (a *application) rewriteRefOfValueContainer(parent AST, node *ValueContainer, replacer replacerFunc) error { if node == nil { return nil @@ -1143,58 +1051,51 @@ func (a *application) rewriteRefOfValueContainer(parent AST, node *ValueContaine } return nil } - -// EqualsRefOfValueSliceContainer does deep equals between the two objects. -func EqualsRefOfValueSliceContainer(a, b *ValueSliceContainer) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsSliceOfAST(a.ASTElements, b.ASTElements) && - EqualsSliceOfInt(a.NotASTElements, b.NotASTElements) && - EqualsSliceOfRefOfLeaf(a.ASTImplementationElements, b.ASTImplementationElements) -} - -// CloneRefOfValueSliceContainer creates a deep clone of the input. -func CloneRefOfValueSliceContainer(n *ValueSliceContainer) *ValueSliceContainer { - if n == nil { +func (a *application) rewriteRefOfValueSliceContainer(parent AST, node *ValueSliceContainer, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.ASTElements = CloneSliceOfAST(n.ASTElements) - out.NotASTElements = CloneSliceOfInt(n.NotASTElements) - out.ASTImplementationElements = CloneSliceOfRefOfLeaf(n.ASTImplementationElements) - return &out -} - -// VisitRefOfValueSliceContainer will visit all parts of the AST -func VisitRefOfValueSliceContainer(in *ValueSliceContainer, f Visit) error { - if in == nil { - return nil + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if cont, err := f(in); err != nil || !cont { - return err + if a.pre != nil && !a.pre(&cur) { + return nil } - for _, el := range in.ASTElements { - if err := VisitAST(el, f); err != nil { - return err + for i, el := range node.ASTElements { + if errF := a.rewriteAST(node, el, func(newNode, parent AST) { + parent.(*ValueSliceContainer).ASTElements[i] = newNode.(AST) + }); errF != nil { + return errF } } - for _, el := range in.ASTImplementationElements { - if err := VisitRefOfLeaf(el, f); err != nil { - return err + for i, el := range node.ASTImplementationElements { + if errF := a.rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + parent.(*ValueSliceContainer).ASTImplementationElements[i] = newNode.(*Leaf) + }); errF != nil { + return errF } } + if a.post != nil && !a.post(&cur) { + return errAbort + } return nil } - -// rewriteRefOfValueSliceContainer is part of the Rewrite implementation -func (a *application) rewriteRefOfValueSliceContainer(parent AST, node *ValueSliceContainer, replacer replacerFunc) error { +func (a *application) rewriteSubIface(parent AST, node SubIface, replacer replacerFunc) error { if node == nil { return nil } + switch node := node.(type) { + case *SubImpl: + return a.rewriteRefOfSubImpl(parent, node, replacer) + default: + // this should never happen + return nil + } +} +func (a *application) rewriteValueContainer(parent AST, node ValueContainer, replacer replacerFunc) error { + var err error cur := Cursor{ node: node, parent: parent, @@ -1203,20 +1104,51 @@ func (a *application) rewriteRefOfValueSliceContainer(parent AST, node *ValueSli if a.pre != nil && !a.pre(&cur) { return nil } - for i, el := range node.ASTElements { + if errF := a.rewriteAST(node, node.ASTType, func(newNode, parent AST) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTType' on 'ValueContainer'") + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTImplementationType' on 'ValueContainer'") + }); errF != nil { + return errF + } + if err != nil { + return err + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil +} +func (a *application) rewriteValueSliceContainer(parent AST, node ValueSliceContainer, replacer replacerFunc) error { + var err error + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + for _, el := range node.ASTElements { if errF := a.rewriteAST(node, el, func(newNode, parent AST) { - parent.(*ValueSliceContainer).ASTElements[i] = newNode.(AST) + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTElements' on 'ValueSliceContainer'") }); errF != nil { return errF } } - for i, el := range node.ASTImplementationElements { + for _, el := range node.ASTImplementationElements { if errF := a.rewriteRefOfLeaf(node, el, func(newNode, parent AST) { - parent.(*ValueSliceContainer).ASTImplementationElements[i] = newNode.(*Leaf) + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTImplementationElements' on 'ValueSliceContainer'") }); errF != nil { return errF } } + if err != nil { + return err + } if a.post != nil && !a.post(&cur) { return errAbort } diff --git a/go/vt/sqlparser/ast_helper.go b/go/vt/sqlparser/ast_helper.go index 593a4570490..1f45bd76d86 100644 --- a/go/vt/sqlparser/ast_helper.go +++ b/go/vt/sqlparser/ast_helper.go @@ -22,2858 +22,1160 @@ import ( vterrors "vitess.io/vitess/go/vt/vterrors" ) -// EqualsSQLNode does deep equals between the two objects. -func EqualsSQLNode(inA, inB SQLNode) bool { - if inA == nil && inB == nil { - return true - } - if inA == nil || inB == nil { - return false +// CloneAlterOption creates a deep clone of the input. +func CloneAlterOption(in AlterOption) AlterOption { + if in == nil { + return nil } - switch a := inA.(type) { - case AccessMode: - b, ok := inB.(AccessMode) - if !ok { - return false - } - return a == b + switch in := in.(type) { case *AddColumns: - b, ok := inB.(*AddColumns) - if !ok { - return false - } - return EqualsRefOfAddColumns(a, b) + return CloneRefOfAddColumns(in) case *AddConstraintDefinition: - b, ok := inB.(*AddConstraintDefinition) - if !ok { - return false - } - return EqualsRefOfAddConstraintDefinition(a, b) + return CloneRefOfAddConstraintDefinition(in) case *AddIndexDefinition: - b, ok := inB.(*AddIndexDefinition) - if !ok { - return false - } - return EqualsRefOfAddIndexDefinition(a, b) + return CloneRefOfAddIndexDefinition(in) case AlgorithmValue: - b, ok := inB.(AlgorithmValue) - if !ok { - return false - } - return a == b - case *AliasedExpr: - b, ok := inB.(*AliasedExpr) - if !ok { - return false - } - return EqualsRefOfAliasedExpr(a, b) - case *AliasedTableExpr: - b, ok := inB.(*AliasedTableExpr) - if !ok { - return false - } - return EqualsRefOfAliasedTableExpr(a, b) + return in case *AlterCharset: - b, ok := inB.(*AlterCharset) - if !ok { - return false - } - return EqualsRefOfAlterCharset(a, b) + return CloneRefOfAlterCharset(in) case *AlterColumn: - b, ok := inB.(*AlterColumn) - if !ok { - return false - } - return EqualsRefOfAlterColumn(a, b) + return CloneRefOfAlterColumn(in) + case *ChangeColumn: + return CloneRefOfChangeColumn(in) + case *DropColumn: + return CloneRefOfDropColumn(in) + case *DropKey: + return CloneRefOfDropKey(in) + case *Force: + return CloneRefOfForce(in) + case *KeyState: + return CloneRefOfKeyState(in) + case *LockOption: + return CloneRefOfLockOption(in) + case *ModifyColumn: + return CloneRefOfModifyColumn(in) + case *OrderByOption: + return CloneRefOfOrderByOption(in) + case *RenameIndex: + return CloneRefOfRenameIndex(in) + case *RenameTableName: + return CloneRefOfRenameTableName(in) + case TableOptions: + return CloneTableOptions(in) + case *TablespaceOperation: + return CloneRefOfTablespaceOperation(in) + case *Validation: + return CloneRefOfValidation(in) + default: + // this should never happen + return nil + } +} + +// CloneCharacteristic creates a deep clone of the input. +func CloneCharacteristic(in Characteristic) Characteristic { + if in == nil { + return nil + } + switch in := in.(type) { + case AccessMode: + return in + case IsolationLevel: + return in + default: + // this should never happen + return nil + } +} + +// CloneColIdent creates a deep clone of the input. +func CloneColIdent(n ColIdent) ColIdent { + return *CloneRefOfColIdent(&n) +} + +// CloneColTuple creates a deep clone of the input. +func CloneColTuple(in ColTuple) ColTuple { + if in == nil { + return nil + } + switch in := in.(type) { + case ListArg: + return CloneListArg(in) + case *Subquery: + return CloneRefOfSubquery(in) + case ValTuple: + return CloneValTuple(in) + default: + // this should never happen + return nil + } +} + +// CloneCollateAndCharset creates a deep clone of the input. +func CloneCollateAndCharset(n CollateAndCharset) CollateAndCharset { + return *CloneRefOfCollateAndCharset(&n) +} + +// CloneColumnType creates a deep clone of the input. +func CloneColumnType(n ColumnType) ColumnType { + return *CloneRefOfColumnType(&n) +} + +// CloneColumns creates a deep clone of the input. +func CloneColumns(n Columns) Columns { + res := make(Columns, 0, len(n)) + for _, x := range n { + res = append(res, CloneColIdent(x)) + } + return res +} + +// CloneComments creates a deep clone of the input. +func CloneComments(n Comments) Comments { + res := make(Comments, 0, len(n)) + copy(res, n) + return res +} + +// CloneConstraintInfo creates a deep clone of the input. +func CloneConstraintInfo(in ConstraintInfo) ConstraintInfo { + if in == nil { + return nil + } + switch in := in.(type) { + case *CheckConstraintDefinition: + return CloneRefOfCheckConstraintDefinition(in) + case *ForeignKeyDefinition: + return CloneRefOfForeignKeyDefinition(in) + default: + // this should never happen + return nil + } +} + +// CloneDBDDLStatement creates a deep clone of the input. +func CloneDBDDLStatement(in DBDDLStatement) DBDDLStatement { + if in == nil { + return nil + } + switch in := in.(type) { case *AlterDatabase: - b, ok := inB.(*AlterDatabase) - if !ok { - return false - } - return EqualsRefOfAlterDatabase(a, b) - case *AlterMigration: - b, ok := inB.(*AlterMigration) - if !ok { - return false - } - return EqualsRefOfAlterMigration(a, b) + return CloneRefOfAlterDatabase(in) + case *CreateDatabase: + return CloneRefOfCreateDatabase(in) + case *DropDatabase: + return CloneRefOfDropDatabase(in) + default: + // this should never happen + return nil + } +} + +// CloneDDLStatement creates a deep clone of the input. +func CloneDDLStatement(in DDLStatement) DDLStatement { + if in == nil { + return nil + } + switch in := in.(type) { case *AlterTable: - b, ok := inB.(*AlterTable) - if !ok { - return false - } - return EqualsRefOfAlterTable(a, b) + return CloneRefOfAlterTable(in) case *AlterView: - b, ok := inB.(*AlterView) - if !ok { - return false - } - return EqualsRefOfAlterView(a, b) - case *AlterVschema: - b, ok := inB.(*AlterVschema) - if !ok { - return false - } - return EqualsRefOfAlterVschema(a, b) - case *AndExpr: - b, ok := inB.(*AndExpr) - if !ok { - return false - } - return EqualsRefOfAndExpr(a, b) - case Argument: - b, ok := inB.(Argument) - if !ok { - return false - } - return a == b - case *AutoIncSpec: - b, ok := inB.(*AutoIncSpec) - if !ok { - return false - } - return EqualsRefOfAutoIncSpec(a, b) - case *Begin: - b, ok := inB.(*Begin) - if !ok { - return false - } - return EqualsRefOfBegin(a, b) - case *BinaryExpr: - b, ok := inB.(*BinaryExpr) - if !ok { - return false - } - return EqualsRefOfBinaryExpr(a, b) + return CloneRefOfAlterView(in) + case *CreateTable: + return CloneRefOfCreateTable(in) + case *CreateView: + return CloneRefOfCreateView(in) + case *DropTable: + return CloneRefOfDropTable(in) + case *DropView: + return CloneRefOfDropView(in) + case *RenameTable: + return CloneRefOfRenameTable(in) + case *TruncateTable: + return CloneRefOfTruncateTable(in) + default: + // this should never happen + return nil + } +} + +// CloneExplain creates a deep clone of the input. +func CloneExplain(in Explain) Explain { + if in == nil { + return nil + } + switch in := in.(type) { + case *ExplainStmt: + return CloneRefOfExplainStmt(in) + case *ExplainTab: + return CloneRefOfExplainTab(in) + default: + // this should never happen + return nil + } +} + +// CloneExpr creates a deep clone of the input. +func CloneExpr(in Expr) Expr { + if in == nil { + return nil + } + switch in := in.(type) { + case *AndExpr: + return CloneRefOfAndExpr(in) + case Argument: + return in + case *BinaryExpr: + return CloneRefOfBinaryExpr(in) case BoolVal: - b, ok := inB.(BoolVal) - if !ok { - return false - } - return a == b - case *CallProc: - b, ok := inB.(*CallProc) - if !ok { - return false - } - return EqualsRefOfCallProc(a, b) + return in case *CaseExpr: - b, ok := inB.(*CaseExpr) - if !ok { - return false - } - return EqualsRefOfCaseExpr(a, b) - case *ChangeColumn: - b, ok := inB.(*ChangeColumn) - if !ok { - return false - } - return EqualsRefOfChangeColumn(a, b) - case *CheckConstraintDefinition: - b, ok := inB.(*CheckConstraintDefinition) - if !ok { - return false - } - return EqualsRefOfCheckConstraintDefinition(a, b) - case ColIdent: - b, ok := inB.(ColIdent) - if !ok { - return false - } - return EqualsColIdent(a, b) + return CloneRefOfCaseExpr(in) case *ColName: - b, ok := inB.(*ColName) - if !ok { - return false - } - return EqualsRefOfColName(a, b) + return CloneRefOfColName(in) case *CollateExpr: - b, ok := inB.(*CollateExpr) - if !ok { - return false - } - return EqualsRefOfCollateExpr(a, b) - case *ColumnDefinition: - b, ok := inB.(*ColumnDefinition) - if !ok { - return false - } - return EqualsRefOfColumnDefinition(a, b) - case *ColumnType: - b, ok := inB.(*ColumnType) - if !ok { - return false - } - return EqualsRefOfColumnType(a, b) - case Columns: - b, ok := inB.(Columns) - if !ok { - return false - } - return EqualsColumns(a, b) - case Comments: - b, ok := inB.(Comments) - if !ok { - return false - } - return EqualsComments(a, b) - case *Commit: - b, ok := inB.(*Commit) - if !ok { - return false - } - return EqualsRefOfCommit(a, b) + return CloneRefOfCollateExpr(in) case *ComparisonExpr: - b, ok := inB.(*ComparisonExpr) - if !ok { - return false - } - return EqualsRefOfComparisonExpr(a, b) - case *ConstraintDefinition: - b, ok := inB.(*ConstraintDefinition) - if !ok { - return false - } - return EqualsRefOfConstraintDefinition(a, b) + return CloneRefOfComparisonExpr(in) case *ConvertExpr: - b, ok := inB.(*ConvertExpr) - if !ok { - return false - } - return EqualsRefOfConvertExpr(a, b) - case *ConvertType: - b, ok := inB.(*ConvertType) - if !ok { - return false - } - return EqualsRefOfConvertType(a, b) + return CloneRefOfConvertExpr(in) case *ConvertUsingExpr: - b, ok := inB.(*ConvertUsingExpr) - if !ok { - return false - } - return EqualsRefOfConvertUsingExpr(a, b) - case *CreateDatabase: - b, ok := inB.(*CreateDatabase) - if !ok { - return false - } - return EqualsRefOfCreateDatabase(a, b) - case *CreateTable: - b, ok := inB.(*CreateTable) - if !ok { - return false - } - return EqualsRefOfCreateTable(a, b) - case *CreateView: - b, ok := inB.(*CreateView) - if !ok { - return false - } - return EqualsRefOfCreateView(a, b) + return CloneRefOfConvertUsingExpr(in) case *CurTimeFuncExpr: - b, ok := inB.(*CurTimeFuncExpr) - if !ok { - return false - } - return EqualsRefOfCurTimeFuncExpr(a, b) + return CloneRefOfCurTimeFuncExpr(in) case *Default: - b, ok := inB.(*Default) - if !ok { - return false - } - return EqualsRefOfDefault(a, b) - case *Delete: - b, ok := inB.(*Delete) - if !ok { - return false - } - return EqualsRefOfDelete(a, b) - case *DerivedTable: - b, ok := inB.(*DerivedTable) - if !ok { - return false - } - return EqualsRefOfDerivedTable(a, b) - case *DropColumn: - b, ok := inB.(*DropColumn) - if !ok { - return false - } - return EqualsRefOfDropColumn(a, b) - case *DropDatabase: - b, ok := inB.(*DropDatabase) - if !ok { - return false - } - return EqualsRefOfDropDatabase(a, b) - case *DropKey: - b, ok := inB.(*DropKey) - if !ok { - return false - } - return EqualsRefOfDropKey(a, b) - case *DropTable: - b, ok := inB.(*DropTable) - if !ok { - return false - } - return EqualsRefOfDropTable(a, b) - case *DropView: - b, ok := inB.(*DropView) - if !ok { - return false - } - return EqualsRefOfDropView(a, b) + return CloneRefOfDefault(in) case *ExistsExpr: - b, ok := inB.(*ExistsExpr) - if !ok { - return false - } - return EqualsRefOfExistsExpr(a, b) - case *ExplainStmt: - b, ok := inB.(*ExplainStmt) - if !ok { - return false - } - return EqualsRefOfExplainStmt(a, b) - case *ExplainTab: - b, ok := inB.(*ExplainTab) - if !ok { - return false - } - return EqualsRefOfExplainTab(a, b) - case Exprs: - b, ok := inB.(Exprs) - if !ok { - return false - } - return EqualsExprs(a, b) - case *Flush: - b, ok := inB.(*Flush) - if !ok { - return false - } - return EqualsRefOfFlush(a, b) - case *Force: - b, ok := inB.(*Force) - if !ok { - return false - } - return EqualsRefOfForce(a, b) - case *ForeignKeyDefinition: - b, ok := inB.(*ForeignKeyDefinition) - if !ok { - return false - } - return EqualsRefOfForeignKeyDefinition(a, b) + return CloneRefOfExistsExpr(in) case *FuncExpr: - b, ok := inB.(*FuncExpr) - if !ok { - return false - } - return EqualsRefOfFuncExpr(a, b) - case GroupBy: - b, ok := inB.(GroupBy) - if !ok { - return false - } - return EqualsGroupBy(a, b) + return CloneRefOfFuncExpr(in) case *GroupConcatExpr: - b, ok := inB.(*GroupConcatExpr) - if !ok { - return false - } - return EqualsRefOfGroupConcatExpr(a, b) - case *IndexDefinition: - b, ok := inB.(*IndexDefinition) - if !ok { - return false - } - return EqualsRefOfIndexDefinition(a, b) - case *IndexHints: - b, ok := inB.(*IndexHints) - if !ok { - return false - } - return EqualsRefOfIndexHints(a, b) - case *IndexInfo: - b, ok := inB.(*IndexInfo) - if !ok { - return false - } - return EqualsRefOfIndexInfo(a, b) - case *Insert: - b, ok := inB.(*Insert) - if !ok { - return false - } - return EqualsRefOfInsert(a, b) + return CloneRefOfGroupConcatExpr(in) case *IntervalExpr: - b, ok := inB.(*IntervalExpr) - if !ok { - return false - } - return EqualsRefOfIntervalExpr(a, b) + return CloneRefOfIntervalExpr(in) case *IsExpr: - b, ok := inB.(*IsExpr) - if !ok { - return false - } - return EqualsRefOfIsExpr(a, b) - case IsolationLevel: - b, ok := inB.(IsolationLevel) - if !ok { - return false - } - return a == b - case JoinCondition: - b, ok := inB.(JoinCondition) - if !ok { - return false - } - return EqualsJoinCondition(a, b) - case *JoinTableExpr: - b, ok := inB.(*JoinTableExpr) - if !ok { - return false - } - return EqualsRefOfJoinTableExpr(a, b) - case *KeyState: - b, ok := inB.(*KeyState) - if !ok { - return false - } - return EqualsRefOfKeyState(a, b) - case *Limit: - b, ok := inB.(*Limit) - if !ok { - return false - } - return EqualsRefOfLimit(a, b) + return CloneRefOfIsExpr(in) case ListArg: - b, ok := inB.(ListArg) - if !ok { - return false - } - return EqualsListArg(a, b) + return CloneListArg(in) case *Literal: - b, ok := inB.(*Literal) - if !ok { - return false - } - return EqualsRefOfLiteral(a, b) - case *Load: - b, ok := inB.(*Load) - if !ok { - return false - } - return EqualsRefOfLoad(a, b) - case *LockOption: - b, ok := inB.(*LockOption) - if !ok { - return false - } - return EqualsRefOfLockOption(a, b) - case *LockTables: - b, ok := inB.(*LockTables) - if !ok { - return false - } - return EqualsRefOfLockTables(a, b) + return CloneRefOfLiteral(in) case *MatchExpr: - b, ok := inB.(*MatchExpr) - if !ok { - return false - } - return EqualsRefOfMatchExpr(a, b) - case *ModifyColumn: - b, ok := inB.(*ModifyColumn) - if !ok { - return false - } - return EqualsRefOfModifyColumn(a, b) - case *Nextval: - b, ok := inB.(*Nextval) - if !ok { - return false - } - return EqualsRefOfNextval(a, b) + return CloneRefOfMatchExpr(in) case *NotExpr: - b, ok := inB.(*NotExpr) - if !ok { - return false - } - return EqualsRefOfNotExpr(a, b) + return CloneRefOfNotExpr(in) case *NullVal: - b, ok := inB.(*NullVal) - if !ok { - return false - } - return EqualsRefOfNullVal(a, b) - case OnDup: - b, ok := inB.(OnDup) - if !ok { - return false - } - return EqualsOnDup(a, b) - case *OptLike: - b, ok := inB.(*OptLike) - if !ok { - return false - } - return EqualsRefOfOptLike(a, b) + return CloneRefOfNullVal(in) case *OrExpr: - b, ok := inB.(*OrExpr) - if !ok { - return false - } - return EqualsRefOfOrExpr(a, b) - case *Order: - b, ok := inB.(*Order) - if !ok { - return false - } - return EqualsRefOfOrder(a, b) - case OrderBy: - b, ok := inB.(OrderBy) - if !ok { - return false - } - return EqualsOrderBy(a, b) - case *OrderByOption: - b, ok := inB.(*OrderByOption) - if !ok { - return false - } - return EqualsRefOfOrderByOption(a, b) - case *OtherAdmin: - b, ok := inB.(*OtherAdmin) - if !ok { - return false - } - return EqualsRefOfOtherAdmin(a, b) - case *OtherRead: - b, ok := inB.(*OtherRead) - if !ok { - return false - } - return EqualsRefOfOtherRead(a, b) - case *ParenSelect: - b, ok := inB.(*ParenSelect) - if !ok { - return false - } - return EqualsRefOfParenSelect(a, b) - case *ParenTableExpr: - b, ok := inB.(*ParenTableExpr) - if !ok { - return false - } - return EqualsRefOfParenTableExpr(a, b) - case *PartitionDefinition: - b, ok := inB.(*PartitionDefinition) - if !ok { - return false - } - return EqualsRefOfPartitionDefinition(a, b) - case *PartitionSpec: - b, ok := inB.(*PartitionSpec) - if !ok { - return false - } - return EqualsRefOfPartitionSpec(a, b) - case Partitions: - b, ok := inB.(Partitions) - if !ok { - return false - } - return EqualsPartitions(a, b) + return CloneRefOfOrExpr(in) case *RangeCond: - b, ok := inB.(*RangeCond) - if !ok { - return false - } - return EqualsRefOfRangeCond(a, b) - case ReferenceAction: - b, ok := inB.(ReferenceAction) - if !ok { - return false - } - return a == b - case *Release: - b, ok := inB.(*Release) - if !ok { - return false - } - return EqualsRefOfRelease(a, b) - case *RenameIndex: - b, ok := inB.(*RenameIndex) - if !ok { - return false - } - return EqualsRefOfRenameIndex(a, b) - case *RenameTable: - b, ok := inB.(*RenameTable) - if !ok { - return false - } - return EqualsRefOfRenameTable(a, b) - case *RenameTableName: - b, ok := inB.(*RenameTableName) - if !ok { - return false - } - return EqualsRefOfRenameTableName(a, b) - case *RevertMigration: - b, ok := inB.(*RevertMigration) - if !ok { - return false - } - return EqualsRefOfRevertMigration(a, b) - case *Rollback: - b, ok := inB.(*Rollback) - if !ok { - return false - } - return EqualsRefOfRollback(a, b) - case *SRollback: - b, ok := inB.(*SRollback) - if !ok { - return false - } - return EqualsRefOfSRollback(a, b) - case *Savepoint: - b, ok := inB.(*Savepoint) - if !ok { - return false - } - return EqualsRefOfSavepoint(a, b) - case *Select: - b, ok := inB.(*Select) - if !ok { - return false - } - return EqualsRefOfSelect(a, b) - case SelectExprs: - b, ok := inB.(SelectExprs) - if !ok { - return false - } - return EqualsSelectExprs(a, b) - case *SelectInto: - b, ok := inB.(*SelectInto) - if !ok { - return false - } - return EqualsRefOfSelectInto(a, b) - case *Set: - b, ok := inB.(*Set) - if !ok { - return false - } - return EqualsRefOfSet(a, b) - case *SetExpr: - b, ok := inB.(*SetExpr) - if !ok { - return false - } - return EqualsRefOfSetExpr(a, b) - case SetExprs: - b, ok := inB.(SetExprs) - if !ok { - return false - } - return EqualsSetExprs(a, b) - case *SetTransaction: - b, ok := inB.(*SetTransaction) - if !ok { - return false - } - return EqualsRefOfSetTransaction(a, b) - case *Show: - b, ok := inB.(*Show) - if !ok { - return false - } - return EqualsRefOfShow(a, b) - case *ShowBasic: - b, ok := inB.(*ShowBasic) - if !ok { - return false - } - return EqualsRefOfShowBasic(a, b) - case *ShowCreate: - b, ok := inB.(*ShowCreate) - if !ok { - return false - } - return EqualsRefOfShowCreate(a, b) - case *ShowFilter: - b, ok := inB.(*ShowFilter) - if !ok { - return false - } - return EqualsRefOfShowFilter(a, b) - case *ShowLegacy: - b, ok := inB.(*ShowLegacy) - if !ok { - return false - } - return EqualsRefOfShowLegacy(a, b) - case *StarExpr: - b, ok := inB.(*StarExpr) - if !ok { - return false - } - return EqualsRefOfStarExpr(a, b) - case *Stream: - b, ok := inB.(*Stream) - if !ok { - return false - } - return EqualsRefOfStream(a, b) + return CloneRefOfRangeCond(in) case *Subquery: - b, ok := inB.(*Subquery) - if !ok { - return false - } - return EqualsRefOfSubquery(a, b) + return CloneRefOfSubquery(in) case *SubstrExpr: - b, ok := inB.(*SubstrExpr) - if !ok { - return false - } - return EqualsRefOfSubstrExpr(a, b) - case TableExprs: - b, ok := inB.(TableExprs) - if !ok { - return false - } - return EqualsTableExprs(a, b) - case TableIdent: - b, ok := inB.(TableIdent) - if !ok { - return false - } - return EqualsTableIdent(a, b) - case TableName: - b, ok := inB.(TableName) - if !ok { - return false - } - return EqualsTableName(a, b) - case TableNames: - b, ok := inB.(TableNames) - if !ok { - return false - } - return EqualsTableNames(a, b) - case TableOptions: - b, ok := inB.(TableOptions) - if !ok { - return false - } - return EqualsTableOptions(a, b) - case *TableSpec: - b, ok := inB.(*TableSpec) - if !ok { - return false - } - return EqualsRefOfTableSpec(a, b) - case *TablespaceOperation: - b, ok := inB.(*TablespaceOperation) - if !ok { - return false - } - return EqualsRefOfTablespaceOperation(a, b) + return CloneRefOfSubstrExpr(in) case *TimestampFuncExpr: - b, ok := inB.(*TimestampFuncExpr) - if !ok { - return false - } - return EqualsRefOfTimestampFuncExpr(a, b) - case *TruncateTable: - b, ok := inB.(*TruncateTable) - if !ok { - return false - } - return EqualsRefOfTruncateTable(a, b) + return CloneRefOfTimestampFuncExpr(in) case *UnaryExpr: - b, ok := inB.(*UnaryExpr) - if !ok { - return false - } - return EqualsRefOfUnaryExpr(a, b) - case *Union: - b, ok := inB.(*Union) - if !ok { - return false - } - return EqualsRefOfUnion(a, b) - case *UnionSelect: - b, ok := inB.(*UnionSelect) - if !ok { - return false - } - return EqualsRefOfUnionSelect(a, b) - case *UnlockTables: - b, ok := inB.(*UnlockTables) - if !ok { - return false - } - return EqualsRefOfUnlockTables(a, b) - case *Update: - b, ok := inB.(*Update) - if !ok { - return false - } - return EqualsRefOfUpdate(a, b) - case *UpdateExpr: - b, ok := inB.(*UpdateExpr) - if !ok { - return false - } - return EqualsRefOfUpdateExpr(a, b) - case UpdateExprs: - b, ok := inB.(UpdateExprs) - if !ok { - return false - } - return EqualsUpdateExprs(a, b) - case *Use: - b, ok := inB.(*Use) - if !ok { - return false - } - return EqualsRefOfUse(a, b) - case *VStream: - b, ok := inB.(*VStream) - if !ok { - return false - } - return EqualsRefOfVStream(a, b) + return CloneRefOfUnaryExpr(in) case ValTuple: - b, ok := inB.(ValTuple) - if !ok { - return false - } - return EqualsValTuple(a, b) - case *Validation: - b, ok := inB.(*Validation) - if !ok { - return false - } - return EqualsRefOfValidation(a, b) - case Values: - b, ok := inB.(Values) - if !ok { - return false - } - return EqualsValues(a, b) + return CloneValTuple(in) case *ValuesFuncExpr: - b, ok := inB.(*ValuesFuncExpr) - if !ok { - return false - } - return EqualsRefOfValuesFuncExpr(a, b) - case VindexParam: - b, ok := inB.(VindexParam) - if !ok { - return false - } - return EqualsVindexParam(a, b) - case *VindexSpec: - b, ok := inB.(*VindexSpec) - if !ok { - return false - } - return EqualsRefOfVindexSpec(a, b) - case *When: - b, ok := inB.(*When) - if !ok { - return false - } - return EqualsRefOfWhen(a, b) - case *Where: - b, ok := inB.(*Where) - if !ok { - return false - } - return EqualsRefOfWhere(a, b) + return CloneRefOfValuesFuncExpr(in) case *XorExpr: - b, ok := inB.(*XorExpr) - if !ok { - return false - } - return EqualsRefOfXorExpr(a, b) + return CloneRefOfXorExpr(in) default: // this should never happen - return false + return nil } } -// CloneSQLNode creates a deep clone of the input. -func CloneSQLNode(in SQLNode) SQLNode { +// CloneExprs creates a deep clone of the input. +func CloneExprs(n Exprs) Exprs { + res := make(Exprs, 0, len(n)) + for _, x := range n { + res = append(res, CloneExpr(x)) + } + return res +} + +// CloneGroupBy creates a deep clone of the input. +func CloneGroupBy(n GroupBy) GroupBy { + res := make(GroupBy, 0, len(n)) + for _, x := range n { + res = append(res, CloneExpr(x)) + } + return res +} + +// CloneInsertRows creates a deep clone of the input. +func CloneInsertRows(in InsertRows) InsertRows { if in == nil { return nil } switch in := in.(type) { - case AccessMode: - return in - case *AddColumns: - return CloneRefOfAddColumns(in) - case *AddConstraintDefinition: - return CloneRefOfAddConstraintDefinition(in) - case *AddIndexDefinition: - return CloneRefOfAddIndexDefinition(in) - case AlgorithmValue: - return in - case *AliasedExpr: - return CloneRefOfAliasedExpr(in) - case *AliasedTableExpr: - return CloneRefOfAliasedTableExpr(in) - case *AlterCharset: - return CloneRefOfAlterCharset(in) - case *AlterColumn: - return CloneRefOfAlterColumn(in) - case *AlterDatabase: - return CloneRefOfAlterDatabase(in) - case *AlterMigration: - return CloneRefOfAlterMigration(in) - case *AlterTable: - return CloneRefOfAlterTable(in) - case *AlterView: - return CloneRefOfAlterView(in) - case *AlterVschema: - return CloneRefOfAlterVschema(in) - case *AndExpr: - return CloneRefOfAndExpr(in) - case Argument: - return in - case *AutoIncSpec: - return CloneRefOfAutoIncSpec(in) - case *Begin: - return CloneRefOfBegin(in) - case *BinaryExpr: - return CloneRefOfBinaryExpr(in) - case BoolVal: - return in - case *CallProc: - return CloneRefOfCallProc(in) - case *CaseExpr: - return CloneRefOfCaseExpr(in) - case *ChangeColumn: - return CloneRefOfChangeColumn(in) - case *CheckConstraintDefinition: - return CloneRefOfCheckConstraintDefinition(in) - case ColIdent: - return CloneColIdent(in) - case *ColName: - return CloneRefOfColName(in) - case *CollateExpr: - return CloneRefOfCollateExpr(in) - case *ColumnDefinition: - return CloneRefOfColumnDefinition(in) - case *ColumnType: - return CloneRefOfColumnType(in) - case Columns: - return CloneColumns(in) - case Comments: - return CloneComments(in) - case *Commit: - return CloneRefOfCommit(in) - case *ComparisonExpr: - return CloneRefOfComparisonExpr(in) - case *ConstraintDefinition: - return CloneRefOfConstraintDefinition(in) - case *ConvertExpr: - return CloneRefOfConvertExpr(in) - case *ConvertType: - return CloneRefOfConvertType(in) - case *ConvertUsingExpr: - return CloneRefOfConvertUsingExpr(in) - case *CreateDatabase: - return CloneRefOfCreateDatabase(in) - case *CreateTable: - return CloneRefOfCreateTable(in) - case *CreateView: - return CloneRefOfCreateView(in) - case *CurTimeFuncExpr: - return CloneRefOfCurTimeFuncExpr(in) - case *Default: - return CloneRefOfDefault(in) - case *Delete: - return CloneRefOfDelete(in) - case *DerivedTable: - return CloneRefOfDerivedTable(in) - case *DropColumn: - return CloneRefOfDropColumn(in) - case *DropDatabase: - return CloneRefOfDropDatabase(in) - case *DropKey: - return CloneRefOfDropKey(in) - case *DropTable: - return CloneRefOfDropTable(in) - case *DropView: - return CloneRefOfDropView(in) - case *ExistsExpr: - return CloneRefOfExistsExpr(in) - case *ExplainStmt: - return CloneRefOfExplainStmt(in) - case *ExplainTab: - return CloneRefOfExplainTab(in) - case Exprs: - return CloneExprs(in) - case *Flush: - return CloneRefOfFlush(in) - case *Force: - return CloneRefOfForce(in) - case *ForeignKeyDefinition: - return CloneRefOfForeignKeyDefinition(in) - case *FuncExpr: - return CloneRefOfFuncExpr(in) - case GroupBy: - return CloneGroupBy(in) - case *GroupConcatExpr: - return CloneRefOfGroupConcatExpr(in) - case *IndexDefinition: - return CloneRefOfIndexDefinition(in) - case *IndexHints: - return CloneRefOfIndexHints(in) - case *IndexInfo: - return CloneRefOfIndexInfo(in) - case *Insert: - return CloneRefOfInsert(in) - case *IntervalExpr: - return CloneRefOfIntervalExpr(in) - case *IsExpr: - return CloneRefOfIsExpr(in) - case IsolationLevel: - return in - case JoinCondition: - return CloneJoinCondition(in) - case *JoinTableExpr: - return CloneRefOfJoinTableExpr(in) - case *KeyState: - return CloneRefOfKeyState(in) - case *Limit: - return CloneRefOfLimit(in) - case ListArg: - return CloneListArg(in) - case *Literal: - return CloneRefOfLiteral(in) - case *Load: - return CloneRefOfLoad(in) - case *LockOption: - return CloneRefOfLockOption(in) - case *LockTables: - return CloneRefOfLockTables(in) - case *MatchExpr: - return CloneRefOfMatchExpr(in) - case *ModifyColumn: - return CloneRefOfModifyColumn(in) - case *Nextval: - return CloneRefOfNextval(in) - case *NotExpr: - return CloneRefOfNotExpr(in) - case *NullVal: - return CloneRefOfNullVal(in) - case OnDup: - return CloneOnDup(in) - case *OptLike: - return CloneRefOfOptLike(in) - case *OrExpr: - return CloneRefOfOrExpr(in) - case *Order: - return CloneRefOfOrder(in) - case OrderBy: - return CloneOrderBy(in) - case *OrderByOption: - return CloneRefOfOrderByOption(in) - case *OtherAdmin: - return CloneRefOfOtherAdmin(in) - case *OtherRead: - return CloneRefOfOtherRead(in) case *ParenSelect: return CloneRefOfParenSelect(in) - case *ParenTableExpr: - return CloneRefOfParenTableExpr(in) - case *PartitionDefinition: - return CloneRefOfPartitionDefinition(in) - case *PartitionSpec: - return CloneRefOfPartitionSpec(in) - case Partitions: - return ClonePartitions(in) - case *RangeCond: - return CloneRefOfRangeCond(in) - case ReferenceAction: - return in - case *Release: - return CloneRefOfRelease(in) - case *RenameIndex: - return CloneRefOfRenameIndex(in) - case *RenameTable: - return CloneRefOfRenameTable(in) - case *RenameTableName: - return CloneRefOfRenameTableName(in) - case *RevertMigration: - return CloneRefOfRevertMigration(in) - case *Rollback: - return CloneRefOfRollback(in) - case *SRollback: - return CloneRefOfSRollback(in) - case *Savepoint: - return CloneRefOfSavepoint(in) case *Select: return CloneRefOfSelect(in) - case SelectExprs: - return CloneSelectExprs(in) - case *SelectInto: - return CloneRefOfSelectInto(in) - case *Set: - return CloneRefOfSet(in) - case *SetExpr: - return CloneRefOfSetExpr(in) - case SetExprs: - return CloneSetExprs(in) - case *SetTransaction: - return CloneRefOfSetTransaction(in) - case *Show: - return CloneRefOfShow(in) - case *ShowBasic: - return CloneRefOfShowBasic(in) - case *ShowCreate: - return CloneRefOfShowCreate(in) - case *ShowFilter: - return CloneRefOfShowFilter(in) - case *ShowLegacy: - return CloneRefOfShowLegacy(in) - case *StarExpr: - return CloneRefOfStarExpr(in) - case *Stream: - return CloneRefOfStream(in) - case *Subquery: - return CloneRefOfSubquery(in) - case *SubstrExpr: - return CloneRefOfSubstrExpr(in) - case TableExprs: - return CloneTableExprs(in) - case TableIdent: - return CloneTableIdent(in) - case TableName: - return CloneTableName(in) - case TableNames: - return CloneTableNames(in) - case TableOptions: - return CloneTableOptions(in) - case *TableSpec: - return CloneRefOfTableSpec(in) - case *TablespaceOperation: - return CloneRefOfTablespaceOperation(in) - case *TimestampFuncExpr: - return CloneRefOfTimestampFuncExpr(in) - case *TruncateTable: - return CloneRefOfTruncateTable(in) - case *UnaryExpr: - return CloneRefOfUnaryExpr(in) case *Union: return CloneRefOfUnion(in) - case *UnionSelect: - return CloneRefOfUnionSelect(in) - case *UnlockTables: - return CloneRefOfUnlockTables(in) - case *Update: - return CloneRefOfUpdate(in) - case *UpdateExpr: - return CloneRefOfUpdateExpr(in) - case UpdateExprs: - return CloneUpdateExprs(in) - case *Use: - return CloneRefOfUse(in) - case *VStream: - return CloneRefOfVStream(in) - case ValTuple: - return CloneValTuple(in) - case *Validation: - return CloneRefOfValidation(in) case Values: return CloneValues(in) - case *ValuesFuncExpr: - return CloneRefOfValuesFuncExpr(in) - case VindexParam: - return CloneVindexParam(in) - case *VindexSpec: - return CloneRefOfVindexSpec(in) - case *When: - return CloneRefOfWhen(in) - case *Where: - return CloneRefOfWhere(in) - case *XorExpr: - return CloneRefOfXorExpr(in) default: // this should never happen return nil } } -// VisitSQLNode will visit all parts of the AST -func VisitSQLNode(in SQLNode, f Visit) error { - if in == nil { +// CloneJoinCondition creates a deep clone of the input. +func CloneJoinCondition(n JoinCondition) JoinCondition { + return *CloneRefOfJoinCondition(&n) +} + +// CloneListArg creates a deep clone of the input. +func CloneListArg(n ListArg) ListArg { + res := make(ListArg, 0, len(n)) + copy(res, n) + return res +} + +// CloneOnDup creates a deep clone of the input. +func CloneOnDup(n OnDup) OnDup { + res := make(OnDup, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfUpdateExpr(x)) + } + return res +} + +// CloneOrderBy creates a deep clone of the input. +func CloneOrderBy(n OrderBy) OrderBy { + res := make(OrderBy, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfOrder(x)) + } + return res +} + +// ClonePartitions creates a deep clone of the input. +func ClonePartitions(n Partitions) Partitions { + res := make(Partitions, 0, len(n)) + for _, x := range n { + res = append(res, CloneColIdent(x)) + } + return res +} + +// CloneRefOfAddColumns creates a deep clone of the input. +func CloneRefOfAddColumns(n *AddColumns) *AddColumns { + if n == nil { return nil } - switch in := in.(type) { - case AccessMode: - return VisitAccessMode(in, f) - case *AddColumns: - return VisitRefOfAddColumns(in, f) - case *AddConstraintDefinition: - return VisitRefOfAddConstraintDefinition(in, f) - case *AddIndexDefinition: - return VisitRefOfAddIndexDefinition(in, f) - case AlgorithmValue: - return VisitAlgorithmValue(in, f) - case *AliasedExpr: - return VisitRefOfAliasedExpr(in, f) - case *AliasedTableExpr: - return VisitRefOfAliasedTableExpr(in, f) - case *AlterCharset: - return VisitRefOfAlterCharset(in, f) - case *AlterColumn: - return VisitRefOfAlterColumn(in, f) - case *AlterDatabase: - return VisitRefOfAlterDatabase(in, f) - case *AlterMigration: - return VisitRefOfAlterMigration(in, f) - case *AlterTable: - return VisitRefOfAlterTable(in, f) - case *AlterView: - return VisitRefOfAlterView(in, f) - case *AlterVschema: - return VisitRefOfAlterVschema(in, f) - case *AndExpr: - return VisitRefOfAndExpr(in, f) - case Argument: - return VisitArgument(in, f) - case *AutoIncSpec: - return VisitRefOfAutoIncSpec(in, f) - case *Begin: - return VisitRefOfBegin(in, f) - case *BinaryExpr: - return VisitRefOfBinaryExpr(in, f) - case BoolVal: - return VisitBoolVal(in, f) - case *CallProc: - return VisitRefOfCallProc(in, f) - case *CaseExpr: - return VisitRefOfCaseExpr(in, f) - case *ChangeColumn: - return VisitRefOfChangeColumn(in, f) - case *CheckConstraintDefinition: - return VisitRefOfCheckConstraintDefinition(in, f) - case ColIdent: - return VisitColIdent(in, f) - case *ColName: - return VisitRefOfColName(in, f) - case *CollateExpr: - return VisitRefOfCollateExpr(in, f) - case *ColumnDefinition: - return VisitRefOfColumnDefinition(in, f) - case *ColumnType: - return VisitRefOfColumnType(in, f) - case Columns: - return VisitColumns(in, f) - case Comments: - return VisitComments(in, f) - case *Commit: - return VisitRefOfCommit(in, f) - case *ComparisonExpr: - return VisitRefOfComparisonExpr(in, f) - case *ConstraintDefinition: - return VisitRefOfConstraintDefinition(in, f) - case *ConvertExpr: - return VisitRefOfConvertExpr(in, f) - case *ConvertType: - return VisitRefOfConvertType(in, f) - case *ConvertUsingExpr: - return VisitRefOfConvertUsingExpr(in, f) - case *CreateDatabase: - return VisitRefOfCreateDatabase(in, f) - case *CreateTable: - return VisitRefOfCreateTable(in, f) - case *CreateView: - return VisitRefOfCreateView(in, f) - case *CurTimeFuncExpr: - return VisitRefOfCurTimeFuncExpr(in, f) - case *Default: - return VisitRefOfDefault(in, f) - case *Delete: - return VisitRefOfDelete(in, f) - case *DerivedTable: - return VisitRefOfDerivedTable(in, f) - case *DropColumn: - return VisitRefOfDropColumn(in, f) - case *DropDatabase: - return VisitRefOfDropDatabase(in, f) - case *DropKey: - return VisitRefOfDropKey(in, f) - case *DropTable: - return VisitRefOfDropTable(in, f) - case *DropView: - return VisitRefOfDropView(in, f) - case *ExistsExpr: - return VisitRefOfExistsExpr(in, f) - case *ExplainStmt: - return VisitRefOfExplainStmt(in, f) - case *ExplainTab: - return VisitRefOfExplainTab(in, f) - case Exprs: - return VisitExprs(in, f) - case *Flush: - return VisitRefOfFlush(in, f) - case *Force: - return VisitRefOfForce(in, f) - case *ForeignKeyDefinition: - return VisitRefOfForeignKeyDefinition(in, f) - case *FuncExpr: - return VisitRefOfFuncExpr(in, f) - case GroupBy: - return VisitGroupBy(in, f) - case *GroupConcatExpr: - return VisitRefOfGroupConcatExpr(in, f) - case *IndexDefinition: - return VisitRefOfIndexDefinition(in, f) - case *IndexHints: - return VisitRefOfIndexHints(in, f) - case *IndexInfo: - return VisitRefOfIndexInfo(in, f) - case *Insert: - return VisitRefOfInsert(in, f) - case *IntervalExpr: - return VisitRefOfIntervalExpr(in, f) - case *IsExpr: - return VisitRefOfIsExpr(in, f) - case IsolationLevel: - return VisitIsolationLevel(in, f) - case JoinCondition: - return VisitJoinCondition(in, f) - case *JoinTableExpr: - return VisitRefOfJoinTableExpr(in, f) - case *KeyState: - return VisitRefOfKeyState(in, f) - case *Limit: - return VisitRefOfLimit(in, f) - case ListArg: - return VisitListArg(in, f) - case *Literal: - return VisitRefOfLiteral(in, f) - case *Load: - return VisitRefOfLoad(in, f) - case *LockOption: - return VisitRefOfLockOption(in, f) - case *LockTables: - return VisitRefOfLockTables(in, f) - case *MatchExpr: - return VisitRefOfMatchExpr(in, f) - case *ModifyColumn: - return VisitRefOfModifyColumn(in, f) - case *Nextval: - return VisitRefOfNextval(in, f) - case *NotExpr: - return VisitRefOfNotExpr(in, f) - case *NullVal: - return VisitRefOfNullVal(in, f) - case OnDup: - return VisitOnDup(in, f) - case *OptLike: - return VisitRefOfOptLike(in, f) - case *OrExpr: - return VisitRefOfOrExpr(in, f) - case *Order: - return VisitRefOfOrder(in, f) - case OrderBy: - return VisitOrderBy(in, f) - case *OrderByOption: - return VisitRefOfOrderByOption(in, f) - case *OtherAdmin: - return VisitRefOfOtherAdmin(in, f) - case *OtherRead: - return VisitRefOfOtherRead(in, f) - case *ParenSelect: - return VisitRefOfParenSelect(in, f) - case *ParenTableExpr: - return VisitRefOfParenTableExpr(in, f) - case *PartitionDefinition: - return VisitRefOfPartitionDefinition(in, f) - case *PartitionSpec: - return VisitRefOfPartitionSpec(in, f) - case Partitions: - return VisitPartitions(in, f) - case *RangeCond: - return VisitRefOfRangeCond(in, f) - case ReferenceAction: - return VisitReferenceAction(in, f) - case *Release: - return VisitRefOfRelease(in, f) - case *RenameIndex: - return VisitRefOfRenameIndex(in, f) - case *RenameTable: - return VisitRefOfRenameTable(in, f) - case *RenameTableName: - return VisitRefOfRenameTableName(in, f) - case *RevertMigration: - return VisitRefOfRevertMigration(in, f) - case *Rollback: - return VisitRefOfRollback(in, f) - case *SRollback: - return VisitRefOfSRollback(in, f) - case *Savepoint: - return VisitRefOfSavepoint(in, f) - case *Select: - return VisitRefOfSelect(in, f) - case SelectExprs: - return VisitSelectExprs(in, f) - case *SelectInto: - return VisitRefOfSelectInto(in, f) - case *Set: - return VisitRefOfSet(in, f) - case *SetExpr: - return VisitRefOfSetExpr(in, f) - case SetExprs: - return VisitSetExprs(in, f) - case *SetTransaction: - return VisitRefOfSetTransaction(in, f) - case *Show: - return VisitRefOfShow(in, f) - case *ShowBasic: - return VisitRefOfShowBasic(in, f) - case *ShowCreate: - return VisitRefOfShowCreate(in, f) - case *ShowFilter: - return VisitRefOfShowFilter(in, f) - case *ShowLegacy: - return VisitRefOfShowLegacy(in, f) - case *StarExpr: - return VisitRefOfStarExpr(in, f) - case *Stream: - return VisitRefOfStream(in, f) - case *Subquery: - return VisitRefOfSubquery(in, f) - case *SubstrExpr: - return VisitRefOfSubstrExpr(in, f) - case TableExprs: - return VisitTableExprs(in, f) - case TableIdent: - return VisitTableIdent(in, f) - case TableName: - return VisitTableName(in, f) - case TableNames: - return VisitTableNames(in, f) - case TableOptions: - return VisitTableOptions(in, f) - case *TableSpec: - return VisitRefOfTableSpec(in, f) - case *TablespaceOperation: - return VisitRefOfTablespaceOperation(in, f) - case *TimestampFuncExpr: - return VisitRefOfTimestampFuncExpr(in, f) - case *TruncateTable: - return VisitRefOfTruncateTable(in, f) - case *UnaryExpr: - return VisitRefOfUnaryExpr(in, f) - case *Union: - return VisitRefOfUnion(in, f) - case *UnionSelect: - return VisitRefOfUnionSelect(in, f) - case *UnlockTables: - return VisitRefOfUnlockTables(in, f) - case *Update: - return VisitRefOfUpdate(in, f) - case *UpdateExpr: - return VisitRefOfUpdateExpr(in, f) - case UpdateExprs: - return VisitUpdateExprs(in, f) - case *Use: - return VisitRefOfUse(in, f) - case *VStream: - return VisitRefOfVStream(in, f) - case ValTuple: - return VisitValTuple(in, f) - case *Validation: - return VisitRefOfValidation(in, f) - case Values: - return VisitValues(in, f) - case *ValuesFuncExpr: - return VisitRefOfValuesFuncExpr(in, f) - case VindexParam: - return VisitVindexParam(in, f) - case *VindexSpec: - return VisitRefOfVindexSpec(in, f) - case *When: - return VisitRefOfWhen(in, f) - case *Where: - return VisitRefOfWhere(in, f) - case *XorExpr: - return VisitRefOfXorExpr(in, f) - default: - // this should never happen + out := *n + out.Columns = CloneSliceOfRefOfColumnDefinition(n.Columns) + out.First = CloneRefOfColName(n.First) + out.After = CloneRefOfColName(n.After) + return &out +} + +// CloneRefOfAddConstraintDefinition creates a deep clone of the input. +func CloneRefOfAddConstraintDefinition(n *AddConstraintDefinition) *AddConstraintDefinition { + if n == nil { return nil } + out := *n + out.ConstraintDefinition = CloneRefOfConstraintDefinition(n.ConstraintDefinition) + return &out } -// rewriteSQLNode is part of the Rewrite implementation -func (a *application) rewriteSQLNode(parent SQLNode, node SQLNode, replacer replacerFunc) error { - if node == nil { +// CloneRefOfAddIndexDefinition creates a deep clone of the input. +func CloneRefOfAddIndexDefinition(n *AddIndexDefinition) *AddIndexDefinition { + if n == nil { return nil } - switch node := node.(type) { - case AccessMode: - return a.rewriteAccessMode(parent, node, replacer) - case *AddColumns: - return a.rewriteRefOfAddColumns(parent, node, replacer) - case *AddConstraintDefinition: - return a.rewriteRefOfAddConstraintDefinition(parent, node, replacer) - case *AddIndexDefinition: - return a.rewriteRefOfAddIndexDefinition(parent, node, replacer) - case AlgorithmValue: - return a.rewriteAlgorithmValue(parent, node, replacer) - case *AliasedExpr: - return a.rewriteRefOfAliasedExpr(parent, node, replacer) - case *AliasedTableExpr: - return a.rewriteRefOfAliasedTableExpr(parent, node, replacer) - case *AlterCharset: - return a.rewriteRefOfAlterCharset(parent, node, replacer) - case *AlterColumn: - return a.rewriteRefOfAlterColumn(parent, node, replacer) - case *AlterDatabase: - return a.rewriteRefOfAlterDatabase(parent, node, replacer) - case *AlterMigration: - return a.rewriteRefOfAlterMigration(parent, node, replacer) - case *AlterTable: - return a.rewriteRefOfAlterTable(parent, node, replacer) - case *AlterView: - return a.rewriteRefOfAlterView(parent, node, replacer) - case *AlterVschema: - return a.rewriteRefOfAlterVschema(parent, node, replacer) - case *AndExpr: - return a.rewriteRefOfAndExpr(parent, node, replacer) - case Argument: - return a.rewriteArgument(parent, node, replacer) - case *AutoIncSpec: - return a.rewriteRefOfAutoIncSpec(parent, node, replacer) - case *Begin: - return a.rewriteRefOfBegin(parent, node, replacer) - case *BinaryExpr: - return a.rewriteRefOfBinaryExpr(parent, node, replacer) - case BoolVal: - return a.rewriteBoolVal(parent, node, replacer) - case *CallProc: - return a.rewriteRefOfCallProc(parent, node, replacer) - case *CaseExpr: - return a.rewriteRefOfCaseExpr(parent, node, replacer) - case *ChangeColumn: - return a.rewriteRefOfChangeColumn(parent, node, replacer) - case *CheckConstraintDefinition: - return a.rewriteRefOfCheckConstraintDefinition(parent, node, replacer) - case ColIdent: - return a.rewriteColIdent(parent, node, replacer) - case *ColName: - return a.rewriteRefOfColName(parent, node, replacer) - case *CollateExpr: - return a.rewriteRefOfCollateExpr(parent, node, replacer) - case *ColumnDefinition: - return a.rewriteRefOfColumnDefinition(parent, node, replacer) - case *ColumnType: - return a.rewriteRefOfColumnType(parent, node, replacer) - case Columns: - return a.rewriteColumns(parent, node, replacer) - case Comments: - return a.rewriteComments(parent, node, replacer) - case *Commit: - return a.rewriteRefOfCommit(parent, node, replacer) - case *ComparisonExpr: - return a.rewriteRefOfComparisonExpr(parent, node, replacer) - case *ConstraintDefinition: - return a.rewriteRefOfConstraintDefinition(parent, node, replacer) - case *ConvertExpr: - return a.rewriteRefOfConvertExpr(parent, node, replacer) - case *ConvertType: - return a.rewriteRefOfConvertType(parent, node, replacer) - case *ConvertUsingExpr: - return a.rewriteRefOfConvertUsingExpr(parent, node, replacer) - case *CreateDatabase: - return a.rewriteRefOfCreateDatabase(parent, node, replacer) - case *CreateTable: - return a.rewriteRefOfCreateTable(parent, node, replacer) - case *CreateView: - return a.rewriteRefOfCreateView(parent, node, replacer) - case *CurTimeFuncExpr: - return a.rewriteRefOfCurTimeFuncExpr(parent, node, replacer) - case *Default: - return a.rewriteRefOfDefault(parent, node, replacer) - case *Delete: - return a.rewriteRefOfDelete(parent, node, replacer) - case *DerivedTable: - return a.rewriteRefOfDerivedTable(parent, node, replacer) - case *DropColumn: - return a.rewriteRefOfDropColumn(parent, node, replacer) - case *DropDatabase: - return a.rewriteRefOfDropDatabase(parent, node, replacer) - case *DropKey: - return a.rewriteRefOfDropKey(parent, node, replacer) - case *DropTable: - return a.rewriteRefOfDropTable(parent, node, replacer) - case *DropView: - return a.rewriteRefOfDropView(parent, node, replacer) - case *ExistsExpr: - return a.rewriteRefOfExistsExpr(parent, node, replacer) - case *ExplainStmt: - return a.rewriteRefOfExplainStmt(parent, node, replacer) - case *ExplainTab: - return a.rewriteRefOfExplainTab(parent, node, replacer) - case Exprs: - return a.rewriteExprs(parent, node, replacer) - case *Flush: - return a.rewriteRefOfFlush(parent, node, replacer) - case *Force: - return a.rewriteRefOfForce(parent, node, replacer) - case *ForeignKeyDefinition: - return a.rewriteRefOfForeignKeyDefinition(parent, node, replacer) - case *FuncExpr: - return a.rewriteRefOfFuncExpr(parent, node, replacer) - case GroupBy: - return a.rewriteGroupBy(parent, node, replacer) - case *GroupConcatExpr: - return a.rewriteRefOfGroupConcatExpr(parent, node, replacer) - case *IndexDefinition: - return a.rewriteRefOfIndexDefinition(parent, node, replacer) - case *IndexHints: - return a.rewriteRefOfIndexHints(parent, node, replacer) - case *IndexInfo: - return a.rewriteRefOfIndexInfo(parent, node, replacer) - case *Insert: - return a.rewriteRefOfInsert(parent, node, replacer) - case *IntervalExpr: - return a.rewriteRefOfIntervalExpr(parent, node, replacer) - case *IsExpr: - return a.rewriteRefOfIsExpr(parent, node, replacer) - case IsolationLevel: - return a.rewriteIsolationLevel(parent, node, replacer) - case JoinCondition: - return a.rewriteJoinCondition(parent, node, replacer) - case *JoinTableExpr: - return a.rewriteRefOfJoinTableExpr(parent, node, replacer) - case *KeyState: - return a.rewriteRefOfKeyState(parent, node, replacer) - case *Limit: - return a.rewriteRefOfLimit(parent, node, replacer) - case ListArg: - return a.rewriteListArg(parent, node, replacer) - case *Literal: - return a.rewriteRefOfLiteral(parent, node, replacer) - case *Load: - return a.rewriteRefOfLoad(parent, node, replacer) - case *LockOption: - return a.rewriteRefOfLockOption(parent, node, replacer) - case *LockTables: - return a.rewriteRefOfLockTables(parent, node, replacer) - case *MatchExpr: - return a.rewriteRefOfMatchExpr(parent, node, replacer) - case *ModifyColumn: - return a.rewriteRefOfModifyColumn(parent, node, replacer) - case *Nextval: - return a.rewriteRefOfNextval(parent, node, replacer) - case *NotExpr: - return a.rewriteRefOfNotExpr(parent, node, replacer) - case *NullVal: - return a.rewriteRefOfNullVal(parent, node, replacer) - case OnDup: - return a.rewriteOnDup(parent, node, replacer) - case *OptLike: - return a.rewriteRefOfOptLike(parent, node, replacer) - case *OrExpr: - return a.rewriteRefOfOrExpr(parent, node, replacer) - case *Order: - return a.rewriteRefOfOrder(parent, node, replacer) - case OrderBy: - return a.rewriteOrderBy(parent, node, replacer) - case *OrderByOption: - return a.rewriteRefOfOrderByOption(parent, node, replacer) - case *OtherAdmin: - return a.rewriteRefOfOtherAdmin(parent, node, replacer) - case *OtherRead: - return a.rewriteRefOfOtherRead(parent, node, replacer) - case *ParenSelect: - return a.rewriteRefOfParenSelect(parent, node, replacer) - case *ParenTableExpr: - return a.rewriteRefOfParenTableExpr(parent, node, replacer) - case *PartitionDefinition: - return a.rewriteRefOfPartitionDefinition(parent, node, replacer) - case *PartitionSpec: - return a.rewriteRefOfPartitionSpec(parent, node, replacer) - case Partitions: - return a.rewritePartitions(parent, node, replacer) - case *RangeCond: - return a.rewriteRefOfRangeCond(parent, node, replacer) - case ReferenceAction: - return a.rewriteReferenceAction(parent, node, replacer) - case *Release: - return a.rewriteRefOfRelease(parent, node, replacer) - case *RenameIndex: - return a.rewriteRefOfRenameIndex(parent, node, replacer) - case *RenameTable: - return a.rewriteRefOfRenameTable(parent, node, replacer) - case *RenameTableName: - return a.rewriteRefOfRenameTableName(parent, node, replacer) - case *RevertMigration: - return a.rewriteRefOfRevertMigration(parent, node, replacer) - case *Rollback: - return a.rewriteRefOfRollback(parent, node, replacer) - case *SRollback: - return a.rewriteRefOfSRollback(parent, node, replacer) - case *Savepoint: - return a.rewriteRefOfSavepoint(parent, node, replacer) - case *Select: - return a.rewriteRefOfSelect(parent, node, replacer) - case SelectExprs: - return a.rewriteSelectExprs(parent, node, replacer) - case *SelectInto: - return a.rewriteRefOfSelectInto(parent, node, replacer) - case *Set: - return a.rewriteRefOfSet(parent, node, replacer) - case *SetExpr: - return a.rewriteRefOfSetExpr(parent, node, replacer) - case SetExprs: - return a.rewriteSetExprs(parent, node, replacer) - case *SetTransaction: - return a.rewriteRefOfSetTransaction(parent, node, replacer) - case *Show: - return a.rewriteRefOfShow(parent, node, replacer) - case *ShowBasic: - return a.rewriteRefOfShowBasic(parent, node, replacer) - case *ShowCreate: - return a.rewriteRefOfShowCreate(parent, node, replacer) - case *ShowFilter: - return a.rewriteRefOfShowFilter(parent, node, replacer) - case *ShowLegacy: - return a.rewriteRefOfShowLegacy(parent, node, replacer) - case *StarExpr: - return a.rewriteRefOfStarExpr(parent, node, replacer) - case *Stream: - return a.rewriteRefOfStream(parent, node, replacer) - case *Subquery: - return a.rewriteRefOfSubquery(parent, node, replacer) - case *SubstrExpr: - return a.rewriteRefOfSubstrExpr(parent, node, replacer) - case TableExprs: - return a.rewriteTableExprs(parent, node, replacer) - case TableIdent: - return a.rewriteTableIdent(parent, node, replacer) - case TableName: - return a.rewriteTableName(parent, node, replacer) - case TableNames: - return a.rewriteTableNames(parent, node, replacer) - case TableOptions: - return a.rewriteTableOptions(parent, node, replacer) - case *TableSpec: - return a.rewriteRefOfTableSpec(parent, node, replacer) - case *TablespaceOperation: - return a.rewriteRefOfTablespaceOperation(parent, node, replacer) - case *TimestampFuncExpr: - return a.rewriteRefOfTimestampFuncExpr(parent, node, replacer) - case *TruncateTable: - return a.rewriteRefOfTruncateTable(parent, node, replacer) - case *UnaryExpr: - return a.rewriteRefOfUnaryExpr(parent, node, replacer) - case *Union: - return a.rewriteRefOfUnion(parent, node, replacer) - case *UnionSelect: - return a.rewriteRefOfUnionSelect(parent, node, replacer) - case *UnlockTables: - return a.rewriteRefOfUnlockTables(parent, node, replacer) - case *Update: - return a.rewriteRefOfUpdate(parent, node, replacer) - case *UpdateExpr: - return a.rewriteRefOfUpdateExpr(parent, node, replacer) - case UpdateExprs: - return a.rewriteUpdateExprs(parent, node, replacer) - case *Use: - return a.rewriteRefOfUse(parent, node, replacer) - case *VStream: - return a.rewriteRefOfVStream(parent, node, replacer) - case ValTuple: - return a.rewriteValTuple(parent, node, replacer) - case *Validation: - return a.rewriteRefOfValidation(parent, node, replacer) - case Values: - return a.rewriteValues(parent, node, replacer) - case *ValuesFuncExpr: - return a.rewriteRefOfValuesFuncExpr(parent, node, replacer) - case VindexParam: - return a.rewriteVindexParam(parent, node, replacer) - case *VindexSpec: - return a.rewriteRefOfVindexSpec(parent, node, replacer) - case *When: - return a.rewriteRefOfWhen(parent, node, replacer) - case *Where: - return a.rewriteRefOfWhere(parent, node, replacer) - case *XorExpr: - return a.rewriteRefOfXorExpr(parent, node, replacer) - default: - // this should never happen + out := *n + out.IndexDefinition = CloneRefOfIndexDefinition(n.IndexDefinition) + return &out +} + +// CloneRefOfAliasedExpr creates a deep clone of the input. +func CloneRefOfAliasedExpr(n *AliasedExpr) *AliasedExpr { + if n == nil { return nil } + out := *n + out.Expr = CloneExpr(n.Expr) + out.As = CloneColIdent(n.As) + return &out } -// EqualsRefOfAddColumns does deep equals between the two objects. -func EqualsRefOfAddColumns(a, b *AddColumns) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfAliasedTableExpr creates a deep clone of the input. +func CloneRefOfAliasedTableExpr(n *AliasedTableExpr) *AliasedTableExpr { + if n == nil { + return nil } - return EqualsSliceOfRefOfColumnDefinition(a.Columns, b.Columns) && - EqualsRefOfColName(a.First, b.First) && - EqualsRefOfColName(a.After, b.After) + out := *n + out.Expr = CloneSimpleTableExpr(n.Expr) + out.Partitions = ClonePartitions(n.Partitions) + out.As = CloneTableIdent(n.As) + out.Hints = CloneRefOfIndexHints(n.Hints) + return &out } -// CloneRefOfAddColumns creates a deep clone of the input. -func CloneRefOfAddColumns(n *AddColumns) *AddColumns { +// CloneRefOfAlterCharset creates a deep clone of the input. +func CloneRefOfAlterCharset(n *AlterCharset) *AlterCharset { if n == nil { return nil } out := *n - out.Columns = CloneSliceOfRefOfColumnDefinition(n.Columns) - out.First = CloneRefOfColName(n.First) - out.After = CloneRefOfColName(n.After) return &out } -// VisitRefOfAddColumns will visit all parts of the AST -func VisitRefOfAddColumns(in *AddColumns, f Visit) error { - if in == nil { +// CloneRefOfAlterColumn creates a deep clone of the input. +func CloneRefOfAlterColumn(n *AlterColumn) *AlterColumn { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - for _, el := range in.Columns { - if err := VisitRefOfColumnDefinition(el, f); err != nil { - return err - } - } - if err := VisitRefOfColName(in.First, f); err != nil { - return err - } - if err := VisitRefOfColName(in.After, f); err != nil { - return err - } - return nil + out := *n + out.Column = CloneRefOfColName(n.Column) + out.DefaultVal = CloneExpr(n.DefaultVal) + return &out } -// rewriteRefOfAddColumns is part of the Rewrite implementation -func (a *application) rewriteRefOfAddColumns(parent SQLNode, node *AddColumns, replacer replacerFunc) error { - if node == nil { +// CloneRefOfAlterDatabase creates a deep clone of the input. +func CloneRefOfAlterDatabase(n *AlterDatabase) *AlterDatabase { + if n == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + out := *n + out.AlterOptions = CloneSliceOfCollateAndCharset(n.AlterOptions) + return &out +} + +// CloneRefOfAlterMigration creates a deep clone of the input. +func CloneRefOfAlterMigration(n *AlterMigration) *AlterMigration { + if n == nil { return nil } - for i, el := range node.Columns { - if errF := a.rewriteRefOfColumnDefinition(node, el, func(newNode, parent SQLNode) { - parent.(*AddColumns).Columns[i] = newNode.(*ColumnDefinition) - }); errF != nil { - return errF - } - } - if errF := a.rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { - parent.(*AddColumns).First = newNode.(*ColName) - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { - parent.(*AddColumns).After = newNode.(*ColName) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + return &out } -// EqualsRefOfAddConstraintDefinition does deep equals between the two objects. -func EqualsRefOfAddConstraintDefinition(a, b *AddConstraintDefinition) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfAlterTable creates a deep clone of the input. +func CloneRefOfAlterTable(n *AlterTable) *AlterTable { + if n == nil { + return nil } - return EqualsRefOfConstraintDefinition(a.ConstraintDefinition, b.ConstraintDefinition) + out := *n + out.Table = CloneTableName(n.Table) + out.AlterOptions = CloneSliceOfAlterOption(n.AlterOptions) + out.PartitionSpec = CloneRefOfPartitionSpec(n.PartitionSpec) + return &out } -// CloneRefOfAddConstraintDefinition creates a deep clone of the input. -func CloneRefOfAddConstraintDefinition(n *AddConstraintDefinition) *AddConstraintDefinition { +// CloneRefOfAlterView creates a deep clone of the input. +func CloneRefOfAlterView(n *AlterView) *AlterView { if n == nil { return nil } out := *n - out.ConstraintDefinition = CloneRefOfConstraintDefinition(n.ConstraintDefinition) + out.ViewName = CloneTableName(n.ViewName) + out.Columns = CloneColumns(n.Columns) + out.Select = CloneSelectStatement(n.Select) return &out } -// VisitRefOfAddConstraintDefinition will visit all parts of the AST -func VisitRefOfAddConstraintDefinition(in *AddConstraintDefinition, f Visit) error { - if in == nil { +// CloneRefOfAlterVschema creates a deep clone of the input. +func CloneRefOfAlterVschema(n *AlterVschema) *AlterVschema { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitRefOfConstraintDefinition(in.ConstraintDefinition, f); err != nil { - return err - } - return nil + out := *n + out.Table = CloneTableName(n.Table) + out.VindexSpec = CloneRefOfVindexSpec(n.VindexSpec) + out.VindexCols = CloneSliceOfColIdent(n.VindexCols) + out.AutoIncSpec = CloneRefOfAutoIncSpec(n.AutoIncSpec) + return &out } -// rewriteRefOfAddConstraintDefinition is part of the Rewrite implementation -func (a *application) rewriteRefOfAddConstraintDefinition(parent SQLNode, node *AddConstraintDefinition, replacer replacerFunc) error { - if node == nil { +// CloneRefOfAndExpr creates a deep clone of the input. +func CloneRefOfAndExpr(n *AndExpr) *AndExpr { + if n == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + out := *n + out.Left = CloneExpr(n.Left) + out.Right = CloneExpr(n.Right) + return &out +} + +// CloneRefOfAutoIncSpec creates a deep clone of the input. +func CloneRefOfAutoIncSpec(n *AutoIncSpec) *AutoIncSpec { + if n == nil { return nil } - if errF := a.rewriteRefOfConstraintDefinition(node, node.ConstraintDefinition, func(newNode, parent SQLNode) { - parent.(*AddConstraintDefinition).ConstraintDefinition = newNode.(*ConstraintDefinition) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.Column = CloneColIdent(n.Column) + out.Sequence = CloneTableName(n.Sequence) + return &out } -// EqualsRefOfAddIndexDefinition does deep equals between the two objects. -func EqualsRefOfAddIndexDefinition(a, b *AddIndexDefinition) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfBegin creates a deep clone of the input. +func CloneRefOfBegin(n *Begin) *Begin { + if n == nil { + return nil } - return EqualsRefOfIndexDefinition(a.IndexDefinition, b.IndexDefinition) + out := *n + return &out } -// CloneRefOfAddIndexDefinition creates a deep clone of the input. -func CloneRefOfAddIndexDefinition(n *AddIndexDefinition) *AddIndexDefinition { +// CloneRefOfBinaryExpr creates a deep clone of the input. +func CloneRefOfBinaryExpr(n *BinaryExpr) *BinaryExpr { if n == nil { return nil } out := *n - out.IndexDefinition = CloneRefOfIndexDefinition(n.IndexDefinition) + out.Left = CloneExpr(n.Left) + out.Right = CloneExpr(n.Right) return &out } -// VisitRefOfAddIndexDefinition will visit all parts of the AST -func VisitRefOfAddIndexDefinition(in *AddIndexDefinition, f Visit) error { - if in == nil { +// CloneRefOfBool creates a deep clone of the input. +func CloneRefOfBool(n *bool) *bool { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitRefOfIndexDefinition(in.IndexDefinition, f); err != nil { - return err - } - return nil + out := *n + return &out } -// rewriteRefOfAddIndexDefinition is part of the Rewrite implementation -func (a *application) rewriteRefOfAddIndexDefinition(parent SQLNode, node *AddIndexDefinition, replacer replacerFunc) error { - if node == nil { +// CloneRefOfCallProc creates a deep clone of the input. +func CloneRefOfCallProc(n *CallProc) *CallProc { + if n == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + out := *n + out.Name = CloneTableName(n.Name) + out.Params = CloneExprs(n.Params) + return &out +} + +// CloneRefOfCaseExpr creates a deep clone of the input. +func CloneRefOfCaseExpr(n *CaseExpr) *CaseExpr { + if n == nil { return nil } - if errF := a.rewriteRefOfIndexDefinition(node, node.IndexDefinition, func(newNode, parent SQLNode) { - parent.(*AddIndexDefinition).IndexDefinition = newNode.(*IndexDefinition) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.Expr = CloneExpr(n.Expr) + out.Whens = CloneSliceOfRefOfWhen(n.Whens) + out.Else = CloneExpr(n.Else) + return &out } -// EqualsRefOfAliasedExpr does deep equals between the two objects. -func EqualsRefOfAliasedExpr(a, b *AliasedExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfChangeColumn creates a deep clone of the input. +func CloneRefOfChangeColumn(n *ChangeColumn) *ChangeColumn { + if n == nil { + return nil } - return EqualsExpr(a.Expr, b.Expr) && - EqualsColIdent(a.As, b.As) + out := *n + out.OldColumn = CloneRefOfColName(n.OldColumn) + out.NewColDefinition = CloneRefOfColumnDefinition(n.NewColDefinition) + out.First = CloneRefOfColName(n.First) + out.After = CloneRefOfColName(n.After) + return &out } -// CloneRefOfAliasedExpr creates a deep clone of the input. -func CloneRefOfAliasedExpr(n *AliasedExpr) *AliasedExpr { +// CloneRefOfCheckConstraintDefinition creates a deep clone of the input. +func CloneRefOfCheckConstraintDefinition(n *CheckConstraintDefinition) *CheckConstraintDefinition { if n == nil { return nil } out := *n out.Expr = CloneExpr(n.Expr) - out.As = CloneColIdent(n.As) return &out } -// VisitRefOfAliasedExpr will visit all parts of the AST -func VisitRefOfAliasedExpr(in *AliasedExpr, f Visit) error { - if in == nil { +// CloneRefOfColIdent creates a deep clone of the input. +func CloneRefOfColIdent(n *ColIdent) *ColIdent { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Expr, f); err != nil { - return err - } - if err := VisitColIdent(in.As, f); err != nil { - return err - } - return nil + out := *n + return &out } -// rewriteRefOfAliasedExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfAliasedExpr(parent SQLNode, node *AliasedExpr, replacer replacerFunc) error { - if node == nil { +// CloneRefOfColName creates a deep clone of the input. +func CloneRefOfColName(n *ColName) *ColName { + return n +} + +// CloneRefOfCollateAndCharset creates a deep clone of the input. +func CloneRefOfCollateAndCharset(n *CollateAndCharset) *CollateAndCharset { + if n == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + out := *n + return &out +} + +// CloneRefOfCollateExpr creates a deep clone of the input. +func CloneRefOfCollateExpr(n *CollateExpr) *CollateExpr { + if n == nil { return nil } - if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*AliasedExpr).Expr = newNode.(Expr) - }); errF != nil { - return errF - } - if errF := a.rewriteColIdent(node, node.As, func(newNode, parent SQLNode) { - parent.(*AliasedExpr).As = newNode.(ColIdent) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.Expr = CloneExpr(n.Expr) + return &out } -// EqualsRefOfAliasedTableExpr does deep equals between the two objects. -func EqualsRefOfAliasedTableExpr(a, b *AliasedTableExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfColumnDefinition creates a deep clone of the input. +func CloneRefOfColumnDefinition(n *ColumnDefinition) *ColumnDefinition { + if n == nil { + return nil } - return EqualsSimpleTableExpr(a.Expr, b.Expr) && - EqualsPartitions(a.Partitions, b.Partitions) && - EqualsTableIdent(a.As, b.As) && - EqualsRefOfIndexHints(a.Hints, b.Hints) + out := *n + out.Name = CloneColIdent(n.Name) + out.Type = CloneColumnType(n.Type) + return &out } -// CloneRefOfAliasedTableExpr creates a deep clone of the input. -func CloneRefOfAliasedTableExpr(n *AliasedTableExpr) *AliasedTableExpr { +// CloneRefOfColumnType creates a deep clone of the input. +func CloneRefOfColumnType(n *ColumnType) *ColumnType { if n == nil { return nil } out := *n - out.Expr = CloneSimpleTableExpr(n.Expr) - out.Partitions = ClonePartitions(n.Partitions) - out.As = CloneTableIdent(n.As) - out.Hints = CloneRefOfIndexHints(n.Hints) + out.Options = CloneRefOfColumnTypeOptions(n.Options) + out.Length = CloneRefOfLiteral(n.Length) + out.Scale = CloneRefOfLiteral(n.Scale) + out.EnumValues = CloneSliceOfString(n.EnumValues) return &out } -// VisitRefOfAliasedTableExpr will visit all parts of the AST -func VisitRefOfAliasedTableExpr(in *AliasedTableExpr, f Visit) error { - if in == nil { +// CloneRefOfColumnTypeOptions creates a deep clone of the input. +func CloneRefOfColumnTypeOptions(n *ColumnTypeOptions) *ColumnTypeOptions { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitSimpleTableExpr(in.Expr, f); err != nil { - return err - } - if err := VisitPartitions(in.Partitions, f); err != nil { - return err - } - if err := VisitTableIdent(in.As, f); err != nil { - return err - } - if err := VisitRefOfIndexHints(in.Hints, f); err != nil { - return err - } - return nil + out := *n + out.Default = CloneExpr(n.Default) + out.OnUpdate = CloneExpr(n.OnUpdate) + out.Comment = CloneRefOfLiteral(n.Comment) + return &out } -// rewriteRefOfAliasedTableExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfAliasedTableExpr(parent SQLNode, node *AliasedTableExpr, replacer replacerFunc) error { - if node == nil { +// CloneRefOfCommit creates a deep clone of the input. +func CloneRefOfCommit(n *Commit) *Commit { + if n == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + out := *n + return &out +} + +// CloneRefOfComparisonExpr creates a deep clone of the input. +func CloneRefOfComparisonExpr(n *ComparisonExpr) *ComparisonExpr { + if n == nil { return nil } - if errF := a.rewriteSimpleTableExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*AliasedTableExpr).Expr = newNode.(SimpleTableExpr) - }); errF != nil { - return errF - } - if errF := a.rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { - parent.(*AliasedTableExpr).Partitions = newNode.(Partitions) - }); errF != nil { - return errF - } - if errF := a.rewriteTableIdent(node, node.As, func(newNode, parent SQLNode) { - parent.(*AliasedTableExpr).As = newNode.(TableIdent) - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfIndexHints(node, node.Hints, func(newNode, parent SQLNode) { - parent.(*AliasedTableExpr).Hints = newNode.(*IndexHints) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.Left = CloneExpr(n.Left) + out.Right = CloneExpr(n.Right) + out.Escape = CloneExpr(n.Escape) + return &out } -// EqualsRefOfAlterCharset does deep equals between the two objects. -func EqualsRefOfAlterCharset(a, b *AlterCharset) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfConstraintDefinition creates a deep clone of the input. +func CloneRefOfConstraintDefinition(n *ConstraintDefinition) *ConstraintDefinition { + if n == nil { + return nil } - return a.CharacterSet == b.CharacterSet && - a.Collate == b.Collate + out := *n + out.Details = CloneConstraintInfo(n.Details) + return &out } -// CloneRefOfAlterCharset creates a deep clone of the input. -func CloneRefOfAlterCharset(n *AlterCharset) *AlterCharset { +// CloneRefOfConvertExpr creates a deep clone of the input. +func CloneRefOfConvertExpr(n *ConvertExpr) *ConvertExpr { if n == nil { return nil } out := *n + out.Expr = CloneExpr(n.Expr) + out.Type = CloneRefOfConvertType(n.Type) return &out } -// VisitRefOfAlterCharset will visit all parts of the AST -func VisitRefOfAlterCharset(in *AlterCharset, f Visit) error { - if in == nil { +// CloneRefOfConvertType creates a deep clone of the input. +func CloneRefOfConvertType(n *ConvertType) *ConvertType { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil + out := *n + out.Length = CloneRefOfLiteral(n.Length) + out.Scale = CloneRefOfLiteral(n.Scale) + return &out } -// rewriteRefOfAlterCharset is part of the Rewrite implementation -func (a *application) rewriteRefOfAlterCharset(parent SQLNode, node *AlterCharset, replacer replacerFunc) error { - if node == nil { +// CloneRefOfConvertUsingExpr creates a deep clone of the input. +func CloneRefOfConvertUsingExpr(n *ConvertUsingExpr) *ConvertUsingExpr { + if n == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + out := *n + out.Expr = CloneExpr(n.Expr) + return &out +} + +// CloneRefOfCreateDatabase creates a deep clone of the input. +func CloneRefOfCreateDatabase(n *CreateDatabase) *CreateDatabase { + if n == nil { return nil } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.Comments = CloneComments(n.Comments) + out.CreateOptions = CloneSliceOfCollateAndCharset(n.CreateOptions) + return &out } -// EqualsRefOfAlterColumn does deep equals between the two objects. -func EqualsRefOfAlterColumn(a, b *AlterColumn) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfCreateTable creates a deep clone of the input. +func CloneRefOfCreateTable(n *CreateTable) *CreateTable { + if n == nil { + return nil } - return a.DropDefault == b.DropDefault && - EqualsRefOfColName(a.Column, b.Column) && - EqualsExpr(a.DefaultVal, b.DefaultVal) + out := *n + out.Table = CloneTableName(n.Table) + out.TableSpec = CloneRefOfTableSpec(n.TableSpec) + out.OptLike = CloneRefOfOptLike(n.OptLike) + return &out } -// CloneRefOfAlterColumn creates a deep clone of the input. -func CloneRefOfAlterColumn(n *AlterColumn) *AlterColumn { +// CloneRefOfCreateView creates a deep clone of the input. +func CloneRefOfCreateView(n *CreateView) *CreateView { if n == nil { return nil } out := *n - out.Column = CloneRefOfColName(n.Column) - out.DefaultVal = CloneExpr(n.DefaultVal) + out.ViewName = CloneTableName(n.ViewName) + out.Columns = CloneColumns(n.Columns) + out.Select = CloneSelectStatement(n.Select) return &out } -// VisitRefOfAlterColumn will visit all parts of the AST -func VisitRefOfAlterColumn(in *AlterColumn, f Visit) error { - if in == nil { +// CloneRefOfCurTimeFuncExpr creates a deep clone of the input. +func CloneRefOfCurTimeFuncExpr(n *CurTimeFuncExpr) *CurTimeFuncExpr { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitRefOfColName(in.Column, f); err != nil { - return err - } - if err := VisitExpr(in.DefaultVal, f); err != nil { - return err - } - return nil + out := *n + out.Name = CloneColIdent(n.Name) + out.Fsp = CloneExpr(n.Fsp) + return &out } -// rewriteRefOfAlterColumn is part of the Rewrite implementation -func (a *application) rewriteRefOfAlterColumn(parent SQLNode, node *AlterColumn, replacer replacerFunc) error { - if node == nil { +// CloneRefOfDefault creates a deep clone of the input. +func CloneRefOfDefault(n *Default) *Default { + if n == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + out := *n + return &out +} + +// CloneRefOfDelete creates a deep clone of the input. +func CloneRefOfDelete(n *Delete) *Delete { + if n == nil { return nil } - if errF := a.rewriteRefOfColName(node, node.Column, func(newNode, parent SQLNode) { - parent.(*AlterColumn).Column = newNode.(*ColName) - }); errF != nil { - return errF - } - if errF := a.rewriteExpr(node, node.DefaultVal, func(newNode, parent SQLNode) { - parent.(*AlterColumn).DefaultVal = newNode.(Expr) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.Comments = CloneComments(n.Comments) + out.Targets = CloneTableNames(n.Targets) + out.TableExprs = CloneTableExprs(n.TableExprs) + out.Partitions = ClonePartitions(n.Partitions) + out.Where = CloneRefOfWhere(n.Where) + out.OrderBy = CloneOrderBy(n.OrderBy) + out.Limit = CloneRefOfLimit(n.Limit) + return &out } -// EqualsRefOfAlterDatabase does deep equals between the two objects. -func EqualsRefOfAlterDatabase(a, b *AlterDatabase) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfDerivedTable creates a deep clone of the input. +func CloneRefOfDerivedTable(n *DerivedTable) *DerivedTable { + if n == nil { + return nil } - return a.DBName == b.DBName && - a.UpdateDataDirectory == b.UpdateDataDirectory && - a.FullyParsed == b.FullyParsed && - EqualsSliceOfCollateAndCharset(a.AlterOptions, b.AlterOptions) + out := *n + out.Select = CloneSelectStatement(n.Select) + return &out } -// CloneRefOfAlterDatabase creates a deep clone of the input. -func CloneRefOfAlterDatabase(n *AlterDatabase) *AlterDatabase { +// CloneRefOfDropColumn creates a deep clone of the input. +func CloneRefOfDropColumn(n *DropColumn) *DropColumn { if n == nil { return nil } out := *n - out.AlterOptions = CloneSliceOfCollateAndCharset(n.AlterOptions) + out.Name = CloneRefOfColName(n.Name) return &out } -// VisitRefOfAlterDatabase will visit all parts of the AST -func VisitRefOfAlterDatabase(in *AlterDatabase, f Visit) error { - if in == nil { +// CloneRefOfDropDatabase creates a deep clone of the input. +func CloneRefOfDropDatabase(n *DropDatabase) *DropDatabase { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil + out := *n + out.Comments = CloneComments(n.Comments) + return &out } -// rewriteRefOfAlterDatabase is part of the Rewrite implementation -func (a *application) rewriteRefOfAlterDatabase(parent SQLNode, node *AlterDatabase, replacer replacerFunc) error { - if node == nil { +// CloneRefOfDropKey creates a deep clone of the input. +func CloneRefOfDropKey(n *DropKey) *DropKey { + if n == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + out := *n + return &out +} + +// CloneRefOfDropTable creates a deep clone of the input. +func CloneRefOfDropTable(n *DropTable) *DropTable { + if n == nil { return nil } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.FromTables = CloneTableNames(n.FromTables) + return &out } -// EqualsRefOfAlterMigration does deep equals between the two objects. -func EqualsRefOfAlterMigration(a, b *AlterMigration) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfDropView creates a deep clone of the input. +func CloneRefOfDropView(n *DropView) *DropView { + if n == nil { + return nil } - return a.UUID == b.UUID && - a.Type == b.Type + out := *n + out.FromTables = CloneTableNames(n.FromTables) + return &out } -// CloneRefOfAlterMigration creates a deep clone of the input. -func CloneRefOfAlterMigration(n *AlterMigration) *AlterMigration { +// CloneRefOfExistsExpr creates a deep clone of the input. +func CloneRefOfExistsExpr(n *ExistsExpr) *ExistsExpr { if n == nil { return nil } out := *n + out.Subquery = CloneRefOfSubquery(n.Subquery) return &out } -// VisitRefOfAlterMigration will visit all parts of the AST -func VisitRefOfAlterMigration(in *AlterMigration, f Visit) error { - if in == nil { +// CloneRefOfExplainStmt creates a deep clone of the input. +func CloneRefOfExplainStmt(n *ExplainStmt) *ExplainStmt { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil + out := *n + out.Statement = CloneStatement(n.Statement) + return &out } -// rewriteRefOfAlterMigration is part of the Rewrite implementation -func (a *application) rewriteRefOfAlterMigration(parent SQLNode, node *AlterMigration, replacer replacerFunc) error { - if node == nil { +// CloneRefOfExplainTab creates a deep clone of the input. +func CloneRefOfExplainTab(n *ExplainTab) *ExplainTab { + if n == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + out := *n + out.Table = CloneTableName(n.Table) + return &out +} + +// CloneRefOfFlush creates a deep clone of the input. +func CloneRefOfFlush(n *Flush) *Flush { + if n == nil { return nil } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.FlushOptions = CloneSliceOfString(n.FlushOptions) + out.TableNames = CloneTableNames(n.TableNames) + return &out } -// EqualsRefOfAlterTable does deep equals between the two objects. -func EqualsRefOfAlterTable(a, b *AlterTable) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfForce creates a deep clone of the input. +func CloneRefOfForce(n *Force) *Force { + if n == nil { + return nil } - return a.FullyParsed == b.FullyParsed && - EqualsTableName(a.Table, b.Table) && - EqualsSliceOfAlterOption(a.AlterOptions, b.AlterOptions) && - EqualsRefOfPartitionSpec(a.PartitionSpec, b.PartitionSpec) + out := *n + return &out } -// CloneRefOfAlterTable creates a deep clone of the input. -func CloneRefOfAlterTable(n *AlterTable) *AlterTable { +// CloneRefOfForeignKeyDefinition creates a deep clone of the input. +func CloneRefOfForeignKeyDefinition(n *ForeignKeyDefinition) *ForeignKeyDefinition { if n == nil { return nil } out := *n - out.Table = CloneTableName(n.Table) - out.AlterOptions = CloneSliceOfAlterOption(n.AlterOptions) - out.PartitionSpec = CloneRefOfPartitionSpec(n.PartitionSpec) + out.Source = CloneColumns(n.Source) + out.ReferencedTable = CloneTableName(n.ReferencedTable) + out.ReferencedColumns = CloneColumns(n.ReferencedColumns) return &out } -// VisitRefOfAlterTable will visit all parts of the AST -func VisitRefOfAlterTable(in *AlterTable, f Visit) error { - if in == nil { +// CloneRefOfFuncExpr creates a deep clone of the input. +func CloneRefOfFuncExpr(n *FuncExpr) *FuncExpr { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableName(in.Table, f); err != nil { - return err - } - for _, el := range in.AlterOptions { - if err := VisitAlterOption(el, f); err != nil { - return err - } - } - if err := VisitRefOfPartitionSpec(in.PartitionSpec, f); err != nil { - return err - } - return nil + out := *n + out.Qualifier = CloneTableIdent(n.Qualifier) + out.Name = CloneColIdent(n.Name) + out.Exprs = CloneSelectExprs(n.Exprs) + return &out } -// rewriteRefOfAlterTable is part of the Rewrite implementation -func (a *application) rewriteRefOfAlterTable(parent SQLNode, node *AlterTable, replacer replacerFunc) error { - if node == nil { +// CloneRefOfGroupConcatExpr creates a deep clone of the input. +func CloneRefOfGroupConcatExpr(n *GroupConcatExpr) *GroupConcatExpr { + if n == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + out := *n + out.Exprs = CloneSelectExprs(n.Exprs) + out.OrderBy = CloneOrderBy(n.OrderBy) + out.Limit = CloneRefOfLimit(n.Limit) + return &out +} + +// CloneRefOfIndexColumn creates a deep clone of the input. +func CloneRefOfIndexColumn(n *IndexColumn) *IndexColumn { + if n == nil { return nil } - if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { - parent.(*AlterTable).Table = newNode.(TableName) - }); errF != nil { - return errF - } - for i, el := range node.AlterOptions { - if errF := a.rewriteAlterOption(node, el, func(newNode, parent SQLNode) { - parent.(*AlterTable).AlterOptions[i] = newNode.(AlterOption) - }); errF != nil { - return errF - } - } - if errF := a.rewriteRefOfPartitionSpec(node, node.PartitionSpec, func(newNode, parent SQLNode) { - parent.(*AlterTable).PartitionSpec = newNode.(*PartitionSpec) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.Column = CloneColIdent(n.Column) + out.Length = CloneRefOfLiteral(n.Length) + return &out } -// EqualsRefOfAlterView does deep equals between the two objects. -func EqualsRefOfAlterView(a, b *AlterView) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfIndexDefinition creates a deep clone of the input. +func CloneRefOfIndexDefinition(n *IndexDefinition) *IndexDefinition { + if n == nil { + return nil } - return a.Algorithm == b.Algorithm && - a.Definer == b.Definer && - a.Security == b.Security && - a.CheckOption == b.CheckOption && - EqualsTableName(a.ViewName, b.ViewName) && - EqualsColumns(a.Columns, b.Columns) && - EqualsSelectStatement(a.Select, b.Select) + out := *n + out.Info = CloneRefOfIndexInfo(n.Info) + out.Columns = CloneSliceOfRefOfIndexColumn(n.Columns) + out.Options = CloneSliceOfRefOfIndexOption(n.Options) + return &out } -// CloneRefOfAlterView creates a deep clone of the input. -func CloneRefOfAlterView(n *AlterView) *AlterView { +// CloneRefOfIndexHints creates a deep clone of the input. +func CloneRefOfIndexHints(n *IndexHints) *IndexHints { if n == nil { return nil } out := *n - out.ViewName = CloneTableName(n.ViewName) - out.Columns = CloneColumns(n.Columns) - out.Select = CloneSelectStatement(n.Select) + out.Indexes = CloneSliceOfColIdent(n.Indexes) return &out } -// VisitRefOfAlterView will visit all parts of the AST -func VisitRefOfAlterView(in *AlterView, f Visit) error { - if in == nil { +// CloneRefOfIndexInfo creates a deep clone of the input. +func CloneRefOfIndexInfo(n *IndexInfo) *IndexInfo { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableName(in.ViewName, f); err != nil { - return err - } - if err := VisitColumns(in.Columns, f); err != nil { - return err - } - if err := VisitSelectStatement(in.Select, f); err != nil { - return err - } - return nil + out := *n + out.Name = CloneColIdent(n.Name) + out.ConstraintName = CloneColIdent(n.ConstraintName) + return &out } -// rewriteRefOfAlterView is part of the Rewrite implementation -func (a *application) rewriteRefOfAlterView(parent SQLNode, node *AlterView, replacer replacerFunc) error { - if node == nil { +// CloneRefOfIndexOption creates a deep clone of the input. +func CloneRefOfIndexOption(n *IndexOption) *IndexOption { + if n == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + out := *n + out.Value = CloneRefOfLiteral(n.Value) + return &out +} + +// CloneRefOfInsert creates a deep clone of the input. +func CloneRefOfInsert(n *Insert) *Insert { + if n == nil { return nil } - if errF := a.rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { - parent.(*AlterView).ViewName = newNode.(TableName) - }); errF != nil { - return errF - } - if errF := a.rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { - parent.(*AlterView).Columns = newNode.(Columns) - }); errF != nil { - return errF - } - if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { - parent.(*AlterView).Select = newNode.(SelectStatement) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.Comments = CloneComments(n.Comments) + out.Table = CloneTableName(n.Table) + out.Partitions = ClonePartitions(n.Partitions) + out.Columns = CloneColumns(n.Columns) + out.Rows = CloneInsertRows(n.Rows) + out.OnDup = CloneOnDup(n.OnDup) + return &out } -// EqualsRefOfAlterVschema does deep equals between the two objects. -func EqualsRefOfAlterVschema(a, b *AlterVschema) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfIntervalExpr creates a deep clone of the input. +func CloneRefOfIntervalExpr(n *IntervalExpr) *IntervalExpr { + if n == nil { + return nil } - return a.Action == b.Action && - EqualsTableName(a.Table, b.Table) && - EqualsRefOfVindexSpec(a.VindexSpec, b.VindexSpec) && - EqualsSliceOfColIdent(a.VindexCols, b.VindexCols) && - EqualsRefOfAutoIncSpec(a.AutoIncSpec, b.AutoIncSpec) + out := *n + out.Expr = CloneExpr(n.Expr) + return &out } -// CloneRefOfAlterVschema creates a deep clone of the input. -func CloneRefOfAlterVschema(n *AlterVschema) *AlterVschema { +// CloneRefOfIsExpr creates a deep clone of the input. +func CloneRefOfIsExpr(n *IsExpr) *IsExpr { if n == nil { return nil } out := *n - out.Table = CloneTableName(n.Table) - out.VindexSpec = CloneRefOfVindexSpec(n.VindexSpec) - out.VindexCols = CloneSliceOfColIdent(n.VindexCols) - out.AutoIncSpec = CloneRefOfAutoIncSpec(n.AutoIncSpec) + out.Expr = CloneExpr(n.Expr) return &out } -// VisitRefOfAlterVschema will visit all parts of the AST -func VisitRefOfAlterVschema(in *AlterVschema, f Visit) error { - if in == nil { +// CloneRefOfJoinCondition creates a deep clone of the input. +func CloneRefOfJoinCondition(n *JoinCondition) *JoinCondition { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableName(in.Table, f); err != nil { - return err - } - if err := VisitRefOfVindexSpec(in.VindexSpec, f); err != nil { - return err - } - for _, el := range in.VindexCols { - if err := VisitColIdent(el, f); err != nil { - return err - } - } - if err := VisitRefOfAutoIncSpec(in.AutoIncSpec, f); err != nil { - return err - } - return nil + out := *n + out.On = CloneExpr(n.On) + out.Using = CloneColumns(n.Using) + return &out } -// rewriteRefOfAlterVschema is part of the Rewrite implementation -func (a *application) rewriteRefOfAlterVschema(parent SQLNode, node *AlterVschema, replacer replacerFunc) error { - if node == nil { +// CloneRefOfJoinTableExpr creates a deep clone of the input. +func CloneRefOfJoinTableExpr(n *JoinTableExpr) *JoinTableExpr { + if n == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + out := *n + out.LeftExpr = CloneTableExpr(n.LeftExpr) + out.RightExpr = CloneTableExpr(n.RightExpr) + out.Condition = CloneJoinCondition(n.Condition) + return &out +} + +// CloneRefOfKeyState creates a deep clone of the input. +func CloneRefOfKeyState(n *KeyState) *KeyState { + if n == nil { return nil } - if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { - parent.(*AlterVschema).Table = newNode.(TableName) - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfVindexSpec(node, node.VindexSpec, func(newNode, parent SQLNode) { - parent.(*AlterVschema).VindexSpec = newNode.(*VindexSpec) - }); errF != nil { - return errF - } - for i, el := range node.VindexCols { - if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { - parent.(*AlterVschema).VindexCols[i] = newNode.(ColIdent) - }); errF != nil { - return errF - } - } - if errF := a.rewriteRefOfAutoIncSpec(node, node.AutoIncSpec, func(newNode, parent SQLNode) { - parent.(*AlterVschema).AutoIncSpec = newNode.(*AutoIncSpec) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + return &out } -// EqualsRefOfAndExpr does deep equals between the two objects. -func EqualsRefOfAndExpr(a, b *AndExpr) bool { - if a == b { - return true +// CloneRefOfLimit creates a deep clone of the input. +func CloneRefOfLimit(n *Limit) *Limit { + if n == nil { + return nil } - if a == nil || b == nil { - return false + out := *n + out.Offset = CloneExpr(n.Offset) + out.Rowcount = CloneExpr(n.Rowcount) + return &out +} + +// CloneRefOfLiteral creates a deep clone of the input. +func CloneRefOfLiteral(n *Literal) *Literal { + if n == nil { + return nil } - return EqualsExpr(a.Left, b.Left) && - EqualsExpr(a.Right, b.Right) + out := *n + return &out } -// CloneRefOfAndExpr creates a deep clone of the input. -func CloneRefOfAndExpr(n *AndExpr) *AndExpr { +// CloneRefOfLoad creates a deep clone of the input. +func CloneRefOfLoad(n *Load) *Load { if n == nil { return nil } out := *n - out.Left = CloneExpr(n.Left) - out.Right = CloneExpr(n.Right) return &out } -// VisitRefOfAndExpr will visit all parts of the AST -func VisitRefOfAndExpr(in *AndExpr, f Visit) error { - if in == nil { +// CloneRefOfLockOption creates a deep clone of the input. +func CloneRefOfLockOption(n *LockOption) *LockOption { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Left, f); err != nil { - return err - } - if err := VisitExpr(in.Right, f); err != nil { - return err - } - return nil + out := *n + return &out } -// rewriteRefOfAndExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +// CloneRefOfLockTables creates a deep clone of the input. +func CloneRefOfLockTables(n *LockTables) *LockTables { + if n == nil { return nil } - if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { - parent.(*AndExpr).Left = newNode.(Expr) - }); errF != nil { - return errF - } - if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { - parent.(*AndExpr).Right = newNode.(Expr) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil -} - -// EqualsRefOfAutoIncSpec does deep equals between the two objects. -func EqualsRefOfAutoIncSpec(a, b *AutoIncSpec) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsColIdent(a.Column, b.Column) && - EqualsTableName(a.Sequence, b.Sequence) + out := *n + out.Tables = CloneTableAndLockTypes(n.Tables) + return &out } -// CloneRefOfAutoIncSpec creates a deep clone of the input. -func CloneRefOfAutoIncSpec(n *AutoIncSpec) *AutoIncSpec { +// CloneRefOfMatchExpr creates a deep clone of the input. +func CloneRefOfMatchExpr(n *MatchExpr) *MatchExpr { if n == nil { return nil } out := *n - out.Column = CloneColIdent(n.Column) - out.Sequence = CloneTableName(n.Sequence) + out.Columns = CloneSelectExprs(n.Columns) + out.Expr = CloneExpr(n.Expr) return &out } -// VisitRefOfAutoIncSpec will visit all parts of the AST -func VisitRefOfAutoIncSpec(in *AutoIncSpec, f Visit) error { - if in == nil { +// CloneRefOfModifyColumn creates a deep clone of the input. +func CloneRefOfModifyColumn(n *ModifyColumn) *ModifyColumn { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitColIdent(in.Column, f); err != nil { - return err - } - if err := VisitTableName(in.Sequence, f); err != nil { - return err - } - return nil + out := *n + out.NewColDefinition = CloneRefOfColumnDefinition(n.NewColDefinition) + out.First = CloneRefOfColName(n.First) + out.After = CloneRefOfColName(n.After) + return &out } -// rewriteRefOfAutoIncSpec is part of the Rewrite implementation -func (a *application) rewriteRefOfAutoIncSpec(parent SQLNode, node *AutoIncSpec, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +// CloneRefOfNextval creates a deep clone of the input. +func CloneRefOfNextval(n *Nextval) *Nextval { + if n == nil { return nil } - if errF := a.rewriteColIdent(node, node.Column, func(newNode, parent SQLNode) { - parent.(*AutoIncSpec).Column = newNode.(ColIdent) - }); errF != nil { - return errF - } - if errF := a.rewriteTableName(node, node.Sequence, func(newNode, parent SQLNode) { - parent.(*AutoIncSpec).Sequence = newNode.(TableName) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil -} - -// EqualsRefOfBegin does deep equals between the two objects. -func EqualsRefOfBegin(a, b *Begin) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return true + out := *n + out.Expr = CloneExpr(n.Expr) + return &out } -// CloneRefOfBegin creates a deep clone of the input. -func CloneRefOfBegin(n *Begin) *Begin { +// CloneRefOfNotExpr creates a deep clone of the input. +func CloneRefOfNotExpr(n *NotExpr) *NotExpr { if n == nil { return nil } out := *n + out.Expr = CloneExpr(n.Expr) return &out } -// VisitRefOfBegin will visit all parts of the AST -func VisitRefOfBegin(in *Begin, f Visit) error { - if in == nil { +// CloneRefOfNullVal creates a deep clone of the input. +func CloneRefOfNullVal(n *NullVal) *NullVal { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil + out := *n + return &out } -// rewriteRefOfBegin is part of the Rewrite implementation -func (a *application) rewriteRefOfBegin(parent SQLNode, node *Begin, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +// CloneRefOfOptLike creates a deep clone of the input. +func CloneRefOfOptLike(n *OptLike) *OptLike { + if n == nil { return nil } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil -} - -// EqualsRefOfBinaryExpr does deep equals between the two objects. -func EqualsRefOfBinaryExpr(a, b *BinaryExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Operator == b.Operator && - EqualsExpr(a.Left, b.Left) && - EqualsExpr(a.Right, b.Right) + out := *n + out.LikeTable = CloneTableName(n.LikeTable) + return &out } -// CloneRefOfBinaryExpr creates a deep clone of the input. -func CloneRefOfBinaryExpr(n *BinaryExpr) *BinaryExpr { +// CloneRefOfOrExpr creates a deep clone of the input. +func CloneRefOfOrExpr(n *OrExpr) *OrExpr { if n == nil { return nil } @@ -2883,1805 +1185,3491 @@ func CloneRefOfBinaryExpr(n *BinaryExpr) *BinaryExpr { return &out } -// VisitRefOfBinaryExpr will visit all parts of the AST -func VisitRefOfBinaryExpr(in *BinaryExpr, f Visit) error { - if in == nil { +// CloneRefOfOrder creates a deep clone of the input. +func CloneRefOfOrder(n *Order) *Order { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Left, f); err != nil { - return err - } - if err := VisitExpr(in.Right, f); err != nil { - return err - } - return nil + out := *n + out.Expr = CloneExpr(n.Expr) + return &out } -// rewriteRefOfBinaryExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfBinaryExpr(parent SQLNode, node *BinaryExpr, replacer replacerFunc) error { - if node == nil { +// CloneRefOfOrderByOption creates a deep clone of the input. +func CloneRefOfOrderByOption(n *OrderByOption) *OrderByOption { + if n == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + out := *n + out.Cols = CloneColumns(n.Cols) + return &out +} + +// CloneRefOfOtherAdmin creates a deep clone of the input. +func CloneRefOfOtherAdmin(n *OtherAdmin) *OtherAdmin { + if n == nil { return nil } - if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { - parent.(*BinaryExpr).Left = newNode.(Expr) - }); errF != nil { - return errF - } - if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { - parent.(*BinaryExpr).Right = newNode.(Expr) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + return &out } -// EqualsRefOfCallProc does deep equals between the two objects. -func EqualsRefOfCallProc(a, b *CallProc) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfOtherRead creates a deep clone of the input. +func CloneRefOfOtherRead(n *OtherRead) *OtherRead { + if n == nil { + return nil } - return EqualsTableName(a.Name, b.Name) && - EqualsExprs(a.Params, b.Params) + out := *n + return &out } -// CloneRefOfCallProc creates a deep clone of the input. -func CloneRefOfCallProc(n *CallProc) *CallProc { +// CloneRefOfParenSelect creates a deep clone of the input. +func CloneRefOfParenSelect(n *ParenSelect) *ParenSelect { if n == nil { return nil } out := *n - out.Name = CloneTableName(n.Name) - out.Params = CloneExprs(n.Params) + out.Select = CloneSelectStatement(n.Select) return &out } -// VisitRefOfCallProc will visit all parts of the AST -func VisitRefOfCallProc(in *CallProc, f Visit) error { - if in == nil { +// CloneRefOfParenTableExpr creates a deep clone of the input. +func CloneRefOfParenTableExpr(n *ParenTableExpr) *ParenTableExpr { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableName(in.Name, f); err != nil { - return err - } - if err := VisitExprs(in.Params, f); err != nil { - return err - } - return nil + out := *n + out.Exprs = CloneTableExprs(n.Exprs) + return &out } -// rewriteRefOfCallProc is part of the Rewrite implementation -func (a *application) rewriteRefOfCallProc(parent SQLNode, node *CallProc, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +// CloneRefOfPartitionDefinition creates a deep clone of the input. +func CloneRefOfPartitionDefinition(n *PartitionDefinition) *PartitionDefinition { + if n == nil { return nil } - if errF := a.rewriteTableName(node, node.Name, func(newNode, parent SQLNode) { - parent.(*CallProc).Name = newNode.(TableName) - }); errF != nil { - return errF - } - if errF := a.rewriteExprs(node, node.Params, func(newNode, parent SQLNode) { - parent.(*CallProc).Params = newNode.(Exprs) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.Name = CloneColIdent(n.Name) + out.Limit = CloneExpr(n.Limit) + return &out } -// EqualsRefOfCaseExpr does deep equals between the two objects. -func EqualsRefOfCaseExpr(a, b *CaseExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfPartitionSpec creates a deep clone of the input. +func CloneRefOfPartitionSpec(n *PartitionSpec) *PartitionSpec { + if n == nil { + return nil } - return EqualsExpr(a.Expr, b.Expr) && - EqualsSliceOfRefOfWhen(a.Whens, b.Whens) && - EqualsExpr(a.Else, b.Else) + out := *n + out.Names = ClonePartitions(n.Names) + out.Number = CloneRefOfLiteral(n.Number) + out.TableName = CloneTableName(n.TableName) + out.Definitions = CloneSliceOfRefOfPartitionDefinition(n.Definitions) + return &out } -// CloneRefOfCaseExpr creates a deep clone of the input. -func CloneRefOfCaseExpr(n *CaseExpr) *CaseExpr { +// CloneRefOfRangeCond creates a deep clone of the input. +func CloneRefOfRangeCond(n *RangeCond) *RangeCond { if n == nil { return nil } out := *n - out.Expr = CloneExpr(n.Expr) - out.Whens = CloneSliceOfRefOfWhen(n.Whens) - out.Else = CloneExpr(n.Else) + out.Left = CloneExpr(n.Left) + out.From = CloneExpr(n.From) + out.To = CloneExpr(n.To) return &out } -// VisitRefOfCaseExpr will visit all parts of the AST -func VisitRefOfCaseExpr(in *CaseExpr, f Visit) error { - if in == nil { +// CloneRefOfRelease creates a deep clone of the input. +func CloneRefOfRelease(n *Release) *Release { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Expr, f); err != nil { - return err - } - for _, el := range in.Whens { - if err := VisitRefOfWhen(el, f); err != nil { - return err - } - } - if err := VisitExpr(in.Else, f); err != nil { - return err - } - return nil + out := *n + out.Name = CloneColIdent(n.Name) + return &out } -// rewriteRefOfCaseExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfCaseExpr(parent SQLNode, node *CaseExpr, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +// CloneRefOfRenameIndex creates a deep clone of the input. +func CloneRefOfRenameIndex(n *RenameIndex) *RenameIndex { + if n == nil { return nil } - if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*CaseExpr).Expr = newNode.(Expr) - }); errF != nil { - return errF - } - for i, el := range node.Whens { - if errF := a.rewriteRefOfWhen(node, el, func(newNode, parent SQLNode) { - parent.(*CaseExpr).Whens[i] = newNode.(*When) - }); errF != nil { - return errF - } - } - if errF := a.rewriteExpr(node, node.Else, func(newNode, parent SQLNode) { - parent.(*CaseExpr).Else = newNode.(Expr) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + return &out } -// EqualsRefOfChangeColumn does deep equals between the two objects. -func EqualsRefOfChangeColumn(a, b *ChangeColumn) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfRenameTable creates a deep clone of the input. +func CloneRefOfRenameTable(n *RenameTable) *RenameTable { + if n == nil { + return nil } - return EqualsRefOfColName(a.OldColumn, b.OldColumn) && - EqualsRefOfColumnDefinition(a.NewColDefinition, b.NewColDefinition) && - EqualsRefOfColName(a.First, b.First) && - EqualsRefOfColName(a.After, b.After) + out := *n + out.TablePairs = CloneSliceOfRefOfRenameTablePair(n.TablePairs) + return &out } -// CloneRefOfChangeColumn creates a deep clone of the input. -func CloneRefOfChangeColumn(n *ChangeColumn) *ChangeColumn { +// CloneRefOfRenameTableName creates a deep clone of the input. +func CloneRefOfRenameTableName(n *RenameTableName) *RenameTableName { if n == nil { return nil } out := *n - out.OldColumn = CloneRefOfColName(n.OldColumn) - out.NewColDefinition = CloneRefOfColumnDefinition(n.NewColDefinition) - out.First = CloneRefOfColName(n.First) - out.After = CloneRefOfColName(n.After) + out.Table = CloneTableName(n.Table) return &out } -// VisitRefOfChangeColumn will visit all parts of the AST -func VisitRefOfChangeColumn(in *ChangeColumn, f Visit) error { - if in == nil { +// CloneRefOfRenameTablePair creates a deep clone of the input. +func CloneRefOfRenameTablePair(n *RenameTablePair) *RenameTablePair { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitRefOfColName(in.OldColumn, f); err != nil { - return err - } - if err := VisitRefOfColumnDefinition(in.NewColDefinition, f); err != nil { - return err - } - if err := VisitRefOfColName(in.First, f); err != nil { - return err - } - if err := VisitRefOfColName(in.After, f); err != nil { - return err - } - return nil + out := *n + out.FromTable = CloneTableName(n.FromTable) + out.ToTable = CloneTableName(n.ToTable) + return &out } -// rewriteRefOfChangeColumn is part of the Rewrite implementation -func (a *application) rewriteRefOfChangeColumn(parent SQLNode, node *ChangeColumn, replacer replacerFunc) error { - if node == nil { +// CloneRefOfRevertMigration creates a deep clone of the input. +func CloneRefOfRevertMigration(n *RevertMigration) *RevertMigration { + if n == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + out := *n + return &out +} + +// CloneRefOfRollback creates a deep clone of the input. +func CloneRefOfRollback(n *Rollback) *Rollback { + if n == nil { return nil } - if errF := a.rewriteRefOfColName(node, node.OldColumn, func(newNode, parent SQLNode) { - parent.(*ChangeColumn).OldColumn = newNode.(*ColName) - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfColumnDefinition(node, node.NewColDefinition, func(newNode, parent SQLNode) { - parent.(*ChangeColumn).NewColDefinition = newNode.(*ColumnDefinition) - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { - parent.(*ChangeColumn).First = newNode.(*ColName) - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { - parent.(*ChangeColumn).After = newNode.(*ColName) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + return &out } -// EqualsRefOfCheckConstraintDefinition does deep equals between the two objects. -func EqualsRefOfCheckConstraintDefinition(a, b *CheckConstraintDefinition) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfSRollback creates a deep clone of the input. +func CloneRefOfSRollback(n *SRollback) *SRollback { + if n == nil { + return nil } - return a.Enforced == b.Enforced && - EqualsExpr(a.Expr, b.Expr) + out := *n + out.Name = CloneColIdent(n.Name) + return &out } -// CloneRefOfCheckConstraintDefinition creates a deep clone of the input. -func CloneRefOfCheckConstraintDefinition(n *CheckConstraintDefinition) *CheckConstraintDefinition { +// CloneRefOfSavepoint creates a deep clone of the input. +func CloneRefOfSavepoint(n *Savepoint) *Savepoint { if n == nil { return nil } out := *n - out.Expr = CloneExpr(n.Expr) + out.Name = CloneColIdent(n.Name) return &out } -// VisitRefOfCheckConstraintDefinition will visit all parts of the AST -func VisitRefOfCheckConstraintDefinition(in *CheckConstraintDefinition, f Visit) error { - if in == nil { +// CloneRefOfSelect creates a deep clone of the input. +func CloneRefOfSelect(n *Select) *Select { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Expr, f); err != nil { - return err - } - return nil + out := *n + out.Cache = CloneRefOfBool(n.Cache) + out.Comments = CloneComments(n.Comments) + out.SelectExprs = CloneSelectExprs(n.SelectExprs) + out.From = CloneTableExprs(n.From) + out.Where = CloneRefOfWhere(n.Where) + out.GroupBy = CloneGroupBy(n.GroupBy) + out.Having = CloneRefOfWhere(n.Having) + out.OrderBy = CloneOrderBy(n.OrderBy) + out.Limit = CloneRefOfLimit(n.Limit) + out.Into = CloneRefOfSelectInto(n.Into) + return &out } -// rewriteRefOfCheckConstraintDefinition is part of the Rewrite implementation -func (a *application) rewriteRefOfCheckConstraintDefinition(parent SQLNode, node *CheckConstraintDefinition, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +// CloneRefOfSelectInto creates a deep clone of the input. +func CloneRefOfSelectInto(n *SelectInto) *SelectInto { + if n == nil { return nil } - if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*CheckConstraintDefinition).Expr = newNode.(Expr) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil -} - -// EqualsColIdent does deep equals between the two objects. -func EqualsColIdent(a, b ColIdent) bool { - return a.val == b.val && - a.lowered == b.lowered && - a.at == b.at -} - -// CloneColIdent creates a deep clone of the input. -func CloneColIdent(n ColIdent) ColIdent { - return *CloneRefOfColIdent(&n) + out := *n + return &out } -// VisitColIdent will visit all parts of the AST -func VisitColIdent(in ColIdent, f Visit) error { - if cont, err := f(in); err != nil || !cont { - return err +// CloneRefOfSet creates a deep clone of the input. +func CloneRefOfSet(n *Set) *Set { + if n == nil { + return nil } - return nil + out := *n + out.Comments = CloneComments(n.Comments) + out.Exprs = CloneSetExprs(n.Exprs) + return &out } -// rewriteColIdent is part of the Rewrite implementation -func (a *application) rewriteColIdent(parent SQLNode, node ColIdent, replacer replacerFunc) error { - var err error - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +// CloneRefOfSetExpr creates a deep clone of the input. +func CloneRefOfSetExpr(n *SetExpr) *SetExpr { + if n == nil { return nil } - if err != nil { - return err - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.Name = CloneColIdent(n.Name) + out.Expr = CloneExpr(n.Expr) + return &out } -// EqualsRefOfColName does deep equals between the two objects. -func EqualsRefOfColName(a, b *ColName) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfSetTransaction creates a deep clone of the input. +func CloneRefOfSetTransaction(n *SetTransaction) *SetTransaction { + if n == nil { + return nil } - return EqualsColIdent(a.Name, b.Name) && - EqualsTableName(a.Qualifier, b.Qualifier) -} - -// CloneRefOfColName creates a deep clone of the input. -func CloneRefOfColName(n *ColName) *ColName { - return n + out := *n + out.SQLNode = CloneSQLNode(n.SQLNode) + out.Comments = CloneComments(n.Comments) + out.Characteristics = CloneSliceOfCharacteristic(n.Characteristics) + return &out } -// VisitRefOfColName will visit all parts of the AST -func VisitRefOfColName(in *ColName, f Visit) error { - if in == nil { +// CloneRefOfShow creates a deep clone of the input. +func CloneRefOfShow(n *Show) *Show { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitColIdent(in.Name, f); err != nil { - return err - } - if err := VisitTableName(in.Qualifier, f); err != nil { - return err - } - return nil + out := *n + out.Internal = CloneShowInternal(n.Internal) + return &out } -// rewriteRefOfColName is part of the Rewrite implementation -func (a *application) rewriteRefOfColName(parent SQLNode, node *ColName, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +// CloneRefOfShowBasic creates a deep clone of the input. +func CloneRefOfShowBasic(n *ShowBasic) *ShowBasic { + if n == nil { return nil } - if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { - parent.(*ColName).Name = newNode.(ColIdent) - }); errF != nil { - return errF - } - if errF := a.rewriteTableName(node, node.Qualifier, func(newNode, parent SQLNode) { - parent.(*ColName).Qualifier = newNode.(TableName) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.Tbl = CloneTableName(n.Tbl) + out.Filter = CloneRefOfShowFilter(n.Filter) + return &out } -// EqualsRefOfCollateExpr does deep equals between the two objects. -func EqualsRefOfCollateExpr(a, b *CollateExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfShowCreate creates a deep clone of the input. +func CloneRefOfShowCreate(n *ShowCreate) *ShowCreate { + if n == nil { + return nil } - return a.Charset == b.Charset && - EqualsExpr(a.Expr, b.Expr) + out := *n + out.Op = CloneTableName(n.Op) + return &out } -// CloneRefOfCollateExpr creates a deep clone of the input. -func CloneRefOfCollateExpr(n *CollateExpr) *CollateExpr { +// CloneRefOfShowFilter creates a deep clone of the input. +func CloneRefOfShowFilter(n *ShowFilter) *ShowFilter { if n == nil { return nil } out := *n - out.Expr = CloneExpr(n.Expr) + out.Filter = CloneExpr(n.Filter) return &out } -// VisitRefOfCollateExpr will visit all parts of the AST -func VisitRefOfCollateExpr(in *CollateExpr, f Visit) error { - if in == nil { +// CloneRefOfShowLegacy creates a deep clone of the input. +func CloneRefOfShowLegacy(n *ShowLegacy) *ShowLegacy { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Expr, f); err != nil { - return err - } - return nil + out := *n + out.OnTable = CloneTableName(n.OnTable) + out.Table = CloneTableName(n.Table) + out.ShowTablesOpt = CloneRefOfShowTablesOpt(n.ShowTablesOpt) + out.ShowCollationFilterOpt = CloneExpr(n.ShowCollationFilterOpt) + return &out } -// rewriteRefOfCollateExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfCollateExpr(parent SQLNode, node *CollateExpr, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +// CloneRefOfShowTablesOpt creates a deep clone of the input. +func CloneRefOfShowTablesOpt(n *ShowTablesOpt) *ShowTablesOpt { + if n == nil { return nil } - if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*CollateExpr).Expr = newNode.(Expr) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.Filter = CloneRefOfShowFilter(n.Filter) + return &out } -// EqualsRefOfColumnDefinition does deep equals between the two objects. -func EqualsRefOfColumnDefinition(a, b *ColumnDefinition) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfStarExpr creates a deep clone of the input. +func CloneRefOfStarExpr(n *StarExpr) *StarExpr { + if n == nil { + return nil } - return EqualsColIdent(a.Name, b.Name) && - EqualsColumnType(a.Type, b.Type) + out := *n + out.TableName = CloneTableName(n.TableName) + return &out } -// CloneRefOfColumnDefinition creates a deep clone of the input. -func CloneRefOfColumnDefinition(n *ColumnDefinition) *ColumnDefinition { +// CloneRefOfStream creates a deep clone of the input. +func CloneRefOfStream(n *Stream) *Stream { if n == nil { return nil } out := *n - out.Name = CloneColIdent(n.Name) - out.Type = CloneColumnType(n.Type) + out.Comments = CloneComments(n.Comments) + out.SelectExpr = CloneSelectExpr(n.SelectExpr) + out.Table = CloneTableName(n.Table) return &out } -// VisitRefOfColumnDefinition will visit all parts of the AST -func VisitRefOfColumnDefinition(in *ColumnDefinition, f Visit) error { - if in == nil { +// CloneRefOfSubquery creates a deep clone of the input. +func CloneRefOfSubquery(n *Subquery) *Subquery { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitColIdent(in.Name, f); err != nil { - return err - } - return nil + out := *n + out.Select = CloneSelectStatement(n.Select) + return &out } -// rewriteRefOfColumnDefinition is part of the Rewrite implementation -func (a *application) rewriteRefOfColumnDefinition(parent SQLNode, node *ColumnDefinition, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +// CloneRefOfSubstrExpr creates a deep clone of the input. +func CloneRefOfSubstrExpr(n *SubstrExpr) *SubstrExpr { + if n == nil { return nil } - if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { - parent.(*ColumnDefinition).Name = newNode.(ColIdent) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.Name = CloneRefOfColName(n.Name) + out.StrVal = CloneRefOfLiteral(n.StrVal) + out.From = CloneExpr(n.From) + out.To = CloneExpr(n.To) + return &out } -// EqualsRefOfColumnType does deep equals between the two objects. -func EqualsRefOfColumnType(a, b *ColumnType) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfTableAndLockType creates a deep clone of the input. +func CloneRefOfTableAndLockType(n *TableAndLockType) *TableAndLockType { + if n == nil { + return nil } - return a.Type == b.Type && - a.Unsigned == b.Unsigned && - a.Zerofill == b.Zerofill && - a.Charset == b.Charset && - a.Collate == b.Collate && - EqualsRefOfColumnTypeOptions(a.Options, b.Options) && - EqualsRefOfLiteral(a.Length, b.Length) && - EqualsRefOfLiteral(a.Scale, b.Scale) && - EqualsSliceOfString(a.EnumValues, b.EnumValues) + out := *n + out.Table = CloneTableExpr(n.Table) + return &out } -// CloneRefOfColumnType creates a deep clone of the input. -func CloneRefOfColumnType(n *ColumnType) *ColumnType { +// CloneRefOfTableIdent creates a deep clone of the input. +func CloneRefOfTableIdent(n *TableIdent) *TableIdent { if n == nil { return nil } out := *n - out.Options = CloneRefOfColumnTypeOptions(n.Options) - out.Length = CloneRefOfLiteral(n.Length) - out.Scale = CloneRefOfLiteral(n.Scale) - out.EnumValues = CloneSliceOfString(n.EnumValues) return &out } -// VisitRefOfColumnType will visit all parts of the AST -func VisitRefOfColumnType(in *ColumnType, f Visit) error { - if in == nil { +// CloneRefOfTableName creates a deep clone of the input. +func CloneRefOfTableName(n *TableName) *TableName { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitRefOfLiteral(in.Length, f); err != nil { - return err - } - if err := VisitRefOfLiteral(in.Scale, f); err != nil { - return err - } - return nil + out := *n + out.Name = CloneTableIdent(n.Name) + out.Qualifier = CloneTableIdent(n.Qualifier) + return &out } -// rewriteRefOfColumnType is part of the Rewrite implementation -func (a *application) rewriteRefOfColumnType(parent SQLNode, node *ColumnType, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +// CloneRefOfTableOption creates a deep clone of the input. +func CloneRefOfTableOption(n *TableOption) *TableOption { + if n == nil { return nil } - if errF := a.rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { - parent.(*ColumnType).Length = newNode.(*Literal) - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfLiteral(node, node.Scale, func(newNode, parent SQLNode) { - parent.(*ColumnType).Scale = newNode.(*Literal) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.Value = CloneRefOfLiteral(n.Value) + out.Tables = CloneTableNames(n.Tables) + return &out } -// EqualsColumns does deep equals between the two objects. -func EqualsColumns(a, b Columns) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsColIdent(a[i], b[i]) { - return false - } +// CloneRefOfTableSpec creates a deep clone of the input. +func CloneRefOfTableSpec(n *TableSpec) *TableSpec { + if n == nil { + return nil } - return true + out := *n + out.Columns = CloneSliceOfRefOfColumnDefinition(n.Columns) + out.Indexes = CloneSliceOfRefOfIndexDefinition(n.Indexes) + out.Constraints = CloneSliceOfRefOfConstraintDefinition(n.Constraints) + out.Options = CloneTableOptions(n.Options) + return &out } -// CloneColumns creates a deep clone of the input. -func CloneColumns(n Columns) Columns { - res := make(Columns, 0, len(n)) - for _, x := range n { - res = append(res, CloneColIdent(x)) +// CloneRefOfTablespaceOperation creates a deep clone of the input. +func CloneRefOfTablespaceOperation(n *TablespaceOperation) *TablespaceOperation { + if n == nil { + return nil } - return res + out := *n + return &out } -// VisitColumns will visit all parts of the AST -func VisitColumns(in Columns, f Visit) error { - if in == nil { +// CloneRefOfTimestampFuncExpr creates a deep clone of the input. +func CloneRefOfTimestampFuncExpr(n *TimestampFuncExpr) *TimestampFuncExpr { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - for _, el := range in { - if err := VisitColIdent(el, f); err != nil { - return err - } - } - return nil + out := *n + out.Expr1 = CloneExpr(n.Expr1) + out.Expr2 = CloneExpr(n.Expr2) + return &out } -// rewriteColumns is part of the Rewrite implementation -func (a *application) rewriteColumns(parent SQLNode, node Columns, replacer replacerFunc) error { - if node == nil { +// CloneRefOfTruncateTable creates a deep clone of the input. +func CloneRefOfTruncateTable(n *TruncateTable) *TruncateTable { + if n == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + out := *n + out.Table = CloneTableName(n.Table) + return &out +} + +// CloneRefOfUnaryExpr creates a deep clone of the input. +func CloneRefOfUnaryExpr(n *UnaryExpr) *UnaryExpr { + if n == nil { return nil } - for i, el := range node { - if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { - parent.(Columns)[i] = newNode.(ColIdent) - }); errF != nil { - return errF - } - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + out := *n + out.Expr = CloneExpr(n.Expr) + return &out } -// EqualsComments does deep equals between the two objects. -func EqualsComments(a, b Comments) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if a[i] != b[i] { - return false - } +// CloneRefOfUnion creates a deep clone of the input. +func CloneRefOfUnion(n *Union) *Union { + if n == nil { + return nil } - return true + out := *n + out.FirstStatement = CloneSelectStatement(n.FirstStatement) + out.UnionSelects = CloneSliceOfRefOfUnionSelect(n.UnionSelects) + out.OrderBy = CloneOrderBy(n.OrderBy) + out.Limit = CloneRefOfLimit(n.Limit) + return &out } -// CloneComments creates a deep clone of the input. -func CloneComments(n Comments) Comments { - res := make(Comments, 0, len(n)) - copy(res, n) - return res +// CloneRefOfUnionSelect creates a deep clone of the input. +func CloneRefOfUnionSelect(n *UnionSelect) *UnionSelect { + if n == nil { + return nil + } + out := *n + out.Statement = CloneSelectStatement(n.Statement) + return &out } -// VisitComments will visit all parts of the AST -func VisitComments(in Comments, f Visit) error { - _, err := f(in) - return err +// CloneRefOfUnlockTables creates a deep clone of the input. +func CloneRefOfUnlockTables(n *UnlockTables) *UnlockTables { + if n == nil { + return nil + } + out := *n + return &out } -// rewriteComments is part of the Rewrite implementation -func (a *application) rewriteComments(parent SQLNode, node Comments, replacer replacerFunc) error { - if node == nil { +// CloneRefOfUpdate creates a deep clone of the input. +func CloneRefOfUpdate(n *Update) *Update { + if n == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil -} - -// EqualsRefOfCommit does deep equals between the two objects. -func EqualsRefOfCommit(a, b *Commit) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return true + out := *n + out.Comments = CloneComments(n.Comments) + out.TableExprs = CloneTableExprs(n.TableExprs) + out.Exprs = CloneUpdateExprs(n.Exprs) + out.Where = CloneRefOfWhere(n.Where) + out.OrderBy = CloneOrderBy(n.OrderBy) + out.Limit = CloneRefOfLimit(n.Limit) + return &out } -// CloneRefOfCommit creates a deep clone of the input. -func CloneRefOfCommit(n *Commit) *Commit { +// CloneRefOfUpdateExpr creates a deep clone of the input. +func CloneRefOfUpdateExpr(n *UpdateExpr) *UpdateExpr { if n == nil { return nil } out := *n + out.Name = CloneRefOfColName(n.Name) + out.Expr = CloneExpr(n.Expr) return &out } -// VisitRefOfCommit will visit all parts of the AST -func VisitRefOfCommit(in *Commit, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil -} - -// rewriteRefOfCommit is part of the Rewrite implementation -func (a *application) rewriteRefOfCommit(parent SQLNode, node *Commit, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +// CloneRefOfUse creates a deep clone of the input. +func CloneRefOfUse(n *Use) *Use { + if n == nil { return nil } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil -} - -// EqualsRefOfComparisonExpr does deep equals between the two objects. -func EqualsRefOfComparisonExpr(a, b *ComparisonExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Operator == b.Operator && - EqualsExpr(a.Left, b.Left) && - EqualsExpr(a.Right, b.Right) && - EqualsExpr(a.Escape, b.Escape) + out := *n + out.DBName = CloneTableIdent(n.DBName) + return &out } -// CloneRefOfComparisonExpr creates a deep clone of the input. -func CloneRefOfComparisonExpr(n *ComparisonExpr) *ComparisonExpr { +// CloneRefOfVStream creates a deep clone of the input. +func CloneRefOfVStream(n *VStream) *VStream { if n == nil { return nil } out := *n - out.Left = CloneExpr(n.Left) - out.Right = CloneExpr(n.Right) - out.Escape = CloneExpr(n.Escape) + out.Comments = CloneComments(n.Comments) + out.SelectExpr = CloneSelectExpr(n.SelectExpr) + out.Table = CloneTableName(n.Table) + out.Where = CloneRefOfWhere(n.Where) + out.Limit = CloneRefOfLimit(n.Limit) return &out } -// VisitRefOfComparisonExpr will visit all parts of the AST -func VisitRefOfComparisonExpr(in *ComparisonExpr, f Visit) error { - if in == nil { +// CloneRefOfValidation creates a deep clone of the input. +func CloneRefOfValidation(n *Validation) *Validation { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Left, f); err != nil { - return err - } - if err := VisitExpr(in.Right, f); err != nil { - return err - } - if err := VisitExpr(in.Escape, f); err != nil { - return err - } - return nil + out := *n + return &out } -// rewriteRefOfComparisonExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfComparisonExpr(parent SQLNode, node *ComparisonExpr, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +// CloneRefOfValuesFuncExpr creates a deep clone of the input. +func CloneRefOfValuesFuncExpr(n *ValuesFuncExpr) *ValuesFuncExpr { + if n == nil { return nil } - if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { - parent.(*ComparisonExpr).Left = newNode.(Expr) - }); errF != nil { - return errF - } - if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { - parent.(*ComparisonExpr).Right = newNode.(Expr) - }); errF != nil { - return errF - } - if errF := a.rewriteExpr(node, node.Escape, func(newNode, parent SQLNode) { - parent.(*ComparisonExpr).Escape = newNode.(Expr) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil -} - -// EqualsRefOfConstraintDefinition does deep equals between the two objects. -func EqualsRefOfConstraintDefinition(a, b *ConstraintDefinition) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Name == b.Name && - EqualsConstraintInfo(a.Details, b.Details) + out := *n + out.Name = CloneRefOfColName(n.Name) + return &out } -// CloneRefOfConstraintDefinition creates a deep clone of the input. -func CloneRefOfConstraintDefinition(n *ConstraintDefinition) *ConstraintDefinition { +// CloneRefOfVindexParam creates a deep clone of the input. +func CloneRefOfVindexParam(n *VindexParam) *VindexParam { if n == nil { return nil } out := *n - out.Details = CloneConstraintInfo(n.Details) + out.Key = CloneColIdent(n.Key) return &out } -// VisitRefOfConstraintDefinition will visit all parts of the AST -func VisitRefOfConstraintDefinition(in *ConstraintDefinition, f Visit) error { - if in == nil { +// CloneRefOfVindexSpec creates a deep clone of the input. +func CloneRefOfVindexSpec(n *VindexSpec) *VindexSpec { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitConstraintInfo(in.Details, f); err != nil { - return err - } - return nil + out := *n + out.Name = CloneColIdent(n.Name) + out.Type = CloneColIdent(n.Type) + out.Params = CloneSliceOfVindexParam(n.Params) + return &out } -// rewriteRefOfConstraintDefinition is part of the Rewrite implementation -func (a *application) rewriteRefOfConstraintDefinition(parent SQLNode, node *ConstraintDefinition, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +// CloneRefOfWhen creates a deep clone of the input. +func CloneRefOfWhen(n *When) *When { + if n == nil { return nil } - if errF := a.rewriteConstraintInfo(node, node.Details, func(newNode, parent SQLNode) { - parent.(*ConstraintDefinition).Details = newNode.(ConstraintInfo) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil -} - -// EqualsRefOfConvertExpr does deep equals between the two objects. -func EqualsRefOfConvertExpr(a, b *ConvertExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsExpr(a.Expr, b.Expr) && - EqualsRefOfConvertType(a.Type, b.Type) + out := *n + out.Cond = CloneExpr(n.Cond) + out.Val = CloneExpr(n.Val) + return &out } -// CloneRefOfConvertExpr creates a deep clone of the input. -func CloneRefOfConvertExpr(n *ConvertExpr) *ConvertExpr { +// CloneRefOfWhere creates a deep clone of the input. +func CloneRefOfWhere(n *Where) *Where { if n == nil { return nil } out := *n out.Expr = CloneExpr(n.Expr) - out.Type = CloneRefOfConvertType(n.Type) return &out } -// VisitRefOfConvertExpr will visit all parts of the AST -func VisitRefOfConvertExpr(in *ConvertExpr, f Visit) error { - if in == nil { +// CloneRefOfXorExpr creates a deep clone of the input. +func CloneRefOfXorExpr(n *XorExpr) *XorExpr { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Expr, f); err != nil { - return err - } - if err := VisitRefOfConvertType(in.Type, f); err != nil { - return err - } - return nil + out := *n + out.Left = CloneExpr(n.Left) + out.Right = CloneExpr(n.Right) + return &out } -// rewriteRefOfConvertExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfConvertExpr(parent SQLNode, node *ConvertExpr, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +// CloneSQLNode creates a deep clone of the input. +func CloneSQLNode(in SQLNode) SQLNode { + if in == nil { return nil } - if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*ConvertExpr).Expr = newNode.(Expr) - }); errF != nil { - return errF + switch in := in.(type) { + case AccessMode: + return in + case *AddColumns: + return CloneRefOfAddColumns(in) + case *AddConstraintDefinition: + return CloneRefOfAddConstraintDefinition(in) + case *AddIndexDefinition: + return CloneRefOfAddIndexDefinition(in) + case AlgorithmValue: + return in + case *AliasedExpr: + return CloneRefOfAliasedExpr(in) + case *AliasedTableExpr: + return CloneRefOfAliasedTableExpr(in) + case *AlterCharset: + return CloneRefOfAlterCharset(in) + case *AlterColumn: + return CloneRefOfAlterColumn(in) + case *AlterDatabase: + return CloneRefOfAlterDatabase(in) + case *AlterMigration: + return CloneRefOfAlterMigration(in) + case *AlterTable: + return CloneRefOfAlterTable(in) + case *AlterView: + return CloneRefOfAlterView(in) + case *AlterVschema: + return CloneRefOfAlterVschema(in) + case *AndExpr: + return CloneRefOfAndExpr(in) + case Argument: + return in + case *AutoIncSpec: + return CloneRefOfAutoIncSpec(in) + case *Begin: + return CloneRefOfBegin(in) + case *BinaryExpr: + return CloneRefOfBinaryExpr(in) + case BoolVal: + return in + case *CallProc: + return CloneRefOfCallProc(in) + case *CaseExpr: + return CloneRefOfCaseExpr(in) + case *ChangeColumn: + return CloneRefOfChangeColumn(in) + case *CheckConstraintDefinition: + return CloneRefOfCheckConstraintDefinition(in) + case ColIdent: + return CloneColIdent(in) + case *ColName: + return CloneRefOfColName(in) + case *CollateExpr: + return CloneRefOfCollateExpr(in) + case *ColumnDefinition: + return CloneRefOfColumnDefinition(in) + case *ColumnType: + return CloneRefOfColumnType(in) + case Columns: + return CloneColumns(in) + case Comments: + return CloneComments(in) + case *Commit: + return CloneRefOfCommit(in) + case *ComparisonExpr: + return CloneRefOfComparisonExpr(in) + case *ConstraintDefinition: + return CloneRefOfConstraintDefinition(in) + case *ConvertExpr: + return CloneRefOfConvertExpr(in) + case *ConvertType: + return CloneRefOfConvertType(in) + case *ConvertUsingExpr: + return CloneRefOfConvertUsingExpr(in) + case *CreateDatabase: + return CloneRefOfCreateDatabase(in) + case *CreateTable: + return CloneRefOfCreateTable(in) + case *CreateView: + return CloneRefOfCreateView(in) + case *CurTimeFuncExpr: + return CloneRefOfCurTimeFuncExpr(in) + case *Default: + return CloneRefOfDefault(in) + case *Delete: + return CloneRefOfDelete(in) + case *DerivedTable: + return CloneRefOfDerivedTable(in) + case *DropColumn: + return CloneRefOfDropColumn(in) + case *DropDatabase: + return CloneRefOfDropDatabase(in) + case *DropKey: + return CloneRefOfDropKey(in) + case *DropTable: + return CloneRefOfDropTable(in) + case *DropView: + return CloneRefOfDropView(in) + case *ExistsExpr: + return CloneRefOfExistsExpr(in) + case *ExplainStmt: + return CloneRefOfExplainStmt(in) + case *ExplainTab: + return CloneRefOfExplainTab(in) + case Exprs: + return CloneExprs(in) + case *Flush: + return CloneRefOfFlush(in) + case *Force: + return CloneRefOfForce(in) + case *ForeignKeyDefinition: + return CloneRefOfForeignKeyDefinition(in) + case *FuncExpr: + return CloneRefOfFuncExpr(in) + case GroupBy: + return CloneGroupBy(in) + case *GroupConcatExpr: + return CloneRefOfGroupConcatExpr(in) + case *IndexDefinition: + return CloneRefOfIndexDefinition(in) + case *IndexHints: + return CloneRefOfIndexHints(in) + case *IndexInfo: + return CloneRefOfIndexInfo(in) + case *Insert: + return CloneRefOfInsert(in) + case *IntervalExpr: + return CloneRefOfIntervalExpr(in) + case *IsExpr: + return CloneRefOfIsExpr(in) + case IsolationLevel: + return in + case JoinCondition: + return CloneJoinCondition(in) + case *JoinTableExpr: + return CloneRefOfJoinTableExpr(in) + case *KeyState: + return CloneRefOfKeyState(in) + case *Limit: + return CloneRefOfLimit(in) + case ListArg: + return CloneListArg(in) + case *Literal: + return CloneRefOfLiteral(in) + case *Load: + return CloneRefOfLoad(in) + case *LockOption: + return CloneRefOfLockOption(in) + case *LockTables: + return CloneRefOfLockTables(in) + case *MatchExpr: + return CloneRefOfMatchExpr(in) + case *ModifyColumn: + return CloneRefOfModifyColumn(in) + case *Nextval: + return CloneRefOfNextval(in) + case *NotExpr: + return CloneRefOfNotExpr(in) + case *NullVal: + return CloneRefOfNullVal(in) + case OnDup: + return CloneOnDup(in) + case *OptLike: + return CloneRefOfOptLike(in) + case *OrExpr: + return CloneRefOfOrExpr(in) + case *Order: + return CloneRefOfOrder(in) + case OrderBy: + return CloneOrderBy(in) + case *OrderByOption: + return CloneRefOfOrderByOption(in) + case *OtherAdmin: + return CloneRefOfOtherAdmin(in) + case *OtherRead: + return CloneRefOfOtherRead(in) + case *ParenSelect: + return CloneRefOfParenSelect(in) + case *ParenTableExpr: + return CloneRefOfParenTableExpr(in) + case *PartitionDefinition: + return CloneRefOfPartitionDefinition(in) + case *PartitionSpec: + return CloneRefOfPartitionSpec(in) + case Partitions: + return ClonePartitions(in) + case *RangeCond: + return CloneRefOfRangeCond(in) + case ReferenceAction: + return in + case *Release: + return CloneRefOfRelease(in) + case *RenameIndex: + return CloneRefOfRenameIndex(in) + case *RenameTable: + return CloneRefOfRenameTable(in) + case *RenameTableName: + return CloneRefOfRenameTableName(in) + case *RevertMigration: + return CloneRefOfRevertMigration(in) + case *Rollback: + return CloneRefOfRollback(in) + case *SRollback: + return CloneRefOfSRollback(in) + case *Savepoint: + return CloneRefOfSavepoint(in) + case *Select: + return CloneRefOfSelect(in) + case SelectExprs: + return CloneSelectExprs(in) + case *SelectInto: + return CloneRefOfSelectInto(in) + case *Set: + return CloneRefOfSet(in) + case *SetExpr: + return CloneRefOfSetExpr(in) + case SetExprs: + return CloneSetExprs(in) + case *SetTransaction: + return CloneRefOfSetTransaction(in) + case *Show: + return CloneRefOfShow(in) + case *ShowBasic: + return CloneRefOfShowBasic(in) + case *ShowCreate: + return CloneRefOfShowCreate(in) + case *ShowFilter: + return CloneRefOfShowFilter(in) + case *ShowLegacy: + return CloneRefOfShowLegacy(in) + case *StarExpr: + return CloneRefOfStarExpr(in) + case *Stream: + return CloneRefOfStream(in) + case *Subquery: + return CloneRefOfSubquery(in) + case *SubstrExpr: + return CloneRefOfSubstrExpr(in) + case TableExprs: + return CloneTableExprs(in) + case TableIdent: + return CloneTableIdent(in) + case TableName: + return CloneTableName(in) + case TableNames: + return CloneTableNames(in) + case TableOptions: + return CloneTableOptions(in) + case *TableSpec: + return CloneRefOfTableSpec(in) + case *TablespaceOperation: + return CloneRefOfTablespaceOperation(in) + case *TimestampFuncExpr: + return CloneRefOfTimestampFuncExpr(in) + case *TruncateTable: + return CloneRefOfTruncateTable(in) + case *UnaryExpr: + return CloneRefOfUnaryExpr(in) + case *Union: + return CloneRefOfUnion(in) + case *UnionSelect: + return CloneRefOfUnionSelect(in) + case *UnlockTables: + return CloneRefOfUnlockTables(in) + case *Update: + return CloneRefOfUpdate(in) + case *UpdateExpr: + return CloneRefOfUpdateExpr(in) + case UpdateExprs: + return CloneUpdateExprs(in) + case *Use: + return CloneRefOfUse(in) + case *VStream: + return CloneRefOfVStream(in) + case ValTuple: + return CloneValTuple(in) + case *Validation: + return CloneRefOfValidation(in) + case Values: + return CloneValues(in) + case *ValuesFuncExpr: + return CloneRefOfValuesFuncExpr(in) + case VindexParam: + return CloneVindexParam(in) + case *VindexSpec: + return CloneRefOfVindexSpec(in) + case *When: + return CloneRefOfWhen(in) + case *Where: + return CloneRefOfWhere(in) + case *XorExpr: + return CloneRefOfXorExpr(in) + default: + // this should never happen + return nil + } +} + +// CloneSelectExpr creates a deep clone of the input. +func CloneSelectExpr(in SelectExpr) SelectExpr { + if in == nil { + return nil + } + switch in := in.(type) { + case *AliasedExpr: + return CloneRefOfAliasedExpr(in) + case *Nextval: + return CloneRefOfNextval(in) + case *StarExpr: + return CloneRefOfStarExpr(in) + default: + // this should never happen + return nil + } +} + +// CloneSelectExprs creates a deep clone of the input. +func CloneSelectExprs(n SelectExprs) SelectExprs { + res := make(SelectExprs, 0, len(n)) + for _, x := range n { + res = append(res, CloneSelectExpr(x)) + } + return res +} + +// CloneSelectStatement creates a deep clone of the input. +func CloneSelectStatement(in SelectStatement) SelectStatement { + if in == nil { + return nil + } + switch in := in.(type) { + case *ParenSelect: + return CloneRefOfParenSelect(in) + case *Select: + return CloneRefOfSelect(in) + case *Union: + return CloneRefOfUnion(in) + default: + // this should never happen + return nil + } +} + +// CloneSetExprs creates a deep clone of the input. +func CloneSetExprs(n SetExprs) SetExprs { + res := make(SetExprs, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfSetExpr(x)) + } + return res +} + +// CloneShowInternal creates a deep clone of the input. +func CloneShowInternal(in ShowInternal) ShowInternal { + if in == nil { + return nil + } + switch in := in.(type) { + case *ShowBasic: + return CloneRefOfShowBasic(in) + case *ShowCreate: + return CloneRefOfShowCreate(in) + case *ShowLegacy: + return CloneRefOfShowLegacy(in) + default: + // this should never happen + return nil + } +} + +// CloneSimpleTableExpr creates a deep clone of the input. +func CloneSimpleTableExpr(in SimpleTableExpr) SimpleTableExpr { + if in == nil { + return nil + } + switch in := in.(type) { + case *DerivedTable: + return CloneRefOfDerivedTable(in) + case TableName: + return CloneTableName(in) + default: + // this should never happen + return nil + } +} + +// CloneSliceOfAlterOption creates a deep clone of the input. +func CloneSliceOfAlterOption(n []AlterOption) []AlterOption { + res := make([]AlterOption, 0, len(n)) + for _, x := range n { + res = append(res, CloneAlterOption(x)) + } + return res +} + +// CloneSliceOfCharacteristic creates a deep clone of the input. +func CloneSliceOfCharacteristic(n []Characteristic) []Characteristic { + res := make([]Characteristic, 0, len(n)) + for _, x := range n { + res = append(res, CloneCharacteristic(x)) + } + return res +} + +// CloneSliceOfColIdent creates a deep clone of the input. +func CloneSliceOfColIdent(n []ColIdent) []ColIdent { + res := make([]ColIdent, 0, len(n)) + for _, x := range n { + res = append(res, CloneColIdent(x)) + } + return res +} + +// CloneSliceOfCollateAndCharset creates a deep clone of the input. +func CloneSliceOfCollateAndCharset(n []CollateAndCharset) []CollateAndCharset { + res := make([]CollateAndCharset, 0, len(n)) + for _, x := range n { + res = append(res, CloneCollateAndCharset(x)) + } + return res +} + +// CloneSliceOfRefOfColumnDefinition creates a deep clone of the input. +func CloneSliceOfRefOfColumnDefinition(n []*ColumnDefinition) []*ColumnDefinition { + res := make([]*ColumnDefinition, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfColumnDefinition(x)) + } + return res +} + +// CloneSliceOfRefOfConstraintDefinition creates a deep clone of the input. +func CloneSliceOfRefOfConstraintDefinition(n []*ConstraintDefinition) []*ConstraintDefinition { + res := make([]*ConstraintDefinition, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfConstraintDefinition(x)) + } + return res +} + +// CloneSliceOfRefOfIndexColumn creates a deep clone of the input. +func CloneSliceOfRefOfIndexColumn(n []*IndexColumn) []*IndexColumn { + res := make([]*IndexColumn, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfIndexColumn(x)) + } + return res +} + +// CloneSliceOfRefOfIndexDefinition creates a deep clone of the input. +func CloneSliceOfRefOfIndexDefinition(n []*IndexDefinition) []*IndexDefinition { + res := make([]*IndexDefinition, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfIndexDefinition(x)) + } + return res +} + +// CloneSliceOfRefOfIndexOption creates a deep clone of the input. +func CloneSliceOfRefOfIndexOption(n []*IndexOption) []*IndexOption { + res := make([]*IndexOption, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfIndexOption(x)) + } + return res +} + +// CloneSliceOfRefOfPartitionDefinition creates a deep clone of the input. +func CloneSliceOfRefOfPartitionDefinition(n []*PartitionDefinition) []*PartitionDefinition { + res := make([]*PartitionDefinition, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfPartitionDefinition(x)) + } + return res +} + +// CloneSliceOfRefOfRenameTablePair creates a deep clone of the input. +func CloneSliceOfRefOfRenameTablePair(n []*RenameTablePair) []*RenameTablePair { + res := make([]*RenameTablePair, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfRenameTablePair(x)) + } + return res +} + +// CloneSliceOfRefOfUnionSelect creates a deep clone of the input. +func CloneSliceOfRefOfUnionSelect(n []*UnionSelect) []*UnionSelect { + res := make([]*UnionSelect, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfUnionSelect(x)) + } + return res +} + +// CloneSliceOfRefOfWhen creates a deep clone of the input. +func CloneSliceOfRefOfWhen(n []*When) []*When { + res := make([]*When, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfWhen(x)) + } + return res +} + +// CloneSliceOfString creates a deep clone of the input. +func CloneSliceOfString(n []string) []string { + res := make([]string, 0, len(n)) + copy(res, n) + return res +} + +// CloneSliceOfVindexParam creates a deep clone of the input. +func CloneSliceOfVindexParam(n []VindexParam) []VindexParam { + res := make([]VindexParam, 0, len(n)) + for _, x := range n { + res = append(res, CloneVindexParam(x)) + } + return res +} + +// CloneStatement creates a deep clone of the input. +func CloneStatement(in Statement) Statement { + if in == nil { + return nil + } + switch in := in.(type) { + case *AlterDatabase: + return CloneRefOfAlterDatabase(in) + case *AlterMigration: + return CloneRefOfAlterMigration(in) + case *AlterTable: + return CloneRefOfAlterTable(in) + case *AlterView: + return CloneRefOfAlterView(in) + case *AlterVschema: + return CloneRefOfAlterVschema(in) + case *Begin: + return CloneRefOfBegin(in) + case *CallProc: + return CloneRefOfCallProc(in) + case *Commit: + return CloneRefOfCommit(in) + case *CreateDatabase: + return CloneRefOfCreateDatabase(in) + case *CreateTable: + return CloneRefOfCreateTable(in) + case *CreateView: + return CloneRefOfCreateView(in) + case *Delete: + return CloneRefOfDelete(in) + case *DropDatabase: + return CloneRefOfDropDatabase(in) + case *DropTable: + return CloneRefOfDropTable(in) + case *DropView: + return CloneRefOfDropView(in) + case *ExplainStmt: + return CloneRefOfExplainStmt(in) + case *ExplainTab: + return CloneRefOfExplainTab(in) + case *Flush: + return CloneRefOfFlush(in) + case *Insert: + return CloneRefOfInsert(in) + case *Load: + return CloneRefOfLoad(in) + case *LockTables: + return CloneRefOfLockTables(in) + case *OtherAdmin: + return CloneRefOfOtherAdmin(in) + case *OtherRead: + return CloneRefOfOtherRead(in) + case *ParenSelect: + return CloneRefOfParenSelect(in) + case *Release: + return CloneRefOfRelease(in) + case *RenameTable: + return CloneRefOfRenameTable(in) + case *RevertMigration: + return CloneRefOfRevertMigration(in) + case *Rollback: + return CloneRefOfRollback(in) + case *SRollback: + return CloneRefOfSRollback(in) + case *Savepoint: + return CloneRefOfSavepoint(in) + case *Select: + return CloneRefOfSelect(in) + case *Set: + return CloneRefOfSet(in) + case *SetTransaction: + return CloneRefOfSetTransaction(in) + case *Show: + return CloneRefOfShow(in) + case *Stream: + return CloneRefOfStream(in) + case *TruncateTable: + return CloneRefOfTruncateTable(in) + case *Union: + return CloneRefOfUnion(in) + case *UnlockTables: + return CloneRefOfUnlockTables(in) + case *Update: + return CloneRefOfUpdate(in) + case *Use: + return CloneRefOfUse(in) + case *VStream: + return CloneRefOfVStream(in) + default: + // this should never happen + return nil + } +} + +// CloneTableAndLockTypes creates a deep clone of the input. +func CloneTableAndLockTypes(n TableAndLockTypes) TableAndLockTypes { + res := make(TableAndLockTypes, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfTableAndLockType(x)) + } + return res +} + +// CloneTableExpr creates a deep clone of the input. +func CloneTableExpr(in TableExpr) TableExpr { + if in == nil { + return nil + } + switch in := in.(type) { + case *AliasedTableExpr: + return CloneRefOfAliasedTableExpr(in) + case *JoinTableExpr: + return CloneRefOfJoinTableExpr(in) + case *ParenTableExpr: + return CloneRefOfParenTableExpr(in) + default: + // this should never happen + return nil + } +} + +// CloneTableExprs creates a deep clone of the input. +func CloneTableExprs(n TableExprs) TableExprs { + res := make(TableExprs, 0, len(n)) + for _, x := range n { + res = append(res, CloneTableExpr(x)) + } + return res +} + +// CloneTableIdent creates a deep clone of the input. +func CloneTableIdent(n TableIdent) TableIdent { + return *CloneRefOfTableIdent(&n) +} + +// CloneTableName creates a deep clone of the input. +func CloneTableName(n TableName) TableName { + return *CloneRefOfTableName(&n) +} + +// CloneTableNames creates a deep clone of the input. +func CloneTableNames(n TableNames) TableNames { + res := make(TableNames, 0, len(n)) + for _, x := range n { + res = append(res, CloneTableName(x)) + } + return res +} + +// CloneTableOptions creates a deep clone of the input. +func CloneTableOptions(n TableOptions) TableOptions { + res := make(TableOptions, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfTableOption(x)) + } + return res +} + +// CloneUpdateExprs creates a deep clone of the input. +func CloneUpdateExprs(n UpdateExprs) UpdateExprs { + res := make(UpdateExprs, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfUpdateExpr(x)) + } + return res +} + +// CloneValTuple creates a deep clone of the input. +func CloneValTuple(n ValTuple) ValTuple { + res := make(ValTuple, 0, len(n)) + for _, x := range n { + res = append(res, CloneExpr(x)) + } + return res +} + +// CloneValues creates a deep clone of the input. +func CloneValues(n Values) Values { + res := make(Values, 0, len(n)) + for _, x := range n { + res = append(res, CloneValTuple(x)) + } + return res +} + +// CloneVindexParam creates a deep clone of the input. +func CloneVindexParam(n VindexParam) VindexParam { + return *CloneRefOfVindexParam(&n) +} + +// EqualsAlterOption does deep equals between the two objects. +func EqualsAlterOption(inA, inB AlterOption) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *AddColumns: + b, ok := inB.(*AddColumns) + if !ok { + return false + } + return EqualsRefOfAddColumns(a, b) + case *AddConstraintDefinition: + b, ok := inB.(*AddConstraintDefinition) + if !ok { + return false + } + return EqualsRefOfAddConstraintDefinition(a, b) + case *AddIndexDefinition: + b, ok := inB.(*AddIndexDefinition) + if !ok { + return false + } + return EqualsRefOfAddIndexDefinition(a, b) + case AlgorithmValue: + b, ok := inB.(AlgorithmValue) + if !ok { + return false + } + return a == b + case *AlterCharset: + b, ok := inB.(*AlterCharset) + if !ok { + return false + } + return EqualsRefOfAlterCharset(a, b) + case *AlterColumn: + b, ok := inB.(*AlterColumn) + if !ok { + return false + } + return EqualsRefOfAlterColumn(a, b) + case *ChangeColumn: + b, ok := inB.(*ChangeColumn) + if !ok { + return false + } + return EqualsRefOfChangeColumn(a, b) + case *DropColumn: + b, ok := inB.(*DropColumn) + if !ok { + return false + } + return EqualsRefOfDropColumn(a, b) + case *DropKey: + b, ok := inB.(*DropKey) + if !ok { + return false + } + return EqualsRefOfDropKey(a, b) + case *Force: + b, ok := inB.(*Force) + if !ok { + return false + } + return EqualsRefOfForce(a, b) + case *KeyState: + b, ok := inB.(*KeyState) + if !ok { + return false + } + return EqualsRefOfKeyState(a, b) + case *LockOption: + b, ok := inB.(*LockOption) + if !ok { + return false + } + return EqualsRefOfLockOption(a, b) + case *ModifyColumn: + b, ok := inB.(*ModifyColumn) + if !ok { + return false + } + return EqualsRefOfModifyColumn(a, b) + case *OrderByOption: + b, ok := inB.(*OrderByOption) + if !ok { + return false + } + return EqualsRefOfOrderByOption(a, b) + case *RenameIndex: + b, ok := inB.(*RenameIndex) + if !ok { + return false + } + return EqualsRefOfRenameIndex(a, b) + case *RenameTableName: + b, ok := inB.(*RenameTableName) + if !ok { + return false + } + return EqualsRefOfRenameTableName(a, b) + case TableOptions: + b, ok := inB.(TableOptions) + if !ok { + return false + } + return EqualsTableOptions(a, b) + case *TablespaceOperation: + b, ok := inB.(*TablespaceOperation) + if !ok { + return false + } + return EqualsRefOfTablespaceOperation(a, b) + case *Validation: + b, ok := inB.(*Validation) + if !ok { + return false + } + return EqualsRefOfValidation(a, b) + default: + // this should never happen + return false + } +} + +// EqualsCharacteristic does deep equals between the two objects. +func EqualsCharacteristic(inA, inB Characteristic) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case AccessMode: + b, ok := inB.(AccessMode) + if !ok { + return false + } + return a == b + case IsolationLevel: + b, ok := inB.(IsolationLevel) + if !ok { + return false + } + return a == b + default: + // this should never happen + return false + } +} + +// EqualsColIdent does deep equals between the two objects. +func EqualsColIdent(a, b ColIdent) bool { + return a.val == b.val && + a.lowered == b.lowered && + a.at == b.at +} + +// EqualsColTuple does deep equals between the two objects. +func EqualsColTuple(inA, inB ColTuple) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case ListArg: + b, ok := inB.(ListArg) + if !ok { + return false + } + return EqualsListArg(a, b) + case *Subquery: + b, ok := inB.(*Subquery) + if !ok { + return false + } + return EqualsRefOfSubquery(a, b) + case ValTuple: + b, ok := inB.(ValTuple) + if !ok { + return false + } + return EqualsValTuple(a, b) + default: + // this should never happen + return false + } +} + +// EqualsCollateAndCharset does deep equals between the two objects. +func EqualsCollateAndCharset(a, b CollateAndCharset) bool { + return a.IsDefault == b.IsDefault && + a.Value == b.Value && + a.Type == b.Type +} + +// EqualsColumnType does deep equals between the two objects. +func EqualsColumnType(a, b ColumnType) bool { + return a.Type == b.Type && + a.Unsigned == b.Unsigned && + a.Zerofill == b.Zerofill && + a.Charset == b.Charset && + a.Collate == b.Collate && + EqualsRefOfColumnTypeOptions(a.Options, b.Options) && + EqualsRefOfLiteral(a.Length, b.Length) && + EqualsRefOfLiteral(a.Scale, b.Scale) && + EqualsSliceOfString(a.EnumValues, b.EnumValues) +} + +// EqualsColumns does deep equals between the two objects. +func EqualsColumns(a, b Columns) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsColIdent(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsComments does deep equals between the two objects. +func EqualsComments(a, b Comments) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false + } + } + return true +} + +// EqualsConstraintInfo does deep equals between the two objects. +func EqualsConstraintInfo(inA, inB ConstraintInfo) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *CheckConstraintDefinition: + b, ok := inB.(*CheckConstraintDefinition) + if !ok { + return false + } + return EqualsRefOfCheckConstraintDefinition(a, b) + case *ForeignKeyDefinition: + b, ok := inB.(*ForeignKeyDefinition) + if !ok { + return false + } + return EqualsRefOfForeignKeyDefinition(a, b) + default: + // this should never happen + return false + } +} + +// EqualsDBDDLStatement does deep equals between the two objects. +func EqualsDBDDLStatement(inA, inB DBDDLStatement) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *AlterDatabase: + b, ok := inB.(*AlterDatabase) + if !ok { + return false + } + return EqualsRefOfAlterDatabase(a, b) + case *CreateDatabase: + b, ok := inB.(*CreateDatabase) + if !ok { + return false + } + return EqualsRefOfCreateDatabase(a, b) + case *DropDatabase: + b, ok := inB.(*DropDatabase) + if !ok { + return false + } + return EqualsRefOfDropDatabase(a, b) + default: + // this should never happen + return false + } +} + +// EqualsDDLStatement does deep equals between the two objects. +func EqualsDDLStatement(inA, inB DDLStatement) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *AlterTable: + b, ok := inB.(*AlterTable) + if !ok { + return false + } + return EqualsRefOfAlterTable(a, b) + case *AlterView: + b, ok := inB.(*AlterView) + if !ok { + return false + } + return EqualsRefOfAlterView(a, b) + case *CreateTable: + b, ok := inB.(*CreateTable) + if !ok { + return false + } + return EqualsRefOfCreateTable(a, b) + case *CreateView: + b, ok := inB.(*CreateView) + if !ok { + return false + } + return EqualsRefOfCreateView(a, b) + case *DropTable: + b, ok := inB.(*DropTable) + if !ok { + return false + } + return EqualsRefOfDropTable(a, b) + case *DropView: + b, ok := inB.(*DropView) + if !ok { + return false + } + return EqualsRefOfDropView(a, b) + case *RenameTable: + b, ok := inB.(*RenameTable) + if !ok { + return false + } + return EqualsRefOfRenameTable(a, b) + case *TruncateTable: + b, ok := inB.(*TruncateTable) + if !ok { + return false + } + return EqualsRefOfTruncateTable(a, b) + default: + // this should never happen + return false + } +} + +// EqualsExplain does deep equals between the two objects. +func EqualsExplain(inA, inB Explain) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *ExplainStmt: + b, ok := inB.(*ExplainStmt) + if !ok { + return false + } + return EqualsRefOfExplainStmt(a, b) + case *ExplainTab: + b, ok := inB.(*ExplainTab) + if !ok { + return false + } + return EqualsRefOfExplainTab(a, b) + default: + // this should never happen + return false + } +} + +// EqualsExpr does deep equals between the two objects. +func EqualsExpr(inA, inB Expr) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *AndExpr: + b, ok := inB.(*AndExpr) + if !ok { + return false + } + return EqualsRefOfAndExpr(a, b) + case Argument: + b, ok := inB.(Argument) + if !ok { + return false + } + return a == b + case *BinaryExpr: + b, ok := inB.(*BinaryExpr) + if !ok { + return false + } + return EqualsRefOfBinaryExpr(a, b) + case BoolVal: + b, ok := inB.(BoolVal) + if !ok { + return false + } + return a == b + case *CaseExpr: + b, ok := inB.(*CaseExpr) + if !ok { + return false + } + return EqualsRefOfCaseExpr(a, b) + case *ColName: + b, ok := inB.(*ColName) + if !ok { + return false + } + return EqualsRefOfColName(a, b) + case *CollateExpr: + b, ok := inB.(*CollateExpr) + if !ok { + return false + } + return EqualsRefOfCollateExpr(a, b) + case *ComparisonExpr: + b, ok := inB.(*ComparisonExpr) + if !ok { + return false + } + return EqualsRefOfComparisonExpr(a, b) + case *ConvertExpr: + b, ok := inB.(*ConvertExpr) + if !ok { + return false + } + return EqualsRefOfConvertExpr(a, b) + case *ConvertUsingExpr: + b, ok := inB.(*ConvertUsingExpr) + if !ok { + return false + } + return EqualsRefOfConvertUsingExpr(a, b) + case *CurTimeFuncExpr: + b, ok := inB.(*CurTimeFuncExpr) + if !ok { + return false + } + return EqualsRefOfCurTimeFuncExpr(a, b) + case *Default: + b, ok := inB.(*Default) + if !ok { + return false + } + return EqualsRefOfDefault(a, b) + case *ExistsExpr: + b, ok := inB.(*ExistsExpr) + if !ok { + return false + } + return EqualsRefOfExistsExpr(a, b) + case *FuncExpr: + b, ok := inB.(*FuncExpr) + if !ok { + return false + } + return EqualsRefOfFuncExpr(a, b) + case *GroupConcatExpr: + b, ok := inB.(*GroupConcatExpr) + if !ok { + return false + } + return EqualsRefOfGroupConcatExpr(a, b) + case *IntervalExpr: + b, ok := inB.(*IntervalExpr) + if !ok { + return false + } + return EqualsRefOfIntervalExpr(a, b) + case *IsExpr: + b, ok := inB.(*IsExpr) + if !ok { + return false + } + return EqualsRefOfIsExpr(a, b) + case ListArg: + b, ok := inB.(ListArg) + if !ok { + return false + } + return EqualsListArg(a, b) + case *Literal: + b, ok := inB.(*Literal) + if !ok { + return false + } + return EqualsRefOfLiteral(a, b) + case *MatchExpr: + b, ok := inB.(*MatchExpr) + if !ok { + return false + } + return EqualsRefOfMatchExpr(a, b) + case *NotExpr: + b, ok := inB.(*NotExpr) + if !ok { + return false + } + return EqualsRefOfNotExpr(a, b) + case *NullVal: + b, ok := inB.(*NullVal) + if !ok { + return false + } + return EqualsRefOfNullVal(a, b) + case *OrExpr: + b, ok := inB.(*OrExpr) + if !ok { + return false + } + return EqualsRefOfOrExpr(a, b) + case *RangeCond: + b, ok := inB.(*RangeCond) + if !ok { + return false + } + return EqualsRefOfRangeCond(a, b) + case *Subquery: + b, ok := inB.(*Subquery) + if !ok { + return false + } + return EqualsRefOfSubquery(a, b) + case *SubstrExpr: + b, ok := inB.(*SubstrExpr) + if !ok { + return false + } + return EqualsRefOfSubstrExpr(a, b) + case *TimestampFuncExpr: + b, ok := inB.(*TimestampFuncExpr) + if !ok { + return false + } + return EqualsRefOfTimestampFuncExpr(a, b) + case *UnaryExpr: + b, ok := inB.(*UnaryExpr) + if !ok { + return false + } + return EqualsRefOfUnaryExpr(a, b) + case ValTuple: + b, ok := inB.(ValTuple) + if !ok { + return false + } + return EqualsValTuple(a, b) + case *ValuesFuncExpr: + b, ok := inB.(*ValuesFuncExpr) + if !ok { + return false + } + return EqualsRefOfValuesFuncExpr(a, b) + case *XorExpr: + b, ok := inB.(*XorExpr) + if !ok { + return false + } + return EqualsRefOfXorExpr(a, b) + default: + // this should never happen + return false + } +} + +// EqualsExprs does deep equals between the two objects. +func EqualsExprs(a, b Exprs) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsExpr(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsGroupBy does deep equals between the two objects. +func EqualsGroupBy(a, b GroupBy) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsExpr(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsInsertRows does deep equals between the two objects. +func EqualsInsertRows(inA, inB InsertRows) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *ParenSelect: + b, ok := inB.(*ParenSelect) + if !ok { + return false + } + return EqualsRefOfParenSelect(a, b) + case *Select: + b, ok := inB.(*Select) + if !ok { + return false + } + return EqualsRefOfSelect(a, b) + case *Union: + b, ok := inB.(*Union) + if !ok { + return false + } + return EqualsRefOfUnion(a, b) + case Values: + b, ok := inB.(Values) + if !ok { + return false + } + return EqualsValues(a, b) + default: + // this should never happen + return false + } +} + +// EqualsJoinCondition does deep equals between the two objects. +func EqualsJoinCondition(a, b JoinCondition) bool { + return EqualsExpr(a.On, b.On) && + EqualsColumns(a.Using, b.Using) +} + +// EqualsListArg does deep equals between the two objects. +func EqualsListArg(a, b ListArg) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false + } + } + return true +} + +// EqualsOnDup does deep equals between the two objects. +func EqualsOnDup(a, b OnDup) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsRefOfUpdateExpr(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsOrderBy does deep equals between the two objects. +func EqualsOrderBy(a, b OrderBy) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsRefOfOrder(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsPartitions does deep equals between the two objects. +func EqualsPartitions(a, b Partitions) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsColIdent(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsRefOfAddColumns does deep equals between the two objects. +func EqualsRefOfAddColumns(a, b *AddColumns) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsSliceOfRefOfColumnDefinition(a.Columns, b.Columns) && + EqualsRefOfColName(a.First, b.First) && + EqualsRefOfColName(a.After, b.After) +} + +// EqualsRefOfAddConstraintDefinition does deep equals between the two objects. +func EqualsRefOfAddConstraintDefinition(a, b *AddConstraintDefinition) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsRefOfConstraintDefinition(a.ConstraintDefinition, b.ConstraintDefinition) +} + +// EqualsRefOfAddIndexDefinition does deep equals between the two objects. +func EqualsRefOfAddIndexDefinition(a, b *AddIndexDefinition) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsRefOfIndexDefinition(a.IndexDefinition, b.IndexDefinition) +} + +// EqualsRefOfAliasedExpr does deep equals between the two objects. +func EqualsRefOfAliasedExpr(a, b *AliasedExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Expr, b.Expr) && + EqualsColIdent(a.As, b.As) +} + +// EqualsRefOfAliasedTableExpr does deep equals between the two objects. +func EqualsRefOfAliasedTableExpr(a, b *AliasedTableExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsSimpleTableExpr(a.Expr, b.Expr) && + EqualsPartitions(a.Partitions, b.Partitions) && + EqualsTableIdent(a.As, b.As) && + EqualsRefOfIndexHints(a.Hints, b.Hints) +} + +// EqualsRefOfAlterCharset does deep equals between the two objects. +func EqualsRefOfAlterCharset(a, b *AlterCharset) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.CharacterSet == b.CharacterSet && + a.Collate == b.Collate +} + +// EqualsRefOfAlterColumn does deep equals between the two objects. +func EqualsRefOfAlterColumn(a, b *AlterColumn) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.DropDefault == b.DropDefault && + EqualsRefOfColName(a.Column, b.Column) && + EqualsExpr(a.DefaultVal, b.DefaultVal) +} + +// EqualsRefOfAlterDatabase does deep equals between the two objects. +func EqualsRefOfAlterDatabase(a, b *AlterDatabase) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.DBName == b.DBName && + a.UpdateDataDirectory == b.UpdateDataDirectory && + a.FullyParsed == b.FullyParsed && + EqualsSliceOfCollateAndCharset(a.AlterOptions, b.AlterOptions) +} + +// EqualsRefOfAlterMigration does deep equals between the two objects. +func EqualsRefOfAlterMigration(a, b *AlterMigration) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.UUID == b.UUID && + a.Type == b.Type +} + +// EqualsRefOfAlterTable does deep equals between the two objects. +func EqualsRefOfAlterTable(a, b *AlterTable) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.FullyParsed == b.FullyParsed && + EqualsTableName(a.Table, b.Table) && + EqualsSliceOfAlterOption(a.AlterOptions, b.AlterOptions) && + EqualsRefOfPartitionSpec(a.PartitionSpec, b.PartitionSpec) +} + +// EqualsRefOfAlterView does deep equals between the two objects. +func EqualsRefOfAlterView(a, b *AlterView) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Algorithm == b.Algorithm && + a.Definer == b.Definer && + a.Security == b.Security && + a.CheckOption == b.CheckOption && + EqualsTableName(a.ViewName, b.ViewName) && + EqualsColumns(a.Columns, b.Columns) && + EqualsSelectStatement(a.Select, b.Select) +} + +// EqualsRefOfAlterVschema does deep equals between the two objects. +func EqualsRefOfAlterVschema(a, b *AlterVschema) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Action == b.Action && + EqualsTableName(a.Table, b.Table) && + EqualsRefOfVindexSpec(a.VindexSpec, b.VindexSpec) && + EqualsSliceOfColIdent(a.VindexCols, b.VindexCols) && + EqualsRefOfAutoIncSpec(a.AutoIncSpec, b.AutoIncSpec) +} + +// EqualsRefOfAndExpr does deep equals between the two objects. +func EqualsRefOfAndExpr(a, b *AndExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Left, b.Left) && + EqualsExpr(a.Right, b.Right) +} + +// EqualsRefOfAutoIncSpec does deep equals between the two objects. +func EqualsRefOfAutoIncSpec(a, b *AutoIncSpec) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsColIdent(a.Column, b.Column) && + EqualsTableName(a.Sequence, b.Sequence) +} + +// EqualsRefOfBegin does deep equals between the two objects. +func EqualsRefOfBegin(a, b *Begin) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return true +} + +// EqualsRefOfBinaryExpr does deep equals between the two objects. +func EqualsRefOfBinaryExpr(a, b *BinaryExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Operator == b.Operator && + EqualsExpr(a.Left, b.Left) && + EqualsExpr(a.Right, b.Right) +} + +// EqualsRefOfBool does deep equals between the two objects. +func EqualsRefOfBool(a, b *bool) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} + +// EqualsRefOfCallProc does deep equals between the two objects. +func EqualsRefOfCallProc(a, b *CallProc) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsTableName(a.Name, b.Name) && + EqualsExprs(a.Params, b.Params) +} + +// EqualsRefOfCaseExpr does deep equals between the two objects. +func EqualsRefOfCaseExpr(a, b *CaseExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Expr, b.Expr) && + EqualsSliceOfRefOfWhen(a.Whens, b.Whens) && + EqualsExpr(a.Else, b.Else) +} + +// EqualsRefOfChangeColumn does deep equals between the two objects. +func EqualsRefOfChangeColumn(a, b *ChangeColumn) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsRefOfColName(a.OldColumn, b.OldColumn) && + EqualsRefOfColumnDefinition(a.NewColDefinition, b.NewColDefinition) && + EqualsRefOfColName(a.First, b.First) && + EqualsRefOfColName(a.After, b.After) +} + +// EqualsRefOfCheckConstraintDefinition does deep equals between the two objects. +func EqualsRefOfCheckConstraintDefinition(a, b *CheckConstraintDefinition) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Enforced == b.Enforced && + EqualsExpr(a.Expr, b.Expr) +} + +// EqualsRefOfColIdent does deep equals between the two objects. +func EqualsRefOfColIdent(a, b *ColIdent) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.val == b.val && + a.lowered == b.lowered && + a.at == b.at +} + +// EqualsRefOfColName does deep equals between the two objects. +func EqualsRefOfColName(a, b *ColName) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsColIdent(a.Name, b.Name) && + EqualsTableName(a.Qualifier, b.Qualifier) +} + +// EqualsRefOfCollateAndCharset does deep equals between the two objects. +func EqualsRefOfCollateAndCharset(a, b *CollateAndCharset) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.IsDefault == b.IsDefault && + a.Value == b.Value && + a.Type == b.Type +} + +// EqualsRefOfCollateExpr does deep equals between the two objects. +func EqualsRefOfCollateExpr(a, b *CollateExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Charset == b.Charset && + EqualsExpr(a.Expr, b.Expr) +} + +// EqualsRefOfColumnDefinition does deep equals between the two objects. +func EqualsRefOfColumnDefinition(a, b *ColumnDefinition) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsColIdent(a.Name, b.Name) && + EqualsColumnType(a.Type, b.Type) +} + +// EqualsRefOfColumnType does deep equals between the two objects. +func EqualsRefOfColumnType(a, b *ColumnType) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Type == b.Type && + a.Unsigned == b.Unsigned && + a.Zerofill == b.Zerofill && + a.Charset == b.Charset && + a.Collate == b.Collate && + EqualsRefOfColumnTypeOptions(a.Options, b.Options) && + EqualsRefOfLiteral(a.Length, b.Length) && + EqualsRefOfLiteral(a.Scale, b.Scale) && + EqualsSliceOfString(a.EnumValues, b.EnumValues) +} + +// EqualsRefOfColumnTypeOptions does deep equals between the two objects. +func EqualsRefOfColumnTypeOptions(a, b *ColumnTypeOptions) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.NotNull == b.NotNull && + a.Autoincrement == b.Autoincrement && + EqualsExpr(a.Default, b.Default) && + EqualsExpr(a.OnUpdate, b.OnUpdate) && + EqualsRefOfLiteral(a.Comment, b.Comment) && + a.KeyOpt == b.KeyOpt +} + +// EqualsRefOfCommit does deep equals between the two objects. +func EqualsRefOfCommit(a, b *Commit) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return true +} + +// EqualsRefOfComparisonExpr does deep equals between the two objects. +func EqualsRefOfComparisonExpr(a, b *ComparisonExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Operator == b.Operator && + EqualsExpr(a.Left, b.Left) && + EqualsExpr(a.Right, b.Right) && + EqualsExpr(a.Escape, b.Escape) +} + +// EqualsRefOfConstraintDefinition does deep equals between the two objects. +func EqualsRefOfConstraintDefinition(a, b *ConstraintDefinition) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Name == b.Name && + EqualsConstraintInfo(a.Details, b.Details) +} + +// EqualsRefOfConvertExpr does deep equals between the two objects. +func EqualsRefOfConvertExpr(a, b *ConvertExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Expr, b.Expr) && + EqualsRefOfConvertType(a.Type, b.Type) +} + +// EqualsRefOfConvertType does deep equals between the two objects. +func EqualsRefOfConvertType(a, b *ConvertType) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Type == b.Type && + a.Charset == b.Charset && + EqualsRefOfLiteral(a.Length, b.Length) && + EqualsRefOfLiteral(a.Scale, b.Scale) && + a.Operator == b.Operator +} + +// EqualsRefOfConvertUsingExpr does deep equals between the two objects. +func EqualsRefOfConvertUsingExpr(a, b *ConvertUsingExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Type == b.Type && + EqualsExpr(a.Expr, b.Expr) +} + +// EqualsRefOfCreateDatabase does deep equals between the two objects. +func EqualsRefOfCreateDatabase(a, b *CreateDatabase) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.DBName == b.DBName && + a.IfNotExists == b.IfNotExists && + a.FullyParsed == b.FullyParsed && + EqualsComments(a.Comments, b.Comments) && + EqualsSliceOfCollateAndCharset(a.CreateOptions, b.CreateOptions) +} + +// EqualsRefOfCreateTable does deep equals between the two objects. +func EqualsRefOfCreateTable(a, b *CreateTable) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Temp == b.Temp && + a.IfNotExists == b.IfNotExists && + a.FullyParsed == b.FullyParsed && + EqualsTableName(a.Table, b.Table) && + EqualsRefOfTableSpec(a.TableSpec, b.TableSpec) && + EqualsRefOfOptLike(a.OptLike, b.OptLike) +} + +// EqualsRefOfCreateView does deep equals between the two objects. +func EqualsRefOfCreateView(a, b *CreateView) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Algorithm == b.Algorithm && + a.Definer == b.Definer && + a.Security == b.Security && + a.CheckOption == b.CheckOption && + a.IsReplace == b.IsReplace && + EqualsTableName(a.ViewName, b.ViewName) && + EqualsColumns(a.Columns, b.Columns) && + EqualsSelectStatement(a.Select, b.Select) +} + +// EqualsRefOfCurTimeFuncExpr does deep equals between the two objects. +func EqualsRefOfCurTimeFuncExpr(a, b *CurTimeFuncExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsColIdent(a.Name, b.Name) && + EqualsExpr(a.Fsp, b.Fsp) +} + +// EqualsRefOfDefault does deep equals between the two objects. +func EqualsRefOfDefault(a, b *Default) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.ColName == b.ColName +} + +// EqualsRefOfDelete does deep equals between the two objects. +func EqualsRefOfDelete(a, b *Delete) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Ignore == b.Ignore && + EqualsComments(a.Comments, b.Comments) && + EqualsTableNames(a.Targets, b.Targets) && + EqualsTableExprs(a.TableExprs, b.TableExprs) && + EqualsPartitions(a.Partitions, b.Partitions) && + EqualsRefOfWhere(a.Where, b.Where) && + EqualsOrderBy(a.OrderBy, b.OrderBy) && + EqualsRefOfLimit(a.Limit, b.Limit) +} + +// EqualsRefOfDerivedTable does deep equals between the two objects. +func EqualsRefOfDerivedTable(a, b *DerivedTable) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsSelectStatement(a.Select, b.Select) +} + +// EqualsRefOfDropColumn does deep equals between the two objects. +func EqualsRefOfDropColumn(a, b *DropColumn) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsRefOfColName(a.Name, b.Name) +} + +// EqualsRefOfDropDatabase does deep equals between the two objects. +func EqualsRefOfDropDatabase(a, b *DropDatabase) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.DBName == b.DBName && + a.IfExists == b.IfExists && + EqualsComments(a.Comments, b.Comments) +} + +// EqualsRefOfDropKey does deep equals between the two objects. +func EqualsRefOfDropKey(a, b *DropKey) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Name == b.Name && + a.Type == b.Type +} + +// EqualsRefOfDropTable does deep equals between the two objects. +func EqualsRefOfDropTable(a, b *DropTable) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Temp == b.Temp && + a.IfExists == b.IfExists && + EqualsTableNames(a.FromTables, b.FromTables) +} + +// EqualsRefOfDropView does deep equals between the two objects. +func EqualsRefOfDropView(a, b *DropView) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.IfExists == b.IfExists && + EqualsTableNames(a.FromTables, b.FromTables) +} + +// EqualsRefOfExistsExpr does deep equals between the two objects. +func EqualsRefOfExistsExpr(a, b *ExistsExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsRefOfSubquery(a.Subquery, b.Subquery) +} + +// EqualsRefOfExplainStmt does deep equals between the two objects. +func EqualsRefOfExplainStmt(a, b *ExplainStmt) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Type == b.Type && + EqualsStatement(a.Statement, b.Statement) +} + +// EqualsRefOfExplainTab does deep equals between the two objects. +func EqualsRefOfExplainTab(a, b *ExplainTab) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Wild == b.Wild && + EqualsTableName(a.Table, b.Table) +} + +// EqualsRefOfFlush does deep equals between the two objects. +func EqualsRefOfFlush(a, b *Flush) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.IsLocal == b.IsLocal && + a.WithLock == b.WithLock && + a.ForExport == b.ForExport && + EqualsSliceOfString(a.FlushOptions, b.FlushOptions) && + EqualsTableNames(a.TableNames, b.TableNames) +} + +// EqualsRefOfForce does deep equals between the two objects. +func EqualsRefOfForce(a, b *Force) bool { + if a == b { + return true } - if errF := a.rewriteRefOfConvertType(node, node.Type, func(newNode, parent SQLNode) { - parent.(*ConvertExpr).Type = newNode.(*ConvertType) - }); errF != nil { - return errF + if a == nil || b == nil { + return false } - if a.post != nil && !a.post(&cur) { - return errAbort + return true +} + +// EqualsRefOfForeignKeyDefinition does deep equals between the two objects. +func EqualsRefOfForeignKeyDefinition(a, b *ForeignKeyDefinition) bool { + if a == b { + return true } - return nil + if a == nil || b == nil { + return false + } + return EqualsColumns(a.Source, b.Source) && + EqualsTableName(a.ReferencedTable, b.ReferencedTable) && + EqualsColumns(a.ReferencedColumns, b.ReferencedColumns) && + a.OnDelete == b.OnDelete && + a.OnUpdate == b.OnUpdate } -// EqualsRefOfConvertType does deep equals between the two objects. -func EqualsRefOfConvertType(a, b *ConvertType) bool { +// EqualsRefOfFuncExpr does deep equals between the two objects. +func EqualsRefOfFuncExpr(a, b *FuncExpr) bool { if a == b { return true } if a == nil || b == nil { return false } - return a.Type == b.Type && - a.Charset == b.Charset && - EqualsRefOfLiteral(a.Length, b.Length) && - EqualsRefOfLiteral(a.Scale, b.Scale) && - a.Operator == b.Operator + return a.Distinct == b.Distinct && + EqualsTableIdent(a.Qualifier, b.Qualifier) && + EqualsColIdent(a.Name, b.Name) && + EqualsSelectExprs(a.Exprs, b.Exprs) } -// CloneRefOfConvertType creates a deep clone of the input. -func CloneRefOfConvertType(n *ConvertType) *ConvertType { - if n == nil { - return nil +// EqualsRefOfGroupConcatExpr does deep equals between the two objects. +func EqualsRefOfGroupConcatExpr(a, b *GroupConcatExpr) bool { + if a == b { + return true } - out := *n - out.Length = CloneRefOfLiteral(n.Length) - out.Scale = CloneRefOfLiteral(n.Scale) - return &out + if a == nil || b == nil { + return false + } + return a.Distinct == b.Distinct && + a.Separator == b.Separator && + EqualsSelectExprs(a.Exprs, b.Exprs) && + EqualsOrderBy(a.OrderBy, b.OrderBy) && + EqualsRefOfLimit(a.Limit, b.Limit) } -// VisitRefOfConvertType will visit all parts of the AST -func VisitRefOfConvertType(in *ConvertType, f Visit) error { - if in == nil { - return nil +// EqualsRefOfIndexColumn does deep equals between the two objects. +func EqualsRefOfIndexColumn(a, b *IndexColumn) bool { + if a == b { + return true } - if cont, err := f(in); err != nil || !cont { - return err + if a == nil || b == nil { + return false } - if err := VisitRefOfLiteral(in.Length, f); err != nil { - return err + return EqualsColIdent(a.Column, b.Column) && + EqualsRefOfLiteral(a.Length, b.Length) && + a.Direction == b.Direction +} + +// EqualsRefOfIndexDefinition does deep equals between the two objects. +func EqualsRefOfIndexDefinition(a, b *IndexDefinition) bool { + if a == b { + return true } - if err := VisitRefOfLiteral(in.Scale, f); err != nil { - return err + if a == nil || b == nil { + return false } - return nil + return EqualsRefOfIndexInfo(a.Info, b.Info) && + EqualsSliceOfRefOfIndexColumn(a.Columns, b.Columns) && + EqualsSliceOfRefOfIndexOption(a.Options, b.Options) } -// rewriteRefOfConvertType is part of the Rewrite implementation -func (a *application) rewriteRefOfConvertType(parent SQLNode, node *ConvertType, replacer replacerFunc) error { - if node == nil { - return nil +// EqualsRefOfIndexHints does deep equals between the two objects. +func EqualsRefOfIndexHints(a, b *IndexHints) bool { + if a == b { + return true } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if a == nil || b == nil { + return false } - if a.pre != nil && !a.pre(&cur) { - return nil + return a.Type == b.Type && + EqualsSliceOfColIdent(a.Indexes, b.Indexes) +} + +// EqualsRefOfIndexInfo does deep equals between the two objects. +func EqualsRefOfIndexInfo(a, b *IndexInfo) bool { + if a == b { + return true } - if errF := a.rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { - parent.(*ConvertType).Length = newNode.(*Literal) - }); errF != nil { - return errF + if a == nil || b == nil { + return false } - if errF := a.rewriteRefOfLiteral(node, node.Scale, func(newNode, parent SQLNode) { - parent.(*ConvertType).Scale = newNode.(*Literal) - }); errF != nil { - return errF + return a.Type == b.Type && + a.Primary == b.Primary && + a.Spatial == b.Spatial && + a.Fulltext == b.Fulltext && + a.Unique == b.Unique && + EqualsColIdent(a.Name, b.Name) && + EqualsColIdent(a.ConstraintName, b.ConstraintName) +} + +// EqualsRefOfIndexOption does deep equals between the two objects. +func EqualsRefOfIndexOption(a, b *IndexOption) bool { + if a == b { + return true } - if a.post != nil && !a.post(&cur) { - return errAbort + if a == nil || b == nil { + return false } - return nil + return a.Name == b.Name && + a.String == b.String && + EqualsRefOfLiteral(a.Value, b.Value) } -// EqualsRefOfConvertUsingExpr does deep equals between the two objects. -func EqualsRefOfConvertUsingExpr(a, b *ConvertUsingExpr) bool { +// EqualsRefOfInsert does deep equals between the two objects. +func EqualsRefOfInsert(a, b *Insert) bool { if a == b { return true } if a == nil || b == nil { return false } - return a.Type == b.Type && + return a.Action == b.Action && + EqualsComments(a.Comments, b.Comments) && + a.Ignore == b.Ignore && + EqualsTableName(a.Table, b.Table) && + EqualsPartitions(a.Partitions, b.Partitions) && + EqualsColumns(a.Columns, b.Columns) && + EqualsInsertRows(a.Rows, b.Rows) && + EqualsOnDup(a.OnDup, b.OnDup) +} + +// EqualsRefOfIntervalExpr does deep equals between the two objects. +func EqualsRefOfIntervalExpr(a, b *IntervalExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Unit == b.Unit && EqualsExpr(a.Expr, b.Expr) } -// CloneRefOfConvertUsingExpr creates a deep clone of the input. -func CloneRefOfConvertUsingExpr(n *ConvertUsingExpr) *ConvertUsingExpr { - if n == nil { - return nil +// EqualsRefOfIsExpr does deep equals between the two objects. +func EqualsRefOfIsExpr(a, b *IsExpr) bool { + if a == b { + return true } - out := *n - out.Expr = CloneExpr(n.Expr) - return &out + if a == nil || b == nil { + return false + } + return a.Operator == b.Operator && + EqualsExpr(a.Expr, b.Expr) } -// VisitRefOfConvertUsingExpr will visit all parts of the AST -func VisitRefOfConvertUsingExpr(in *ConvertUsingExpr, f Visit) error { - if in == nil { - return nil +// EqualsRefOfJoinCondition does deep equals between the two objects. +func EqualsRefOfJoinCondition(a, b *JoinCondition) bool { + if a == b { + return true } - if cont, err := f(in); err != nil || !cont { - return err + if a == nil || b == nil { + return false } - if err := VisitExpr(in.Expr, f); err != nil { - return err + return EqualsExpr(a.On, b.On) && + EqualsColumns(a.Using, b.Using) +} + +// EqualsRefOfJoinTableExpr does deep equals between the two objects. +func EqualsRefOfJoinTableExpr(a, b *JoinTableExpr) bool { + if a == b { + return true } - return nil + if a == nil || b == nil { + return false + } + return EqualsTableExpr(a.LeftExpr, b.LeftExpr) && + a.Join == b.Join && + EqualsTableExpr(a.RightExpr, b.RightExpr) && + EqualsJoinCondition(a.Condition, b.Condition) } -// rewriteRefOfConvertUsingExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfConvertUsingExpr(parent SQLNode, node *ConvertUsingExpr, replacer replacerFunc) error { - if node == nil { - return nil +// EqualsRefOfKeyState does deep equals between the two objects. +func EqualsRefOfKeyState(a, b *KeyState) bool { + if a == b { + return true } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if a == nil || b == nil { + return false } - if a.pre != nil && !a.pre(&cur) { - return nil + return a.Enable == b.Enable +} + +// EqualsRefOfLimit does deep equals between the two objects. +func EqualsRefOfLimit(a, b *Limit) bool { + if a == b { + return true } - if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*ConvertUsingExpr).Expr = newNode.(Expr) - }); errF != nil { - return errF + if a == nil || b == nil { + return false } - if a.post != nil && !a.post(&cur) { - return errAbort + return EqualsExpr(a.Offset, b.Offset) && + EqualsExpr(a.Rowcount, b.Rowcount) +} + +// EqualsRefOfLiteral does deep equals between the two objects. +func EqualsRefOfLiteral(a, b *Literal) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false } - return nil + return a.Val == b.Val && + a.Type == b.Type } -// EqualsRefOfCreateDatabase does deep equals between the two objects. -func EqualsRefOfCreateDatabase(a, b *CreateDatabase) bool { +// EqualsRefOfLoad does deep equals between the two objects. +func EqualsRefOfLoad(a, b *Load) bool { if a == b { return true } if a == nil || b == nil { return false } - return a.DBName == b.DBName && - a.IfNotExists == b.IfNotExists && - a.FullyParsed == b.FullyParsed && - EqualsComments(a.Comments, b.Comments) && - EqualsSliceOfCollateAndCharset(a.CreateOptions, b.CreateOptions) + return true } -// CloneRefOfCreateDatabase creates a deep clone of the input. -func CloneRefOfCreateDatabase(n *CreateDatabase) *CreateDatabase { - if n == nil { - return nil +// EqualsRefOfLockOption does deep equals between the two objects. +func EqualsRefOfLockOption(a, b *LockOption) bool { + if a == b { + return true } - out := *n - out.Comments = CloneComments(n.Comments) - out.CreateOptions = CloneSliceOfCollateAndCharset(n.CreateOptions) - return &out + if a == nil || b == nil { + return false + } + return a.Type == b.Type } -// VisitRefOfCreateDatabase will visit all parts of the AST -func VisitRefOfCreateDatabase(in *CreateDatabase, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsRefOfLockTables does deep equals between the two objects. +func EqualsRefOfLockTables(a, b *LockTables) bool { + if a == b { + return true } - if err := VisitComments(in.Comments, f); err != nil { - return err + if a == nil || b == nil { + return false } - return nil + return EqualsTableAndLockTypes(a.Tables, b.Tables) } -// rewriteRefOfCreateDatabase is part of the Rewrite implementation -func (a *application) rewriteRefOfCreateDatabase(parent SQLNode, node *CreateDatabase, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { - parent.(*CreateDatabase).Comments = newNode.(Comments) - }); errF != nil { - return errF +// EqualsRefOfMatchExpr does deep equals between the two objects. +func EqualsRefOfMatchExpr(a, b *MatchExpr) bool { + if a == b { + return true } - if a.post != nil && !a.post(&cur) { - return errAbort + if a == nil || b == nil { + return false } - return nil + return EqualsSelectExprs(a.Columns, b.Columns) && + EqualsExpr(a.Expr, b.Expr) && + a.Option == b.Option } -// EqualsRefOfCreateTable does deep equals between the two objects. -func EqualsRefOfCreateTable(a, b *CreateTable) bool { +// EqualsRefOfModifyColumn does deep equals between the two objects. +func EqualsRefOfModifyColumn(a, b *ModifyColumn) bool { if a == b { return true } if a == nil || b == nil { return false } - return a.Temp == b.Temp && - a.IfNotExists == b.IfNotExists && - a.FullyParsed == b.FullyParsed && - EqualsTableName(a.Table, b.Table) && - EqualsRefOfTableSpec(a.TableSpec, b.TableSpec) && - EqualsRefOfOptLike(a.OptLike, b.OptLike) + return EqualsRefOfColumnDefinition(a.NewColDefinition, b.NewColDefinition) && + EqualsRefOfColName(a.First, b.First) && + EqualsRefOfColName(a.After, b.After) } -// CloneRefOfCreateTable creates a deep clone of the input. -func CloneRefOfCreateTable(n *CreateTable) *CreateTable { - if n == nil { - return nil +// EqualsRefOfNextval does deep equals between the two objects. +func EqualsRefOfNextval(a, b *Nextval) bool { + if a == b { + return true } - out := *n - out.Table = CloneTableName(n.Table) - out.TableSpec = CloneRefOfTableSpec(n.TableSpec) - out.OptLike = CloneRefOfOptLike(n.OptLike) - return &out + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Expr, b.Expr) } -// VisitRefOfCreateTable will visit all parts of the AST -func VisitRefOfCreateTable(in *CreateTable, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableName(in.Table, f); err != nil { - return err - } - if err := VisitRefOfTableSpec(in.TableSpec, f); err != nil { - return err +// EqualsRefOfNotExpr does deep equals between the two objects. +func EqualsRefOfNotExpr(a, b *NotExpr) bool { + if a == b { + return true } - if err := VisitRefOfOptLike(in.OptLike, f); err != nil { - return err + if a == nil || b == nil { + return false } - return nil + return EqualsExpr(a.Expr, b.Expr) } -// rewriteRefOfCreateTable is part of the Rewrite implementation -func (a *application) rewriteRefOfCreateTable(parent SQLNode, node *CreateTable, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { - parent.(*CreateTable).Table = newNode.(TableName) - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfTableSpec(node, node.TableSpec, func(newNode, parent SQLNode) { - parent.(*CreateTable).TableSpec = newNode.(*TableSpec) - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfOptLike(node, node.OptLike, func(newNode, parent SQLNode) { - parent.(*CreateTable).OptLike = newNode.(*OptLike) - }); errF != nil { - return errF +// EqualsRefOfNullVal does deep equals between the two objects. +func EqualsRefOfNullVal(a, b *NullVal) bool { + if a == b { + return true } - if a.post != nil && !a.post(&cur) { - return errAbort + if a == nil || b == nil { + return false } - return nil + return true } -// EqualsRefOfCreateView does deep equals between the two objects. -func EqualsRefOfCreateView(a, b *CreateView) bool { +// EqualsRefOfOptLike does deep equals between the two objects. +func EqualsRefOfOptLike(a, b *OptLike) bool { if a == b { return true } if a == nil || b == nil { return false } - return a.Algorithm == b.Algorithm && - a.Definer == b.Definer && - a.Security == b.Security && - a.CheckOption == b.CheckOption && - a.IsReplace == b.IsReplace && - EqualsTableName(a.ViewName, b.ViewName) && - EqualsColumns(a.Columns, b.Columns) && - EqualsSelectStatement(a.Select, b.Select) + return EqualsTableName(a.LikeTable, b.LikeTable) } -// CloneRefOfCreateView creates a deep clone of the input. -func CloneRefOfCreateView(n *CreateView) *CreateView { - if n == nil { - return nil +// EqualsRefOfOrExpr does deep equals between the two objects. +func EqualsRefOfOrExpr(a, b *OrExpr) bool { + if a == b { + return true } - out := *n - out.ViewName = CloneTableName(n.ViewName) - out.Columns = CloneColumns(n.Columns) - out.Select = CloneSelectStatement(n.Select) - return &out + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Left, b.Left) && + EqualsExpr(a.Right, b.Right) } -// VisitRefOfCreateView will visit all parts of the AST -func VisitRefOfCreateView(in *CreateView, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableName(in.ViewName, f); err != nil { - return err - } - if err := VisitColumns(in.Columns, f); err != nil { - return err +// EqualsRefOfOrder does deep equals between the two objects. +func EqualsRefOfOrder(a, b *Order) bool { + if a == b { + return true } - if err := VisitSelectStatement(in.Select, f); err != nil { - return err + if a == nil || b == nil { + return false } - return nil + return EqualsExpr(a.Expr, b.Expr) && + a.Direction == b.Direction } -// rewriteRefOfCreateView is part of the Rewrite implementation -func (a *application) rewriteRefOfCreateView(parent SQLNode, node *CreateView, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - if errF := a.rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { - parent.(*CreateView).ViewName = newNode.(TableName) - }); errF != nil { - return errF +// EqualsRefOfOrderByOption does deep equals between the two objects. +func EqualsRefOfOrderByOption(a, b *OrderByOption) bool { + if a == b { + return true } - if errF := a.rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { - parent.(*CreateView).Columns = newNode.(Columns) - }); errF != nil { - return errF + if a == nil || b == nil { + return false } - if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { - parent.(*CreateView).Select = newNode.(SelectStatement) - }); errF != nil { - return errF + return EqualsColumns(a.Cols, b.Cols) +} + +// EqualsRefOfOtherAdmin does deep equals between the two objects. +func EqualsRefOfOtherAdmin(a, b *OtherAdmin) bool { + if a == b { + return true } - if a.post != nil && !a.post(&cur) { - return errAbort + if a == nil || b == nil { + return false } - return nil + return true } -// EqualsRefOfCurTimeFuncExpr does deep equals between the two objects. -func EqualsRefOfCurTimeFuncExpr(a, b *CurTimeFuncExpr) bool { +// EqualsRefOfOtherRead does deep equals between the two objects. +func EqualsRefOfOtherRead(a, b *OtherRead) bool { if a == b { return true } if a == nil || b == nil { return false } - return EqualsColIdent(a.Name, b.Name) && - EqualsExpr(a.Fsp, b.Fsp) + return true } -// CloneRefOfCurTimeFuncExpr creates a deep clone of the input. -func CloneRefOfCurTimeFuncExpr(n *CurTimeFuncExpr) *CurTimeFuncExpr { - if n == nil { - return nil +// EqualsRefOfParenSelect does deep equals between the two objects. +func EqualsRefOfParenSelect(a, b *ParenSelect) bool { + if a == b { + return true } - out := *n - out.Name = CloneColIdent(n.Name) - out.Fsp = CloneExpr(n.Fsp) - return &out + if a == nil || b == nil { + return false + } + return EqualsSelectStatement(a.Select, b.Select) } -// VisitRefOfCurTimeFuncExpr will visit all parts of the AST -func VisitRefOfCurTimeFuncExpr(in *CurTimeFuncExpr, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitColIdent(in.Name, f); err != nil { - return err +// EqualsRefOfParenTableExpr does deep equals between the two objects. +func EqualsRefOfParenTableExpr(a, b *ParenTableExpr) bool { + if a == b { + return true } - if err := VisitExpr(in.Fsp, f); err != nil { - return err + if a == nil || b == nil { + return false } - return nil + return EqualsTableExprs(a.Exprs, b.Exprs) } -// rewriteRefOfCurTimeFuncExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfCurTimeFuncExpr(parent SQLNode, node *CurTimeFuncExpr, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil +// EqualsRefOfPartitionDefinition does deep equals between the two objects. +func EqualsRefOfPartitionDefinition(a, b *PartitionDefinition) bool { + if a == b { + return true } - if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { - parent.(*CurTimeFuncExpr).Name = newNode.(ColIdent) - }); errF != nil { - return errF + if a == nil || b == nil { + return false } - if errF := a.rewriteExpr(node, node.Fsp, func(newNode, parent SQLNode) { - parent.(*CurTimeFuncExpr).Fsp = newNode.(Expr) - }); errF != nil { - return errF + return a.Maxvalue == b.Maxvalue && + EqualsColIdent(a.Name, b.Name) && + EqualsExpr(a.Limit, b.Limit) +} + +// EqualsRefOfPartitionSpec does deep equals between the two objects. +func EqualsRefOfPartitionSpec(a, b *PartitionSpec) bool { + if a == b { + return true } - if a.post != nil && !a.post(&cur) { - return errAbort + if a == nil || b == nil { + return false } - return nil + return a.IsAll == b.IsAll && + a.WithoutValidation == b.WithoutValidation && + a.Action == b.Action && + EqualsPartitions(a.Names, b.Names) && + EqualsRefOfLiteral(a.Number, b.Number) && + EqualsTableName(a.TableName, b.TableName) && + EqualsSliceOfRefOfPartitionDefinition(a.Definitions, b.Definitions) } -// EqualsRefOfDefault does deep equals between the two objects. -func EqualsRefOfDefault(a, b *Default) bool { +// EqualsRefOfRangeCond does deep equals between the two objects. +func EqualsRefOfRangeCond(a, b *RangeCond) bool { if a == b { return true } if a == nil || b == nil { return false } - return a.ColName == b.ColName + return a.Operator == b.Operator && + EqualsExpr(a.Left, b.Left) && + EqualsExpr(a.From, b.From) && + EqualsExpr(a.To, b.To) } -// CloneRefOfDefault creates a deep clone of the input. -func CloneRefOfDefault(n *Default) *Default { - if n == nil { - return nil +// EqualsRefOfRelease does deep equals between the two objects. +func EqualsRefOfRelease(a, b *Release) bool { + if a == b { + return true } - out := *n - return &out + if a == nil || b == nil { + return false + } + return EqualsColIdent(a.Name, b.Name) } -// VisitRefOfDefault will visit all parts of the AST -func VisitRefOfDefault(in *Default, f Visit) error { - if in == nil { - return nil +// EqualsRefOfRenameIndex does deep equals between the two objects. +func EqualsRefOfRenameIndex(a, b *RenameIndex) bool { + if a == b { + return true } - if cont, err := f(in); err != nil || !cont { - return err + if a == nil || b == nil { + return false } - return nil + return a.OldName == b.OldName && + a.NewName == b.NewName } -// rewriteRefOfDefault is part of the Rewrite implementation -func (a *application) rewriteRefOfDefault(parent SQLNode, node *Default, replacer replacerFunc) error { - if node == nil { - return nil +// EqualsRefOfRenameTable does deep equals between the two objects. +func EqualsRefOfRenameTable(a, b *RenameTable) bool { + if a == b { + return true } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if a == nil || b == nil { + return false } - if a.pre != nil && !a.pre(&cur) { - return nil + return EqualsSliceOfRefOfRenameTablePair(a.TablePairs, b.TablePairs) +} + +// EqualsRefOfRenameTableName does deep equals between the two objects. +func EqualsRefOfRenameTableName(a, b *RenameTableName) bool { + if a == b { + return true } - if a.post != nil && !a.post(&cur) { - return errAbort + if a == nil || b == nil { + return false } - return nil + return EqualsTableName(a.Table, b.Table) } -// EqualsRefOfDelete does deep equals between the two objects. -func EqualsRefOfDelete(a, b *Delete) bool { +// EqualsRefOfRenameTablePair does deep equals between the two objects. +func EqualsRefOfRenameTablePair(a, b *RenameTablePair) bool { if a == b { return true } if a == nil || b == nil { return false } - return a.Ignore == b.Ignore && - EqualsComments(a.Comments, b.Comments) && - EqualsTableNames(a.Targets, b.Targets) && - EqualsTableExprs(a.TableExprs, b.TableExprs) && - EqualsPartitions(a.Partitions, b.Partitions) && - EqualsRefOfWhere(a.Where, b.Where) && - EqualsOrderBy(a.OrderBy, b.OrderBy) && - EqualsRefOfLimit(a.Limit, b.Limit) + return EqualsTableName(a.FromTable, b.FromTable) && + EqualsTableName(a.ToTable, b.ToTable) } -// CloneRefOfDelete creates a deep clone of the input. -func CloneRefOfDelete(n *Delete) *Delete { - if n == nil { - return nil +// EqualsRefOfRevertMigration does deep equals between the two objects. +func EqualsRefOfRevertMigration(a, b *RevertMigration) bool { + if a == b { + return true } - out := *n - out.Comments = CloneComments(n.Comments) - out.Targets = CloneTableNames(n.Targets) - out.TableExprs = CloneTableExprs(n.TableExprs) - out.Partitions = ClonePartitions(n.Partitions) - out.Where = CloneRefOfWhere(n.Where) - out.OrderBy = CloneOrderBy(n.OrderBy) - out.Limit = CloneRefOfLimit(n.Limit) - return &out + if a == nil || b == nil { + return false + } + return a.UUID == b.UUID } -// VisitRefOfDelete will visit all parts of the AST -func VisitRefOfDelete(in *Delete, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsRefOfRollback does deep equals between the two objects. +func EqualsRefOfRollback(a, b *Rollback) bool { + if a == b { + return true } - if err := VisitComments(in.Comments, f); err != nil { - return err + if a == nil || b == nil { + return false } - if err := VisitTableNames(in.Targets, f); err != nil { - return err + return true +} + +// EqualsRefOfSRollback does deep equals between the two objects. +func EqualsRefOfSRollback(a, b *SRollback) bool { + if a == b { + return true } - if err := VisitTableExprs(in.TableExprs, f); err != nil { - return err + if a == nil || b == nil { + return false } - if err := VisitPartitions(in.Partitions, f); err != nil { - return err + return EqualsColIdent(a.Name, b.Name) +} + +// EqualsRefOfSavepoint does deep equals between the two objects. +func EqualsRefOfSavepoint(a, b *Savepoint) bool { + if a == b { + return true } - if err := VisitRefOfWhere(in.Where, f); err != nil { - return err + if a == nil || b == nil { + return false } - if err := VisitOrderBy(in.OrderBy, f); err != nil { - return err + return EqualsColIdent(a.Name, b.Name) +} + +// EqualsRefOfSelect does deep equals between the two objects. +func EqualsRefOfSelect(a, b *Select) bool { + if a == b { + return true } - if err := VisitRefOfLimit(in.Limit, f); err != nil { - return err + if a == nil || b == nil { + return false } - return nil + return a.Distinct == b.Distinct && + a.StraightJoinHint == b.StraightJoinHint && + a.SQLCalcFoundRows == b.SQLCalcFoundRows && + EqualsRefOfBool(a.Cache, b.Cache) && + EqualsComments(a.Comments, b.Comments) && + EqualsSelectExprs(a.SelectExprs, b.SelectExprs) && + EqualsTableExprs(a.From, b.From) && + EqualsRefOfWhere(a.Where, b.Where) && + EqualsGroupBy(a.GroupBy, b.GroupBy) && + EqualsRefOfWhere(a.Having, b.Having) && + EqualsOrderBy(a.OrderBy, b.OrderBy) && + EqualsRefOfLimit(a.Limit, b.Limit) && + a.Lock == b.Lock && + EqualsRefOfSelectInto(a.Into, b.Into) } -// rewriteRefOfDelete is part of the Rewrite implementation -func (a *application) rewriteRefOfDelete(parent SQLNode, node *Delete, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { - parent.(*Delete).Comments = newNode.(Comments) - }); errF != nil { - return errF +// EqualsRefOfSelectInto does deep equals between the two objects. +func EqualsRefOfSelectInto(a, b *SelectInto) bool { + if a == b { + return true } - if errF := a.rewriteTableNames(node, node.Targets, func(newNode, parent SQLNode) { - parent.(*Delete).Targets = newNode.(TableNames) - }); errF != nil { - return errF + if a == nil || b == nil { + return false } - if errF := a.rewriteTableExprs(node, node.TableExprs, func(newNode, parent SQLNode) { - parent.(*Delete).TableExprs = newNode.(TableExprs) - }); errF != nil { - return errF + return a.FileName == b.FileName && + a.Charset == b.Charset && + a.FormatOption == b.FormatOption && + a.ExportOption == b.ExportOption && + a.Manifest == b.Manifest && + a.Overwrite == b.Overwrite && + a.Type == b.Type +} + +// EqualsRefOfSet does deep equals between the two objects. +func EqualsRefOfSet(a, b *Set) bool { + if a == b { + return true } - if errF := a.rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { - parent.(*Delete).Partitions = newNode.(Partitions) - }); errF != nil { - return errF + if a == nil || b == nil { + return false } - if errF := a.rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { - parent.(*Delete).Where = newNode.(*Where) - }); errF != nil { - return errF + return EqualsComments(a.Comments, b.Comments) && + EqualsSetExprs(a.Exprs, b.Exprs) +} + +// EqualsRefOfSetExpr does deep equals between the two objects. +func EqualsRefOfSetExpr(a, b *SetExpr) bool { + if a == b { + return true } - if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { - parent.(*Delete).OrderBy = newNode.(OrderBy) - }); errF != nil { - return errF + if a == nil || b == nil { + return false } - if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { - parent.(*Delete).Limit = newNode.(*Limit) - }); errF != nil { - return errF + return a.Scope == b.Scope && + EqualsColIdent(a.Name, b.Name) && + EqualsExpr(a.Expr, b.Expr) +} + +// EqualsRefOfSetTransaction does deep equals between the two objects. +func EqualsRefOfSetTransaction(a, b *SetTransaction) bool { + if a == b { + return true } - if a.post != nil && !a.post(&cur) { - return errAbort + if a == nil || b == nil { + return false } - return nil + return EqualsSQLNode(a.SQLNode, b.SQLNode) && + EqualsComments(a.Comments, b.Comments) && + a.Scope == b.Scope && + EqualsSliceOfCharacteristic(a.Characteristics, b.Characteristics) } -// EqualsRefOfDerivedTable does deep equals between the two objects. -func EqualsRefOfDerivedTable(a, b *DerivedTable) bool { +// EqualsRefOfShow does deep equals between the two objects. +func EqualsRefOfShow(a, b *Show) bool { if a == b { return true } if a == nil || b == nil { return false } - return EqualsSelectStatement(a.Select, b.Select) + return EqualsShowInternal(a.Internal, b.Internal) } -// CloneRefOfDerivedTable creates a deep clone of the input. -func CloneRefOfDerivedTable(n *DerivedTable) *DerivedTable { - if n == nil { - return nil +// EqualsRefOfShowBasic does deep equals between the two objects. +func EqualsRefOfShowBasic(a, b *ShowBasic) bool { + if a == b { + return true } - out := *n - out.Select = CloneSelectStatement(n.Select) - return &out + if a == nil || b == nil { + return false + } + return a.Full == b.Full && + a.DbName == b.DbName && + a.Command == b.Command && + EqualsTableName(a.Tbl, b.Tbl) && + EqualsRefOfShowFilter(a.Filter, b.Filter) } -// VisitRefOfDerivedTable will visit all parts of the AST -func VisitRefOfDerivedTable(in *DerivedTable, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsRefOfShowCreate does deep equals between the two objects. +func EqualsRefOfShowCreate(a, b *ShowCreate) bool { + if a == b { + return true } - if err := VisitSelectStatement(in.Select, f); err != nil { - return err + if a == nil || b == nil { + return false } - return nil + return a.Command == b.Command && + EqualsTableName(a.Op, b.Op) } -// rewriteRefOfDerivedTable is part of the Rewrite implementation -func (a *application) rewriteRefOfDerivedTable(parent SQLNode, node *DerivedTable, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, +// EqualsRefOfShowFilter does deep equals between the two objects. +func EqualsRefOfShowFilter(a, b *ShowFilter) bool { + if a == b { + return true } - if a.pre != nil && !a.pre(&cur) { - return nil + if a == nil || b == nil { + return false } - if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { - parent.(*DerivedTable).Select = newNode.(SelectStatement) - }); errF != nil { - return errF + return a.Like == b.Like && + EqualsExpr(a.Filter, b.Filter) +} + +// EqualsRefOfShowLegacy does deep equals between the two objects. +func EqualsRefOfShowLegacy(a, b *ShowLegacy) bool { + if a == b { + return true } - if a.post != nil && !a.post(&cur) { - return errAbort + if a == nil || b == nil { + return false } - return nil + return a.Extended == b.Extended && + a.Type == b.Type && + EqualsTableName(a.OnTable, b.OnTable) && + EqualsTableName(a.Table, b.Table) && + EqualsRefOfShowTablesOpt(a.ShowTablesOpt, b.ShowTablesOpt) && + a.Scope == b.Scope && + EqualsExpr(a.ShowCollationFilterOpt, b.ShowCollationFilterOpt) } -// EqualsRefOfDropColumn does deep equals between the two objects. -func EqualsRefOfDropColumn(a, b *DropColumn) bool { +// EqualsRefOfShowTablesOpt does deep equals between the two objects. +func EqualsRefOfShowTablesOpt(a, b *ShowTablesOpt) bool { if a == b { return true } if a == nil || b == nil { return false } - return EqualsRefOfColName(a.Name, b.Name) + return a.Full == b.Full && + a.DbName == b.DbName && + EqualsRefOfShowFilter(a.Filter, b.Filter) } -// CloneRefOfDropColumn creates a deep clone of the input. -func CloneRefOfDropColumn(n *DropColumn) *DropColumn { - if n == nil { - return nil +// EqualsRefOfStarExpr does deep equals between the two objects. +func EqualsRefOfStarExpr(a, b *StarExpr) bool { + if a == b { + return true } - out := *n - out.Name = CloneRefOfColName(n.Name) - return &out + if a == nil || b == nil { + return false + } + return EqualsTableName(a.TableName, b.TableName) } -// VisitRefOfDropColumn will visit all parts of the AST -func VisitRefOfDropColumn(in *DropColumn, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsRefOfStream does deep equals between the two objects. +func EqualsRefOfStream(a, b *Stream) bool { + if a == b { + return true } - if err := VisitRefOfColName(in.Name, f); err != nil { - return err + if a == nil || b == nil { + return false } - return nil + return EqualsComments(a.Comments, b.Comments) && + EqualsSelectExpr(a.SelectExpr, b.SelectExpr) && + EqualsTableName(a.Table, b.Table) } -// rewriteRefOfDropColumn is part of the Rewrite implementation -func (a *application) rewriteRefOfDropColumn(parent SQLNode, node *DropColumn, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, +// EqualsRefOfSubquery does deep equals between the two objects. +func EqualsRefOfSubquery(a, b *Subquery) bool { + if a == b { + return true } - if a.pre != nil && !a.pre(&cur) { - return nil + if a == nil || b == nil { + return false } - if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { - parent.(*DropColumn).Name = newNode.(*ColName) - }); errF != nil { - return errF + return EqualsSelectStatement(a.Select, b.Select) +} + +// EqualsRefOfSubstrExpr does deep equals between the two objects. +func EqualsRefOfSubstrExpr(a, b *SubstrExpr) bool { + if a == b { + return true } - if a.post != nil && !a.post(&cur) { - return errAbort + if a == nil || b == nil { + return false } - return nil + return EqualsRefOfColName(a.Name, b.Name) && + EqualsRefOfLiteral(a.StrVal, b.StrVal) && + EqualsExpr(a.From, b.From) && + EqualsExpr(a.To, b.To) } -// EqualsRefOfDropDatabase does deep equals between the two objects. -func EqualsRefOfDropDatabase(a, b *DropDatabase) bool { +// EqualsRefOfTableAndLockType does deep equals between the two objects. +func EqualsRefOfTableAndLockType(a, b *TableAndLockType) bool { if a == b { return true } if a == nil || b == nil { return false } - return a.DBName == b.DBName && - a.IfExists == b.IfExists && - EqualsComments(a.Comments, b.Comments) -} - -// CloneRefOfDropDatabase creates a deep clone of the input. -func CloneRefOfDropDatabase(n *DropDatabase) *DropDatabase { - if n == nil { - return nil - } - out := *n - out.Comments = CloneComments(n.Comments) - return &out + return EqualsTableExpr(a.Table, b.Table) && + a.Lock == b.Lock } -// VisitRefOfDropDatabase will visit all parts of the AST -func VisitRefOfDropDatabase(in *DropDatabase, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsRefOfTableIdent does deep equals between the two objects. +func EqualsRefOfTableIdent(a, b *TableIdent) bool { + if a == b { + return true } - if err := VisitComments(in.Comments, f); err != nil { - return err + if a == nil || b == nil { + return false } - return nil + return a.v == b.v } -// rewriteRefOfDropDatabase is part of the Rewrite implementation -func (a *application) rewriteRefOfDropDatabase(parent SQLNode, node *DropDatabase, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { - parent.(*DropDatabase).Comments = newNode.(Comments) - }); errF != nil { - return errF +// EqualsRefOfTableName does deep equals between the two objects. +func EqualsRefOfTableName(a, b *TableName) bool { + if a == b { + return true } - if a.post != nil && !a.post(&cur) { - return errAbort + if a == nil || b == nil { + return false } - return nil + return EqualsTableIdent(a.Name, b.Name) && + EqualsTableIdent(a.Qualifier, b.Qualifier) } -// EqualsRefOfDropKey does deep equals between the two objects. -func EqualsRefOfDropKey(a, b *DropKey) bool { +// EqualsRefOfTableOption does deep equals between the two objects. +func EqualsRefOfTableOption(a, b *TableOption) bool { if a == b { return true } @@ -4689,1292 +4677,2198 @@ func EqualsRefOfDropKey(a, b *DropKey) bool { return false } return a.Name == b.Name && - a.Type == b.Type + a.String == b.String && + EqualsRefOfLiteral(a.Value, b.Value) && + EqualsTableNames(a.Tables, b.Tables) } -// CloneRefOfDropKey creates a deep clone of the input. -func CloneRefOfDropKey(n *DropKey) *DropKey { - if n == nil { - return nil +// EqualsRefOfTableSpec does deep equals between the two objects. +func EqualsRefOfTableSpec(a, b *TableSpec) bool { + if a == b { + return true } - out := *n - return &out + if a == nil || b == nil { + return false + } + return EqualsSliceOfRefOfColumnDefinition(a.Columns, b.Columns) && + EqualsSliceOfRefOfIndexDefinition(a.Indexes, b.Indexes) && + EqualsSliceOfRefOfConstraintDefinition(a.Constraints, b.Constraints) && + EqualsTableOptions(a.Options, b.Options) } -// VisitRefOfDropKey will visit all parts of the AST -func VisitRefOfDropKey(in *DropKey, f Visit) error { - if in == nil { - return nil +// EqualsRefOfTablespaceOperation does deep equals between the two objects. +func EqualsRefOfTablespaceOperation(a, b *TablespaceOperation) bool { + if a == b { + return true } - if cont, err := f(in); err != nil || !cont { - return err + if a == nil || b == nil { + return false } - return nil + return a.Import == b.Import } -// rewriteRefOfDropKey is part of the Rewrite implementation -func (a *application) rewriteRefOfDropKey(parent SQLNode, node *DropKey, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil +// EqualsRefOfTimestampFuncExpr does deep equals between the two objects. +func EqualsRefOfTimestampFuncExpr(a, b *TimestampFuncExpr) bool { + if a == b { + return true } - if a.post != nil && !a.post(&cur) { - return errAbort + if a == nil || b == nil { + return false } - return nil + return a.Name == b.Name && + a.Unit == b.Unit && + EqualsExpr(a.Expr1, b.Expr1) && + EqualsExpr(a.Expr2, b.Expr2) } -// EqualsRefOfDropTable does deep equals between the two objects. -func EqualsRefOfDropTable(a, b *DropTable) bool { +// EqualsRefOfTruncateTable does deep equals between the two objects. +func EqualsRefOfTruncateTable(a, b *TruncateTable) bool { if a == b { return true } if a == nil || b == nil { return false } - return a.Temp == b.Temp && - a.IfExists == b.IfExists && - EqualsTableNames(a.FromTables, b.FromTables) + return EqualsTableName(a.Table, b.Table) } -// CloneRefOfDropTable creates a deep clone of the input. -func CloneRefOfDropTable(n *DropTable) *DropTable { - if n == nil { - return nil +// EqualsRefOfUnaryExpr does deep equals between the two objects. +func EqualsRefOfUnaryExpr(a, b *UnaryExpr) bool { + if a == b { + return true } - out := *n - out.FromTables = CloneTableNames(n.FromTables) - return &out + if a == nil || b == nil { + return false + } + return a.Operator == b.Operator && + EqualsExpr(a.Expr, b.Expr) } -// VisitRefOfDropTable will visit all parts of the AST -func VisitRefOfDropTable(in *DropTable, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsRefOfUnion does deep equals between the two objects. +func EqualsRefOfUnion(a, b *Union) bool { + if a == b { + return true } - if err := VisitTableNames(in.FromTables, f); err != nil { - return err + if a == nil || b == nil { + return false } - return nil + return EqualsSelectStatement(a.FirstStatement, b.FirstStatement) && + EqualsSliceOfRefOfUnionSelect(a.UnionSelects, b.UnionSelects) && + EqualsOrderBy(a.OrderBy, b.OrderBy) && + EqualsRefOfLimit(a.Limit, b.Limit) && + a.Lock == b.Lock } -// rewriteRefOfDropTable is part of the Rewrite implementation -func (a *application) rewriteRefOfDropTable(parent SQLNode, node *DropTable, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - if errF := a.rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { - parent.(*DropTable).FromTables = newNode.(TableNames) - }); errF != nil { - return errF +// EqualsRefOfUnionSelect does deep equals between the two objects. +func EqualsRefOfUnionSelect(a, b *UnionSelect) bool { + if a == b { + return true } - if a.post != nil && !a.post(&cur) { - return errAbort + if a == nil || b == nil { + return false } - return nil + return a.Distinct == b.Distinct && + EqualsSelectStatement(a.Statement, b.Statement) } -// EqualsRefOfDropView does deep equals between the two objects. -func EqualsRefOfDropView(a, b *DropView) bool { +// EqualsRefOfUnlockTables does deep equals between the two objects. +func EqualsRefOfUnlockTables(a, b *UnlockTables) bool { if a == b { return true } if a == nil || b == nil { return false } - return a.IfExists == b.IfExists && - EqualsTableNames(a.FromTables, b.FromTables) + return true } -// CloneRefOfDropView creates a deep clone of the input. -func CloneRefOfDropView(n *DropView) *DropView { - if n == nil { - return nil +// EqualsRefOfUpdate does deep equals between the two objects. +func EqualsRefOfUpdate(a, b *Update) bool { + if a == b { + return true } - out := *n - out.FromTables = CloneTableNames(n.FromTables) - return &out + if a == nil || b == nil { + return false + } + return EqualsComments(a.Comments, b.Comments) && + a.Ignore == b.Ignore && + EqualsTableExprs(a.TableExprs, b.TableExprs) && + EqualsUpdateExprs(a.Exprs, b.Exprs) && + EqualsRefOfWhere(a.Where, b.Where) && + EqualsOrderBy(a.OrderBy, b.OrderBy) && + EqualsRefOfLimit(a.Limit, b.Limit) } -// VisitRefOfDropView will visit all parts of the AST -func VisitRefOfDropView(in *DropView, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsRefOfUpdateExpr does deep equals between the two objects. +func EqualsRefOfUpdateExpr(a, b *UpdateExpr) bool { + if a == b { + return true } - if err := VisitTableNames(in.FromTables, f); err != nil { - return err + if a == nil || b == nil { + return false } - return nil + return EqualsRefOfColName(a.Name, b.Name) && + EqualsExpr(a.Expr, b.Expr) } -// rewriteRefOfDropView is part of the Rewrite implementation -func (a *application) rewriteRefOfDropView(parent SQLNode, node *DropView, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - if errF := a.rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { - parent.(*DropView).FromTables = newNode.(TableNames) - }); errF != nil { - return errF +// EqualsRefOfUse does deep equals between the two objects. +func EqualsRefOfUse(a, b *Use) bool { + if a == b { + return true } - if a.post != nil && !a.post(&cur) { - return errAbort + if a == nil || b == nil { + return false } - return nil + return EqualsTableIdent(a.DBName, b.DBName) } -// EqualsRefOfExistsExpr does deep equals between the two objects. -func EqualsRefOfExistsExpr(a, b *ExistsExpr) bool { +// EqualsRefOfVStream does deep equals between the two objects. +func EqualsRefOfVStream(a, b *VStream) bool { if a == b { return true } if a == nil || b == nil { return false } - return EqualsRefOfSubquery(a.Subquery, b.Subquery) + return EqualsComments(a.Comments, b.Comments) && + EqualsSelectExpr(a.SelectExpr, b.SelectExpr) && + EqualsTableName(a.Table, b.Table) && + EqualsRefOfWhere(a.Where, b.Where) && + EqualsRefOfLimit(a.Limit, b.Limit) } -// CloneRefOfExistsExpr creates a deep clone of the input. -func CloneRefOfExistsExpr(n *ExistsExpr) *ExistsExpr { - if n == nil { - return nil +// EqualsRefOfValidation does deep equals between the two objects. +func EqualsRefOfValidation(a, b *Validation) bool { + if a == b { + return true } - out := *n - out.Subquery = CloneRefOfSubquery(n.Subquery) - return &out + if a == nil || b == nil { + return false + } + return a.With == b.With } -// VisitRefOfExistsExpr will visit all parts of the AST -func VisitRefOfExistsExpr(in *ExistsExpr, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsRefOfValuesFuncExpr does deep equals between the two objects. +func EqualsRefOfValuesFuncExpr(a, b *ValuesFuncExpr) bool { + if a == b { + return true } - if err := VisitRefOfSubquery(in.Subquery, f); err != nil { - return err + if a == nil || b == nil { + return false } - return nil + return EqualsRefOfColName(a.Name, b.Name) } -// rewriteRefOfExistsExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfExistsExpr(parent SQLNode, node *ExistsExpr, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - if errF := a.rewriteRefOfSubquery(node, node.Subquery, func(newNode, parent SQLNode) { - parent.(*ExistsExpr).Subquery = newNode.(*Subquery) - }); errF != nil { - return errF +// EqualsRefOfVindexParam does deep equals between the two objects. +func EqualsRefOfVindexParam(a, b *VindexParam) bool { + if a == b { + return true } - if a.post != nil && !a.post(&cur) { - return errAbort + if a == nil || b == nil { + return false } - return nil + return a.Val == b.Val && + EqualsColIdent(a.Key, b.Key) } -// EqualsRefOfExplainStmt does deep equals between the two objects. -func EqualsRefOfExplainStmt(a, b *ExplainStmt) bool { +// EqualsRefOfVindexSpec does deep equals between the two objects. +func EqualsRefOfVindexSpec(a, b *VindexSpec) bool { if a == b { return true } if a == nil || b == nil { return false } - return a.Type == b.Type && - EqualsStatement(a.Statement, b.Statement) + return EqualsColIdent(a.Name, b.Name) && + EqualsColIdent(a.Type, b.Type) && + EqualsSliceOfVindexParam(a.Params, b.Params) } -// CloneRefOfExplainStmt creates a deep clone of the input. -func CloneRefOfExplainStmt(n *ExplainStmt) *ExplainStmt { - if n == nil { - return nil +// EqualsRefOfWhen does deep equals between the two objects. +func EqualsRefOfWhen(a, b *When) bool { + if a == b { + return true } - out := *n - out.Statement = CloneStatement(n.Statement) - return &out + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Cond, b.Cond) && + EqualsExpr(a.Val, b.Val) } -// VisitRefOfExplainStmt will visit all parts of the AST -func VisitRefOfExplainStmt(in *ExplainStmt, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsRefOfWhere does deep equals between the two objects. +func EqualsRefOfWhere(a, b *Where) bool { + if a == b { + return true } - if err := VisitStatement(in.Statement, f); err != nil { - return err + if a == nil || b == nil { + return false } - return nil + return a.Type == b.Type && + EqualsExpr(a.Expr, b.Expr) } -// rewriteRefOfExplainStmt is part of the Rewrite implementation -func (a *application) rewriteRefOfExplainStmt(parent SQLNode, node *ExplainStmt, replacer replacerFunc) error { - if node == nil { - return nil +// EqualsRefOfXorExpr does deep equals between the two objects. +func EqualsRefOfXorExpr(a, b *XorExpr) bool { + if a == b { + return true } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if a == nil || b == nil { + return false } - if a.pre != nil && !a.pre(&cur) { - return nil + return EqualsExpr(a.Left, b.Left) && + EqualsExpr(a.Right, b.Right) +} + +// EqualsSQLNode does deep equals between the two objects. +func EqualsSQLNode(inA, inB SQLNode) bool { + if inA == nil && inB == nil { + return true } - if errF := a.rewriteStatement(node, node.Statement, func(newNode, parent SQLNode) { - parent.(*ExplainStmt).Statement = newNode.(Statement) - }); errF != nil { - return errF + if inA == nil || inB == nil { + return false } - if a.post != nil && !a.post(&cur) { - return errAbort + switch a := inA.(type) { + case AccessMode: + b, ok := inB.(AccessMode) + if !ok { + return false + } + return a == b + case *AddColumns: + b, ok := inB.(*AddColumns) + if !ok { + return false + } + return EqualsRefOfAddColumns(a, b) + case *AddConstraintDefinition: + b, ok := inB.(*AddConstraintDefinition) + if !ok { + return false + } + return EqualsRefOfAddConstraintDefinition(a, b) + case *AddIndexDefinition: + b, ok := inB.(*AddIndexDefinition) + if !ok { + return false + } + return EqualsRefOfAddIndexDefinition(a, b) + case AlgorithmValue: + b, ok := inB.(AlgorithmValue) + if !ok { + return false + } + return a == b + case *AliasedExpr: + b, ok := inB.(*AliasedExpr) + if !ok { + return false + } + return EqualsRefOfAliasedExpr(a, b) + case *AliasedTableExpr: + b, ok := inB.(*AliasedTableExpr) + if !ok { + return false + } + return EqualsRefOfAliasedTableExpr(a, b) + case *AlterCharset: + b, ok := inB.(*AlterCharset) + if !ok { + return false + } + return EqualsRefOfAlterCharset(a, b) + case *AlterColumn: + b, ok := inB.(*AlterColumn) + if !ok { + return false + } + return EqualsRefOfAlterColumn(a, b) + case *AlterDatabase: + b, ok := inB.(*AlterDatabase) + if !ok { + return false + } + return EqualsRefOfAlterDatabase(a, b) + case *AlterMigration: + b, ok := inB.(*AlterMigration) + if !ok { + return false + } + return EqualsRefOfAlterMigration(a, b) + case *AlterTable: + b, ok := inB.(*AlterTable) + if !ok { + return false + } + return EqualsRefOfAlterTable(a, b) + case *AlterView: + b, ok := inB.(*AlterView) + if !ok { + return false + } + return EqualsRefOfAlterView(a, b) + case *AlterVschema: + b, ok := inB.(*AlterVschema) + if !ok { + return false + } + return EqualsRefOfAlterVschema(a, b) + case *AndExpr: + b, ok := inB.(*AndExpr) + if !ok { + return false + } + return EqualsRefOfAndExpr(a, b) + case Argument: + b, ok := inB.(Argument) + if !ok { + return false + } + return a == b + case *AutoIncSpec: + b, ok := inB.(*AutoIncSpec) + if !ok { + return false + } + return EqualsRefOfAutoIncSpec(a, b) + case *Begin: + b, ok := inB.(*Begin) + if !ok { + return false + } + return EqualsRefOfBegin(a, b) + case *BinaryExpr: + b, ok := inB.(*BinaryExpr) + if !ok { + return false + } + return EqualsRefOfBinaryExpr(a, b) + case BoolVal: + b, ok := inB.(BoolVal) + if !ok { + return false + } + return a == b + case *CallProc: + b, ok := inB.(*CallProc) + if !ok { + return false + } + return EqualsRefOfCallProc(a, b) + case *CaseExpr: + b, ok := inB.(*CaseExpr) + if !ok { + return false + } + return EqualsRefOfCaseExpr(a, b) + case *ChangeColumn: + b, ok := inB.(*ChangeColumn) + if !ok { + return false + } + return EqualsRefOfChangeColumn(a, b) + case *CheckConstraintDefinition: + b, ok := inB.(*CheckConstraintDefinition) + if !ok { + return false + } + return EqualsRefOfCheckConstraintDefinition(a, b) + case ColIdent: + b, ok := inB.(ColIdent) + if !ok { + return false + } + return EqualsColIdent(a, b) + case *ColName: + b, ok := inB.(*ColName) + if !ok { + return false + } + return EqualsRefOfColName(a, b) + case *CollateExpr: + b, ok := inB.(*CollateExpr) + if !ok { + return false + } + return EqualsRefOfCollateExpr(a, b) + case *ColumnDefinition: + b, ok := inB.(*ColumnDefinition) + if !ok { + return false + } + return EqualsRefOfColumnDefinition(a, b) + case *ColumnType: + b, ok := inB.(*ColumnType) + if !ok { + return false + } + return EqualsRefOfColumnType(a, b) + case Columns: + b, ok := inB.(Columns) + if !ok { + return false + } + return EqualsColumns(a, b) + case Comments: + b, ok := inB.(Comments) + if !ok { + return false + } + return EqualsComments(a, b) + case *Commit: + b, ok := inB.(*Commit) + if !ok { + return false + } + return EqualsRefOfCommit(a, b) + case *ComparisonExpr: + b, ok := inB.(*ComparisonExpr) + if !ok { + return false + } + return EqualsRefOfComparisonExpr(a, b) + case *ConstraintDefinition: + b, ok := inB.(*ConstraintDefinition) + if !ok { + return false + } + return EqualsRefOfConstraintDefinition(a, b) + case *ConvertExpr: + b, ok := inB.(*ConvertExpr) + if !ok { + return false + } + return EqualsRefOfConvertExpr(a, b) + case *ConvertType: + b, ok := inB.(*ConvertType) + if !ok { + return false + } + return EqualsRefOfConvertType(a, b) + case *ConvertUsingExpr: + b, ok := inB.(*ConvertUsingExpr) + if !ok { + return false + } + return EqualsRefOfConvertUsingExpr(a, b) + case *CreateDatabase: + b, ok := inB.(*CreateDatabase) + if !ok { + return false + } + return EqualsRefOfCreateDatabase(a, b) + case *CreateTable: + b, ok := inB.(*CreateTable) + if !ok { + return false + } + return EqualsRefOfCreateTable(a, b) + case *CreateView: + b, ok := inB.(*CreateView) + if !ok { + return false + } + return EqualsRefOfCreateView(a, b) + case *CurTimeFuncExpr: + b, ok := inB.(*CurTimeFuncExpr) + if !ok { + return false + } + return EqualsRefOfCurTimeFuncExpr(a, b) + case *Default: + b, ok := inB.(*Default) + if !ok { + return false + } + return EqualsRefOfDefault(a, b) + case *Delete: + b, ok := inB.(*Delete) + if !ok { + return false + } + return EqualsRefOfDelete(a, b) + case *DerivedTable: + b, ok := inB.(*DerivedTable) + if !ok { + return false + } + return EqualsRefOfDerivedTable(a, b) + case *DropColumn: + b, ok := inB.(*DropColumn) + if !ok { + return false + } + return EqualsRefOfDropColumn(a, b) + case *DropDatabase: + b, ok := inB.(*DropDatabase) + if !ok { + return false + } + return EqualsRefOfDropDatabase(a, b) + case *DropKey: + b, ok := inB.(*DropKey) + if !ok { + return false + } + return EqualsRefOfDropKey(a, b) + case *DropTable: + b, ok := inB.(*DropTable) + if !ok { + return false + } + return EqualsRefOfDropTable(a, b) + case *DropView: + b, ok := inB.(*DropView) + if !ok { + return false + } + return EqualsRefOfDropView(a, b) + case *ExistsExpr: + b, ok := inB.(*ExistsExpr) + if !ok { + return false + } + return EqualsRefOfExistsExpr(a, b) + case *ExplainStmt: + b, ok := inB.(*ExplainStmt) + if !ok { + return false + } + return EqualsRefOfExplainStmt(a, b) + case *ExplainTab: + b, ok := inB.(*ExplainTab) + if !ok { + return false + } + return EqualsRefOfExplainTab(a, b) + case Exprs: + b, ok := inB.(Exprs) + if !ok { + return false + } + return EqualsExprs(a, b) + case *Flush: + b, ok := inB.(*Flush) + if !ok { + return false + } + return EqualsRefOfFlush(a, b) + case *Force: + b, ok := inB.(*Force) + if !ok { + return false + } + return EqualsRefOfForce(a, b) + case *ForeignKeyDefinition: + b, ok := inB.(*ForeignKeyDefinition) + if !ok { + return false + } + return EqualsRefOfForeignKeyDefinition(a, b) + case *FuncExpr: + b, ok := inB.(*FuncExpr) + if !ok { + return false + } + return EqualsRefOfFuncExpr(a, b) + case GroupBy: + b, ok := inB.(GroupBy) + if !ok { + return false + } + return EqualsGroupBy(a, b) + case *GroupConcatExpr: + b, ok := inB.(*GroupConcatExpr) + if !ok { + return false + } + return EqualsRefOfGroupConcatExpr(a, b) + case *IndexDefinition: + b, ok := inB.(*IndexDefinition) + if !ok { + return false + } + return EqualsRefOfIndexDefinition(a, b) + case *IndexHints: + b, ok := inB.(*IndexHints) + if !ok { + return false + } + return EqualsRefOfIndexHints(a, b) + case *IndexInfo: + b, ok := inB.(*IndexInfo) + if !ok { + return false + } + return EqualsRefOfIndexInfo(a, b) + case *Insert: + b, ok := inB.(*Insert) + if !ok { + return false + } + return EqualsRefOfInsert(a, b) + case *IntervalExpr: + b, ok := inB.(*IntervalExpr) + if !ok { + return false + } + return EqualsRefOfIntervalExpr(a, b) + case *IsExpr: + b, ok := inB.(*IsExpr) + if !ok { + return false + } + return EqualsRefOfIsExpr(a, b) + case IsolationLevel: + b, ok := inB.(IsolationLevel) + if !ok { + return false + } + return a == b + case JoinCondition: + b, ok := inB.(JoinCondition) + if !ok { + return false + } + return EqualsJoinCondition(a, b) + case *JoinTableExpr: + b, ok := inB.(*JoinTableExpr) + if !ok { + return false + } + return EqualsRefOfJoinTableExpr(a, b) + case *KeyState: + b, ok := inB.(*KeyState) + if !ok { + return false + } + return EqualsRefOfKeyState(a, b) + case *Limit: + b, ok := inB.(*Limit) + if !ok { + return false + } + return EqualsRefOfLimit(a, b) + case ListArg: + b, ok := inB.(ListArg) + if !ok { + return false + } + return EqualsListArg(a, b) + case *Literal: + b, ok := inB.(*Literal) + if !ok { + return false + } + return EqualsRefOfLiteral(a, b) + case *Load: + b, ok := inB.(*Load) + if !ok { + return false + } + return EqualsRefOfLoad(a, b) + case *LockOption: + b, ok := inB.(*LockOption) + if !ok { + return false + } + return EqualsRefOfLockOption(a, b) + case *LockTables: + b, ok := inB.(*LockTables) + if !ok { + return false + } + return EqualsRefOfLockTables(a, b) + case *MatchExpr: + b, ok := inB.(*MatchExpr) + if !ok { + return false + } + return EqualsRefOfMatchExpr(a, b) + case *ModifyColumn: + b, ok := inB.(*ModifyColumn) + if !ok { + return false + } + return EqualsRefOfModifyColumn(a, b) + case *Nextval: + b, ok := inB.(*Nextval) + if !ok { + return false + } + return EqualsRefOfNextval(a, b) + case *NotExpr: + b, ok := inB.(*NotExpr) + if !ok { + return false + } + return EqualsRefOfNotExpr(a, b) + case *NullVal: + b, ok := inB.(*NullVal) + if !ok { + return false + } + return EqualsRefOfNullVal(a, b) + case OnDup: + b, ok := inB.(OnDup) + if !ok { + return false + } + return EqualsOnDup(a, b) + case *OptLike: + b, ok := inB.(*OptLike) + if !ok { + return false + } + return EqualsRefOfOptLike(a, b) + case *OrExpr: + b, ok := inB.(*OrExpr) + if !ok { + return false + } + return EqualsRefOfOrExpr(a, b) + case *Order: + b, ok := inB.(*Order) + if !ok { + return false + } + return EqualsRefOfOrder(a, b) + case OrderBy: + b, ok := inB.(OrderBy) + if !ok { + return false + } + return EqualsOrderBy(a, b) + case *OrderByOption: + b, ok := inB.(*OrderByOption) + if !ok { + return false + } + return EqualsRefOfOrderByOption(a, b) + case *OtherAdmin: + b, ok := inB.(*OtherAdmin) + if !ok { + return false + } + return EqualsRefOfOtherAdmin(a, b) + case *OtherRead: + b, ok := inB.(*OtherRead) + if !ok { + return false + } + return EqualsRefOfOtherRead(a, b) + case *ParenSelect: + b, ok := inB.(*ParenSelect) + if !ok { + return false + } + return EqualsRefOfParenSelect(a, b) + case *ParenTableExpr: + b, ok := inB.(*ParenTableExpr) + if !ok { + return false + } + return EqualsRefOfParenTableExpr(a, b) + case *PartitionDefinition: + b, ok := inB.(*PartitionDefinition) + if !ok { + return false + } + return EqualsRefOfPartitionDefinition(a, b) + case *PartitionSpec: + b, ok := inB.(*PartitionSpec) + if !ok { + return false + } + return EqualsRefOfPartitionSpec(a, b) + case Partitions: + b, ok := inB.(Partitions) + if !ok { + return false + } + return EqualsPartitions(a, b) + case *RangeCond: + b, ok := inB.(*RangeCond) + if !ok { + return false + } + return EqualsRefOfRangeCond(a, b) + case ReferenceAction: + b, ok := inB.(ReferenceAction) + if !ok { + return false + } + return a == b + case *Release: + b, ok := inB.(*Release) + if !ok { + return false + } + return EqualsRefOfRelease(a, b) + case *RenameIndex: + b, ok := inB.(*RenameIndex) + if !ok { + return false + } + return EqualsRefOfRenameIndex(a, b) + case *RenameTable: + b, ok := inB.(*RenameTable) + if !ok { + return false + } + return EqualsRefOfRenameTable(a, b) + case *RenameTableName: + b, ok := inB.(*RenameTableName) + if !ok { + return false + } + return EqualsRefOfRenameTableName(a, b) + case *RevertMigration: + b, ok := inB.(*RevertMigration) + if !ok { + return false + } + return EqualsRefOfRevertMigration(a, b) + case *Rollback: + b, ok := inB.(*Rollback) + if !ok { + return false + } + return EqualsRefOfRollback(a, b) + case *SRollback: + b, ok := inB.(*SRollback) + if !ok { + return false + } + return EqualsRefOfSRollback(a, b) + case *Savepoint: + b, ok := inB.(*Savepoint) + if !ok { + return false + } + return EqualsRefOfSavepoint(a, b) + case *Select: + b, ok := inB.(*Select) + if !ok { + return false + } + return EqualsRefOfSelect(a, b) + case SelectExprs: + b, ok := inB.(SelectExprs) + if !ok { + return false + } + return EqualsSelectExprs(a, b) + case *SelectInto: + b, ok := inB.(*SelectInto) + if !ok { + return false + } + return EqualsRefOfSelectInto(a, b) + case *Set: + b, ok := inB.(*Set) + if !ok { + return false + } + return EqualsRefOfSet(a, b) + case *SetExpr: + b, ok := inB.(*SetExpr) + if !ok { + return false + } + return EqualsRefOfSetExpr(a, b) + case SetExprs: + b, ok := inB.(SetExprs) + if !ok { + return false + } + return EqualsSetExprs(a, b) + case *SetTransaction: + b, ok := inB.(*SetTransaction) + if !ok { + return false + } + return EqualsRefOfSetTransaction(a, b) + case *Show: + b, ok := inB.(*Show) + if !ok { + return false + } + return EqualsRefOfShow(a, b) + case *ShowBasic: + b, ok := inB.(*ShowBasic) + if !ok { + return false + } + return EqualsRefOfShowBasic(a, b) + case *ShowCreate: + b, ok := inB.(*ShowCreate) + if !ok { + return false + } + return EqualsRefOfShowCreate(a, b) + case *ShowFilter: + b, ok := inB.(*ShowFilter) + if !ok { + return false + } + return EqualsRefOfShowFilter(a, b) + case *ShowLegacy: + b, ok := inB.(*ShowLegacy) + if !ok { + return false + } + return EqualsRefOfShowLegacy(a, b) + case *StarExpr: + b, ok := inB.(*StarExpr) + if !ok { + return false + } + return EqualsRefOfStarExpr(a, b) + case *Stream: + b, ok := inB.(*Stream) + if !ok { + return false + } + return EqualsRefOfStream(a, b) + case *Subquery: + b, ok := inB.(*Subquery) + if !ok { + return false + } + return EqualsRefOfSubquery(a, b) + case *SubstrExpr: + b, ok := inB.(*SubstrExpr) + if !ok { + return false + } + return EqualsRefOfSubstrExpr(a, b) + case TableExprs: + b, ok := inB.(TableExprs) + if !ok { + return false + } + return EqualsTableExprs(a, b) + case TableIdent: + b, ok := inB.(TableIdent) + if !ok { + return false + } + return EqualsTableIdent(a, b) + case TableName: + b, ok := inB.(TableName) + if !ok { + return false + } + return EqualsTableName(a, b) + case TableNames: + b, ok := inB.(TableNames) + if !ok { + return false + } + return EqualsTableNames(a, b) + case TableOptions: + b, ok := inB.(TableOptions) + if !ok { + return false + } + return EqualsTableOptions(a, b) + case *TableSpec: + b, ok := inB.(*TableSpec) + if !ok { + return false + } + return EqualsRefOfTableSpec(a, b) + case *TablespaceOperation: + b, ok := inB.(*TablespaceOperation) + if !ok { + return false + } + return EqualsRefOfTablespaceOperation(a, b) + case *TimestampFuncExpr: + b, ok := inB.(*TimestampFuncExpr) + if !ok { + return false + } + return EqualsRefOfTimestampFuncExpr(a, b) + case *TruncateTable: + b, ok := inB.(*TruncateTable) + if !ok { + return false + } + return EqualsRefOfTruncateTable(a, b) + case *UnaryExpr: + b, ok := inB.(*UnaryExpr) + if !ok { + return false + } + return EqualsRefOfUnaryExpr(a, b) + case *Union: + b, ok := inB.(*Union) + if !ok { + return false + } + return EqualsRefOfUnion(a, b) + case *UnionSelect: + b, ok := inB.(*UnionSelect) + if !ok { + return false + } + return EqualsRefOfUnionSelect(a, b) + case *UnlockTables: + b, ok := inB.(*UnlockTables) + if !ok { + return false + } + return EqualsRefOfUnlockTables(a, b) + case *Update: + b, ok := inB.(*Update) + if !ok { + return false + } + return EqualsRefOfUpdate(a, b) + case *UpdateExpr: + b, ok := inB.(*UpdateExpr) + if !ok { + return false + } + return EqualsRefOfUpdateExpr(a, b) + case UpdateExprs: + b, ok := inB.(UpdateExprs) + if !ok { + return false + } + return EqualsUpdateExprs(a, b) + case *Use: + b, ok := inB.(*Use) + if !ok { + return false + } + return EqualsRefOfUse(a, b) + case *VStream: + b, ok := inB.(*VStream) + if !ok { + return false + } + return EqualsRefOfVStream(a, b) + case ValTuple: + b, ok := inB.(ValTuple) + if !ok { + return false + } + return EqualsValTuple(a, b) + case *Validation: + b, ok := inB.(*Validation) + if !ok { + return false + } + return EqualsRefOfValidation(a, b) + case Values: + b, ok := inB.(Values) + if !ok { + return false + } + return EqualsValues(a, b) + case *ValuesFuncExpr: + b, ok := inB.(*ValuesFuncExpr) + if !ok { + return false + } + return EqualsRefOfValuesFuncExpr(a, b) + case VindexParam: + b, ok := inB.(VindexParam) + if !ok { + return false + } + return EqualsVindexParam(a, b) + case *VindexSpec: + b, ok := inB.(*VindexSpec) + if !ok { + return false + } + return EqualsRefOfVindexSpec(a, b) + case *When: + b, ok := inB.(*When) + if !ok { + return false + } + return EqualsRefOfWhen(a, b) + case *Where: + b, ok := inB.(*Where) + if !ok { + return false + } + return EqualsRefOfWhere(a, b) + case *XorExpr: + b, ok := inB.(*XorExpr) + if !ok { + return false + } + return EqualsRefOfXorExpr(a, b) + default: + // this should never happen + return false } - return nil } -// EqualsRefOfExplainTab does deep equals between the two objects. -func EqualsRefOfExplainTab(a, b *ExplainTab) bool { - if a == b { +// EqualsSelectExpr does deep equals between the two objects. +func EqualsSelectExpr(inA, inB SelectExpr) bool { + if inA == nil && inB == nil { return true } - if a == nil || b == nil { + if inA == nil || inB == nil { return false } - return a.Wild == b.Wild && - EqualsTableName(a.Table, b.Table) -} - -// CloneRefOfExplainTab creates a deep clone of the input. -func CloneRefOfExplainTab(n *ExplainTab) *ExplainTab { - if n == nil { - return nil - } - out := *n - out.Table = CloneTableName(n.Table) - return &out -} - -// VisitRefOfExplainTab will visit all parts of the AST -func VisitRefOfExplainTab(in *ExplainTab, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableName(in.Table, f); err != nil { - return err - } - return nil -} - -// rewriteRefOfExplainTab is part of the Rewrite implementation -func (a *application) rewriteRefOfExplainTab(parent SQLNode, node *ExplainTab, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { - parent.(*ExplainTab).Table = newNode.(TableName) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort + switch a := inA.(type) { + case *AliasedExpr: + b, ok := inB.(*AliasedExpr) + if !ok { + return false + } + return EqualsRefOfAliasedExpr(a, b) + case *Nextval: + b, ok := inB.(*Nextval) + if !ok { + return false + } + return EqualsRefOfNextval(a, b) + case *StarExpr: + b, ok := inB.(*StarExpr) + if !ok { + return false + } + return EqualsRefOfStarExpr(a, b) + default: + // this should never happen + return false } - return nil } -// EqualsExprs does deep equals between the two objects. -func EqualsExprs(a, b Exprs) bool { +// EqualsSelectExprs does deep equals between the two objects. +func EqualsSelectExprs(a, b SelectExprs) bool { if len(a) != len(b) { return false } for i := 0; i < len(a); i++ { - if !EqualsExpr(a[i], b[i]) { + if !EqualsSelectExpr(a[i], b[i]) { return false } } return true } -// CloneExprs creates a deep clone of the input. -func CloneExprs(n Exprs) Exprs { - res := make(Exprs, 0, len(n)) - for _, x := range n { - res = append(res, CloneExpr(x)) - } - return res -} - -// VisitExprs will visit all parts of the AST -func VisitExprs(in Exprs, f Visit) error { - if in == nil { - return nil +// EqualsSelectStatement does deep equals between the two objects. +func EqualsSelectStatement(inA, inB SelectStatement) bool { + if inA == nil && inB == nil { + return true } - if cont, err := f(in); err != nil || !cont { - return err + if inA == nil || inB == nil { + return false } - for _, el := range in { - if err := VisitExpr(el, f); err != nil { - return err + switch a := inA.(type) { + case *ParenSelect: + b, ok := inB.(*ParenSelect) + if !ok { + return false + } + return EqualsRefOfParenSelect(a, b) + case *Select: + b, ok := inB.(*Select) + if !ok { + return false + } + return EqualsRefOfSelect(a, b) + case *Union: + b, ok := inB.(*Union) + if !ok { + return false } + return EqualsRefOfUnion(a, b) + default: + // this should never happen + return false } - return nil } -// rewriteExprs is part of the Rewrite implementation -func (a *application) rewriteExprs(parent SQLNode, node Exprs, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil +// EqualsSetExprs does deep equals between the two objects. +func EqualsSetExprs(a, b SetExprs) bool { + if len(a) != len(b) { + return false } - for i, el := range node { - if errF := a.rewriteExpr(node, el, func(newNode, parent SQLNode) { - parent.(Exprs)[i] = newNode.(Expr) - }); errF != nil { - return errF + for i := 0; i < len(a); i++ { + if !EqualsRefOfSetExpr(a[i], b[i]) { + return false } } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + return true } -// EqualsRefOfFlush does deep equals between the two objects. -func EqualsRefOfFlush(a, b *Flush) bool { - if a == b { +// EqualsShowInternal does deep equals between the two objects. +func EqualsShowInternal(inA, inB ShowInternal) bool { + if inA == nil && inB == nil { return true } - if a == nil || b == nil { + if inA == nil || inB == nil { return false } - return a.IsLocal == b.IsLocal && - a.WithLock == b.WithLock && - a.ForExport == b.ForExport && - EqualsSliceOfString(a.FlushOptions, b.FlushOptions) && - EqualsTableNames(a.TableNames, b.TableNames) -} - -// CloneRefOfFlush creates a deep clone of the input. -func CloneRefOfFlush(n *Flush) *Flush { - if n == nil { - return nil - } - out := *n - out.FlushOptions = CloneSliceOfString(n.FlushOptions) - out.TableNames = CloneTableNames(n.TableNames) - return &out -} - -// VisitRefOfFlush will visit all parts of the AST -func VisitRefOfFlush(in *Flush, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableNames(in.TableNames, f); err != nil { - return err + switch a := inA.(type) { + case *ShowBasic: + b, ok := inB.(*ShowBasic) + if !ok { + return false + } + return EqualsRefOfShowBasic(a, b) + case *ShowCreate: + b, ok := inB.(*ShowCreate) + if !ok { + return false + } + return EqualsRefOfShowCreate(a, b) + case *ShowLegacy: + b, ok := inB.(*ShowLegacy) + if !ok { + return false + } + return EqualsRefOfShowLegacy(a, b) + default: + // this should never happen + return false } - return nil } -// rewriteRefOfFlush is part of the Rewrite implementation -func (a *application) rewriteRefOfFlush(parent SQLNode, node *Flush, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil +// EqualsSimpleTableExpr does deep equals between the two objects. +func EqualsSimpleTableExpr(inA, inB SimpleTableExpr) bool { + if inA == nil && inB == nil { + return true } - if errF := a.rewriteTableNames(node, node.TableNames, func(newNode, parent SQLNode) { - parent.(*Flush).TableNames = newNode.(TableNames) - }); errF != nil { - return errF + if inA == nil || inB == nil { + return false } - if a.post != nil && !a.post(&cur) { - return errAbort + switch a := inA.(type) { + case *DerivedTable: + b, ok := inB.(*DerivedTable) + if !ok { + return false + } + return EqualsRefOfDerivedTable(a, b) + case TableName: + b, ok := inB.(TableName) + if !ok { + return false + } + return EqualsTableName(a, b) + default: + // this should never happen + return false } - return nil } -// EqualsRefOfForce does deep equals between the two objects. -func EqualsRefOfForce(a, b *Force) bool { - if a == b { - return true - } - if a == nil || b == nil { +// EqualsSliceOfAlterOption does deep equals between the two objects. +func EqualsSliceOfAlterOption(a, b []AlterOption) bool { + if len(a) != len(b) { return false } + for i := 0; i < len(a); i++ { + if !EqualsAlterOption(a[i], b[i]) { + return false + } + } return true } -// CloneRefOfForce creates a deep clone of the input. -func CloneRefOfForce(n *Force) *Force { - if n == nil { - return nil +// EqualsSliceOfCharacteristic does deep equals between the two objects. +func EqualsSliceOfCharacteristic(a, b []Characteristic) bool { + if len(a) != len(b) { + return false } - out := *n - return &out + for i := 0; i < len(a); i++ { + if !EqualsCharacteristic(a[i], b[i]) { + return false + } + } + return true } -// VisitRefOfForce will visit all parts of the AST -func VisitRefOfForce(in *Force, f Visit) error { - if in == nil { - return nil +// EqualsSliceOfColIdent does deep equals between the two objects. +func EqualsSliceOfColIdent(a, b []ColIdent) bool { + if len(a) != len(b) { + return false } - if cont, err := f(in); err != nil || !cont { - return err + for i := 0; i < len(a); i++ { + if !EqualsColIdent(a[i], b[i]) { + return false + } } - return nil + return true } -// rewriteRefOfForce is part of the Rewrite implementation -func (a *application) rewriteRefOfForce(parent SQLNode, node *Force, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil +// EqualsSliceOfCollateAndCharset does deep equals between the two objects. +func EqualsSliceOfCollateAndCharset(a, b []CollateAndCharset) bool { + if len(a) != len(b) { + return false } - if a.post != nil && !a.post(&cur) { - return errAbort + for i := 0; i < len(a); i++ { + if !EqualsCollateAndCharset(a[i], b[i]) { + return false + } } - return nil + return true } -// EqualsRefOfForeignKeyDefinition does deep equals between the two objects. -func EqualsRefOfForeignKeyDefinition(a, b *ForeignKeyDefinition) bool { - if a == b { - return true - } - if a == nil || b == nil { +// EqualsSliceOfRefOfColumnDefinition does deep equals between the two objects. +func EqualsSliceOfRefOfColumnDefinition(a, b []*ColumnDefinition) bool { + if len(a) != len(b) { return false } - return EqualsColumns(a.Source, b.Source) && - EqualsTableName(a.ReferencedTable, b.ReferencedTable) && - EqualsColumns(a.ReferencedColumns, b.ReferencedColumns) && - a.OnDelete == b.OnDelete && - a.OnUpdate == b.OnUpdate -} - -// CloneRefOfForeignKeyDefinition creates a deep clone of the input. -func CloneRefOfForeignKeyDefinition(n *ForeignKeyDefinition) *ForeignKeyDefinition { - if n == nil { - return nil + for i := 0; i < len(a); i++ { + if !EqualsRefOfColumnDefinition(a[i], b[i]) { + return false + } } - out := *n - out.Source = CloneColumns(n.Source) - out.ReferencedTable = CloneTableName(n.ReferencedTable) - out.ReferencedColumns = CloneColumns(n.ReferencedColumns) - return &out + return true } -// VisitRefOfForeignKeyDefinition will visit all parts of the AST -func VisitRefOfForeignKeyDefinition(in *ForeignKeyDefinition, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitColumns(in.Source, f); err != nil { - return err - } - if err := VisitTableName(in.ReferencedTable, f); err != nil { - return err - } - if err := VisitColumns(in.ReferencedColumns, f); err != nil { - return err - } - if err := VisitReferenceAction(in.OnDelete, f); err != nil { - return err +// EqualsSliceOfRefOfConstraintDefinition does deep equals between the two objects. +func EqualsSliceOfRefOfConstraintDefinition(a, b []*ConstraintDefinition) bool { + if len(a) != len(b) { + return false } - if err := VisitReferenceAction(in.OnUpdate, f); err != nil { - return err + for i := 0; i < len(a); i++ { + if !EqualsRefOfConstraintDefinition(a[i], b[i]) { + return false + } } - return nil + return true } -// rewriteRefOfForeignKeyDefinition is part of the Rewrite implementation -func (a *application) rewriteRefOfForeignKeyDefinition(parent SQLNode, node *ForeignKeyDefinition, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - if errF := a.rewriteColumns(node, node.Source, func(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).Source = newNode.(Columns) - }); errF != nil { - return errF +// EqualsSliceOfRefOfIndexColumn does deep equals between the two objects. +func EqualsSliceOfRefOfIndexColumn(a, b []*IndexColumn) bool { + if len(a) != len(b) { + return false } - if errF := a.rewriteTableName(node, node.ReferencedTable, func(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).ReferencedTable = newNode.(TableName) - }); errF != nil { - return errF + for i := 0; i < len(a); i++ { + if !EqualsRefOfIndexColumn(a[i], b[i]) { + return false + } } - if errF := a.rewriteColumns(node, node.ReferencedColumns, func(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).ReferencedColumns = newNode.(Columns) - }); errF != nil { - return errF + return true +} + +// EqualsSliceOfRefOfIndexDefinition does deep equals between the two objects. +func EqualsSliceOfRefOfIndexDefinition(a, b []*IndexDefinition) bool { + if len(a) != len(b) { + return false } - if errF := a.rewriteReferenceAction(node, node.OnDelete, func(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).OnDelete = newNode.(ReferenceAction) - }); errF != nil { - return errF + for i := 0; i < len(a); i++ { + if !EqualsRefOfIndexDefinition(a[i], b[i]) { + return false + } } - if errF := a.rewriteReferenceAction(node, node.OnUpdate, func(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).OnUpdate = newNode.(ReferenceAction) - }); errF != nil { - return errF + return true +} + +// EqualsSliceOfRefOfIndexOption does deep equals between the two objects. +func EqualsSliceOfRefOfIndexOption(a, b []*IndexOption) bool { + if len(a) != len(b) { + return false } - if a.post != nil && !a.post(&cur) { - return errAbort + for i := 0; i < len(a); i++ { + if !EqualsRefOfIndexOption(a[i], b[i]) { + return false + } } - return nil + return true } -// EqualsRefOfFuncExpr does deep equals between the two objects. -func EqualsRefOfFuncExpr(a, b *FuncExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { +// EqualsSliceOfRefOfPartitionDefinition does deep equals between the two objects. +func EqualsSliceOfRefOfPartitionDefinition(a, b []*PartitionDefinition) bool { + if len(a) != len(b) { return false } - return a.Distinct == b.Distinct && - EqualsTableIdent(a.Qualifier, b.Qualifier) && - EqualsColIdent(a.Name, b.Name) && - EqualsSelectExprs(a.Exprs, b.Exprs) + for i := 0; i < len(a); i++ { + if !EqualsRefOfPartitionDefinition(a[i], b[i]) { + return false + } + } + return true } -// CloneRefOfFuncExpr creates a deep clone of the input. -func CloneRefOfFuncExpr(n *FuncExpr) *FuncExpr { - if n == nil { - return nil +// EqualsSliceOfRefOfRenameTablePair does deep equals between the two objects. +func EqualsSliceOfRefOfRenameTablePair(a, b []*RenameTablePair) bool { + if len(a) != len(b) { + return false } - out := *n - out.Qualifier = CloneTableIdent(n.Qualifier) - out.Name = CloneColIdent(n.Name) - out.Exprs = CloneSelectExprs(n.Exprs) - return &out + for i := 0; i < len(a); i++ { + if !EqualsRefOfRenameTablePair(a[i], b[i]) { + return false + } + } + return true } -// VisitRefOfFuncExpr will visit all parts of the AST -func VisitRefOfFuncExpr(in *FuncExpr, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsSliceOfRefOfUnionSelect does deep equals between the two objects. +func EqualsSliceOfRefOfUnionSelect(a, b []*UnionSelect) bool { + if len(a) != len(b) { + return false } - if err := VisitTableIdent(in.Qualifier, f); err != nil { - return err + for i := 0; i < len(a); i++ { + if !EqualsRefOfUnionSelect(a[i], b[i]) { + return false + } } - if err := VisitColIdent(in.Name, f); err != nil { - return err + return true +} + +// EqualsSliceOfRefOfWhen does deep equals between the two objects. +func EqualsSliceOfRefOfWhen(a, b []*When) bool { + if len(a) != len(b) { + return false } - if err := VisitSelectExprs(in.Exprs, f); err != nil { - return err + for i := 0; i < len(a); i++ { + if !EqualsRefOfWhen(a[i], b[i]) { + return false + } } - return nil + return true } -// rewriteRefOfFuncExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfFuncExpr(parent SQLNode, node *FuncExpr, replacer replacerFunc) error { - if node == nil { - return nil +// EqualsSliceOfString does deep equals between the two objects. +func EqualsSliceOfString(a, b []string) bool { + if len(a) != len(b) { + return false } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false + } } - if a.pre != nil && !a.pre(&cur) { - return nil + return true +} + +// EqualsSliceOfVindexParam does deep equals between the two objects. +func EqualsSliceOfVindexParam(a, b []VindexParam) bool { + if len(a) != len(b) { + return false } - if errF := a.rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { - parent.(*FuncExpr).Qualifier = newNode.(TableIdent) - }); errF != nil { - return errF + for i := 0; i < len(a); i++ { + if !EqualsVindexParam(a[i], b[i]) { + return false + } } - if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { - parent.(*FuncExpr).Name = newNode.(ColIdent) - }); errF != nil { - return errF + return true +} + +// EqualsStatement does deep equals between the two objects. +func EqualsStatement(inA, inB Statement) bool { + if inA == nil && inB == nil { + return true } - if errF := a.rewriteSelectExprs(node, node.Exprs, func(newNode, parent SQLNode) { - parent.(*FuncExpr).Exprs = newNode.(SelectExprs) - }); errF != nil { - return errF + if inA == nil || inB == nil { + return false } - if a.post != nil && !a.post(&cur) { - return errAbort + switch a := inA.(type) { + case *AlterDatabase: + b, ok := inB.(*AlterDatabase) + if !ok { + return false + } + return EqualsRefOfAlterDatabase(a, b) + case *AlterMigration: + b, ok := inB.(*AlterMigration) + if !ok { + return false + } + return EqualsRefOfAlterMigration(a, b) + case *AlterTable: + b, ok := inB.(*AlterTable) + if !ok { + return false + } + return EqualsRefOfAlterTable(a, b) + case *AlterView: + b, ok := inB.(*AlterView) + if !ok { + return false + } + return EqualsRefOfAlterView(a, b) + case *AlterVschema: + b, ok := inB.(*AlterVschema) + if !ok { + return false + } + return EqualsRefOfAlterVschema(a, b) + case *Begin: + b, ok := inB.(*Begin) + if !ok { + return false + } + return EqualsRefOfBegin(a, b) + case *CallProc: + b, ok := inB.(*CallProc) + if !ok { + return false + } + return EqualsRefOfCallProc(a, b) + case *Commit: + b, ok := inB.(*Commit) + if !ok { + return false + } + return EqualsRefOfCommit(a, b) + case *CreateDatabase: + b, ok := inB.(*CreateDatabase) + if !ok { + return false + } + return EqualsRefOfCreateDatabase(a, b) + case *CreateTable: + b, ok := inB.(*CreateTable) + if !ok { + return false + } + return EqualsRefOfCreateTable(a, b) + case *CreateView: + b, ok := inB.(*CreateView) + if !ok { + return false + } + return EqualsRefOfCreateView(a, b) + case *Delete: + b, ok := inB.(*Delete) + if !ok { + return false + } + return EqualsRefOfDelete(a, b) + case *DropDatabase: + b, ok := inB.(*DropDatabase) + if !ok { + return false + } + return EqualsRefOfDropDatabase(a, b) + case *DropTable: + b, ok := inB.(*DropTable) + if !ok { + return false + } + return EqualsRefOfDropTable(a, b) + case *DropView: + b, ok := inB.(*DropView) + if !ok { + return false + } + return EqualsRefOfDropView(a, b) + case *ExplainStmt: + b, ok := inB.(*ExplainStmt) + if !ok { + return false + } + return EqualsRefOfExplainStmt(a, b) + case *ExplainTab: + b, ok := inB.(*ExplainTab) + if !ok { + return false + } + return EqualsRefOfExplainTab(a, b) + case *Flush: + b, ok := inB.(*Flush) + if !ok { + return false + } + return EqualsRefOfFlush(a, b) + case *Insert: + b, ok := inB.(*Insert) + if !ok { + return false + } + return EqualsRefOfInsert(a, b) + case *Load: + b, ok := inB.(*Load) + if !ok { + return false + } + return EqualsRefOfLoad(a, b) + case *LockTables: + b, ok := inB.(*LockTables) + if !ok { + return false + } + return EqualsRefOfLockTables(a, b) + case *OtherAdmin: + b, ok := inB.(*OtherAdmin) + if !ok { + return false + } + return EqualsRefOfOtherAdmin(a, b) + case *OtherRead: + b, ok := inB.(*OtherRead) + if !ok { + return false + } + return EqualsRefOfOtherRead(a, b) + case *ParenSelect: + b, ok := inB.(*ParenSelect) + if !ok { + return false + } + return EqualsRefOfParenSelect(a, b) + case *Release: + b, ok := inB.(*Release) + if !ok { + return false + } + return EqualsRefOfRelease(a, b) + case *RenameTable: + b, ok := inB.(*RenameTable) + if !ok { + return false + } + return EqualsRefOfRenameTable(a, b) + case *RevertMigration: + b, ok := inB.(*RevertMigration) + if !ok { + return false + } + return EqualsRefOfRevertMigration(a, b) + case *Rollback: + b, ok := inB.(*Rollback) + if !ok { + return false + } + return EqualsRefOfRollback(a, b) + case *SRollback: + b, ok := inB.(*SRollback) + if !ok { + return false + } + return EqualsRefOfSRollback(a, b) + case *Savepoint: + b, ok := inB.(*Savepoint) + if !ok { + return false + } + return EqualsRefOfSavepoint(a, b) + case *Select: + b, ok := inB.(*Select) + if !ok { + return false + } + return EqualsRefOfSelect(a, b) + case *Set: + b, ok := inB.(*Set) + if !ok { + return false + } + return EqualsRefOfSet(a, b) + case *SetTransaction: + b, ok := inB.(*SetTransaction) + if !ok { + return false + } + return EqualsRefOfSetTransaction(a, b) + case *Show: + b, ok := inB.(*Show) + if !ok { + return false + } + return EqualsRefOfShow(a, b) + case *Stream: + b, ok := inB.(*Stream) + if !ok { + return false + } + return EqualsRefOfStream(a, b) + case *TruncateTable: + b, ok := inB.(*TruncateTable) + if !ok { + return false + } + return EqualsRefOfTruncateTable(a, b) + case *Union: + b, ok := inB.(*Union) + if !ok { + return false + } + return EqualsRefOfUnion(a, b) + case *UnlockTables: + b, ok := inB.(*UnlockTables) + if !ok { + return false + } + return EqualsRefOfUnlockTables(a, b) + case *Update: + b, ok := inB.(*Update) + if !ok { + return false + } + return EqualsRefOfUpdate(a, b) + case *Use: + b, ok := inB.(*Use) + if !ok { + return false + } + return EqualsRefOfUse(a, b) + case *VStream: + b, ok := inB.(*VStream) + if !ok { + return false + } + return EqualsRefOfVStream(a, b) + default: + // this should never happen + return false } - return nil } -// EqualsGroupBy does deep equals between the two objects. -func EqualsGroupBy(a, b GroupBy) bool { +// EqualsTableAndLockTypes does deep equals between the two objects. +func EqualsTableAndLockTypes(a, b TableAndLockTypes) bool { if len(a) != len(b) { return false } for i := 0; i < len(a); i++ { - if !EqualsExpr(a[i], b[i]) { + if !EqualsRefOfTableAndLockType(a[i], b[i]) { return false } } return true } -// CloneGroupBy creates a deep clone of the input. -func CloneGroupBy(n GroupBy) GroupBy { - res := make(GroupBy, 0, len(n)) - for _, x := range n { - res = append(res, CloneExpr(x)) - } - return res -} - -// VisitGroupBy will visit all parts of the AST -func VisitGroupBy(in GroupBy, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - for _, el := range in { - if err := VisitExpr(el, f); err != nil { - return err - } - } - return nil -} - -// rewriteGroupBy is part of the Rewrite implementation -func (a *application) rewriteGroupBy(parent SQLNode, node GroupBy, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - for i, el := range node { - if errF := a.rewriteExpr(node, el, func(newNode, parent SQLNode) { - parent.(GroupBy)[i] = newNode.(Expr) - }); errF != nil { - return errF - } - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil -} - -// EqualsRefOfGroupConcatExpr does deep equals between the two objects. -func EqualsRefOfGroupConcatExpr(a, b *GroupConcatExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Distinct == b.Distinct && - a.Separator == b.Separator && - EqualsSelectExprs(a.Exprs, b.Exprs) && - EqualsOrderBy(a.OrderBy, b.OrderBy) && - EqualsRefOfLimit(a.Limit, b.Limit) -} - -// CloneRefOfGroupConcatExpr creates a deep clone of the input. -func CloneRefOfGroupConcatExpr(n *GroupConcatExpr) *GroupConcatExpr { - if n == nil { - return nil - } - out := *n - out.Exprs = CloneSelectExprs(n.Exprs) - out.OrderBy = CloneOrderBy(n.OrderBy) - out.Limit = CloneRefOfLimit(n.Limit) - return &out -} - -// VisitRefOfGroupConcatExpr will visit all parts of the AST -func VisitRefOfGroupConcatExpr(in *GroupConcatExpr, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitSelectExprs(in.Exprs, f); err != nil { - return err - } - if err := VisitOrderBy(in.OrderBy, f); err != nil { - return err - } - if err := VisitRefOfLimit(in.Limit, f); err != nil { - return err - } - return nil -} - -// rewriteRefOfGroupConcatExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupConcatExpr, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - if errF := a.rewriteSelectExprs(node, node.Exprs, func(newNode, parent SQLNode) { - parent.(*GroupConcatExpr).Exprs = newNode.(SelectExprs) - }); errF != nil { - return errF - } - if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { - parent.(*GroupConcatExpr).OrderBy = newNode.(OrderBy) - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { - parent.(*GroupConcatExpr).Limit = newNode.(*Limit) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil -} - -// EqualsRefOfIndexDefinition does deep equals between the two objects. -func EqualsRefOfIndexDefinition(a, b *IndexDefinition) bool { - if a == b { +// EqualsTableExpr does deep equals between the two objects. +func EqualsTableExpr(inA, inB TableExpr) bool { + if inA == nil && inB == nil { return true } - if a == nil || b == nil { + if inA == nil || inB == nil { return false } - return EqualsRefOfIndexInfo(a.Info, b.Info) && - EqualsSliceOfRefOfIndexColumn(a.Columns, b.Columns) && - EqualsSliceOfRefOfIndexOption(a.Options, b.Options) -} - -// CloneRefOfIndexDefinition creates a deep clone of the input. -func CloneRefOfIndexDefinition(n *IndexDefinition) *IndexDefinition { - if n == nil { - return nil - } - out := *n - out.Info = CloneRefOfIndexInfo(n.Info) - out.Columns = CloneSliceOfRefOfIndexColumn(n.Columns) - out.Options = CloneSliceOfRefOfIndexOption(n.Options) - return &out -} - -// VisitRefOfIndexDefinition will visit all parts of the AST -func VisitRefOfIndexDefinition(in *IndexDefinition, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitRefOfIndexInfo(in.Info, f); err != nil { - return err - } - return nil -} - -// rewriteRefOfIndexDefinition is part of the Rewrite implementation -func (a *application) rewriteRefOfIndexDefinition(parent SQLNode, node *IndexDefinition, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - if errF := a.rewriteRefOfIndexInfo(node, node.Info, func(newNode, parent SQLNode) { - parent.(*IndexDefinition).Info = newNode.(*IndexInfo) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort + switch a := inA.(type) { + case *AliasedTableExpr: + b, ok := inB.(*AliasedTableExpr) + if !ok { + return false + } + return EqualsRefOfAliasedTableExpr(a, b) + case *JoinTableExpr: + b, ok := inB.(*JoinTableExpr) + if !ok { + return false + } + return EqualsRefOfJoinTableExpr(a, b) + case *ParenTableExpr: + b, ok := inB.(*ParenTableExpr) + if !ok { + return false + } + return EqualsRefOfParenTableExpr(a, b) + default: + // this should never happen + return false } - return nil } -// EqualsRefOfIndexHints does deep equals between the two objects. -func EqualsRefOfIndexHints(a, b *IndexHints) bool { - if a == b { - return true - } - if a == nil || b == nil { +// EqualsTableExprs does deep equals between the two objects. +func EqualsTableExprs(a, b TableExprs) bool { + if len(a) != len(b) { return false } - return a.Type == b.Type && - EqualsSliceOfColIdent(a.Indexes, b.Indexes) + for i := 0; i < len(a); i++ { + if !EqualsTableExpr(a[i], b[i]) { + return false + } + } + return true } -// CloneRefOfIndexHints creates a deep clone of the input. -func CloneRefOfIndexHints(n *IndexHints) *IndexHints { - if n == nil { - return nil - } - out := *n - out.Indexes = CloneSliceOfColIdent(n.Indexes) - return &out +// EqualsTableIdent does deep equals between the two objects. +func EqualsTableIdent(a, b TableIdent) bool { + return a.v == b.v } -// VisitRefOfIndexHints will visit all parts of the AST -func VisitRefOfIndexHints(in *IndexHints, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsTableName does deep equals between the two objects. +func EqualsTableName(a, b TableName) bool { + return EqualsTableIdent(a.Name, b.Name) && + EqualsTableIdent(a.Qualifier, b.Qualifier) +} + +// EqualsTableNames does deep equals between the two objects. +func EqualsTableNames(a, b TableNames) bool { + if len(a) != len(b) { + return false } - for _, el := range in.Indexes { - if err := VisitColIdent(el, f); err != nil { - return err + for i := 0; i < len(a); i++ { + if !EqualsTableName(a[i], b[i]) { + return false } } - return nil + return true } -// rewriteRefOfIndexHints is part of the Rewrite implementation -func (a *application) rewriteRefOfIndexHints(parent SQLNode, node *IndexHints, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil +// EqualsTableOptions does deep equals between the two objects. +func EqualsTableOptions(a, b TableOptions) bool { + if len(a) != len(b) { + return false } - for i, el := range node.Indexes { - if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { - parent.(*IndexHints).Indexes[i] = newNode.(ColIdent) - }); errF != nil { - return errF + for i := 0; i < len(a); i++ { + if !EqualsRefOfTableOption(a[i], b[i]) { + return false } } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil + return true } -// EqualsRefOfIndexInfo does deep equals between the two objects. -func EqualsRefOfIndexInfo(a, b *IndexInfo) bool { - if a == b { - return true - } - if a == nil || b == nil { +// EqualsUpdateExprs does deep equals between the two objects. +func EqualsUpdateExprs(a, b UpdateExprs) bool { + if len(a) != len(b) { return false } - return a.Type == b.Type && - a.Primary == b.Primary && - a.Spatial == b.Spatial && - a.Fulltext == b.Fulltext && - a.Unique == b.Unique && - EqualsColIdent(a.Name, b.Name) && - EqualsColIdent(a.ConstraintName, b.ConstraintName) -} - -// CloneRefOfIndexInfo creates a deep clone of the input. -func CloneRefOfIndexInfo(n *IndexInfo) *IndexInfo { - if n == nil { - return nil + for i := 0; i < len(a); i++ { + if !EqualsRefOfUpdateExpr(a[i], b[i]) { + return false + } } - out := *n - out.Name = CloneColIdent(n.Name) - out.ConstraintName = CloneColIdent(n.ConstraintName) - return &out + return true } -// VisitRefOfIndexInfo will visit all parts of the AST -func VisitRefOfIndexInfo(in *IndexInfo, f Visit) error { - if in == nil { - return nil +// EqualsValTuple does deep equals between the two objects. +func EqualsValTuple(a, b ValTuple) bool { + if len(a) != len(b) { + return false } - if cont, err := f(in); err != nil || !cont { - return err + for i := 0; i < len(a); i++ { + if !EqualsExpr(a[i], b[i]) { + return false + } } - if err := VisitColIdent(in.Name, f); err != nil { - return err + return true +} + +// EqualsValues does deep equals between the two objects. +func EqualsValues(a, b Values) bool { + if len(a) != len(b) { + return false } - if err := VisitColIdent(in.ConstraintName, f); err != nil { - return err + for i := 0; i < len(a); i++ { + if !EqualsValTuple(a[i], b[i]) { + return false + } } - return nil + return true } -// rewriteRefOfIndexInfo is part of the Rewrite implementation -func (a *application) rewriteRefOfIndexInfo(parent SQLNode, node *IndexInfo, replacer replacerFunc) error { - if node == nil { +// EqualsVindexParam does deep equals between the two objects. +func EqualsVindexParam(a, b VindexParam) bool { + return a.Val == b.Val && + EqualsColIdent(a.Key, b.Key) +} +func VisitAccessMode(in AccessMode, f Visit) error { + _, err := f(in) + return err +} +func VisitAlgorithmValue(in AlgorithmValue, f Visit) error { + _, err := f(in) + return err +} +func VisitAlterOption(in AlterOption, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + switch in := in.(type) { + case *AddColumns: + return VisitRefOfAddColumns(in, f) + case *AddConstraintDefinition: + return VisitRefOfAddConstraintDefinition(in, f) + case *AddIndexDefinition: + return VisitRefOfAddIndexDefinition(in, f) + case AlgorithmValue: + return VisitAlgorithmValue(in, f) + case *AlterCharset: + return VisitRefOfAlterCharset(in, f) + case *AlterColumn: + return VisitRefOfAlterColumn(in, f) + case *ChangeColumn: + return VisitRefOfChangeColumn(in, f) + case *DropColumn: + return VisitRefOfDropColumn(in, f) + case *DropKey: + return VisitRefOfDropKey(in, f) + case *Force: + return VisitRefOfForce(in, f) + case *KeyState: + return VisitRefOfKeyState(in, f) + case *LockOption: + return VisitRefOfLockOption(in, f) + case *ModifyColumn: + return VisitRefOfModifyColumn(in, f) + case *OrderByOption: + return VisitRefOfOrderByOption(in, f) + case *RenameIndex: + return VisitRefOfRenameIndex(in, f) + case *RenameTableName: + return VisitRefOfRenameTableName(in, f) + case TableOptions: + return VisitTableOptions(in, f) + case *TablespaceOperation: + return VisitRefOfTablespaceOperation(in, f) + case *Validation: + return VisitRefOfValidation(in, f) + default: + // this should never happen return nil } - if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { - parent.(*IndexInfo).Name = newNode.(ColIdent) - }); errF != nil { - return errF - } - if errF := a.rewriteColIdent(node, node.ConstraintName, func(newNode, parent SQLNode) { - parent.(*IndexInfo).ConstraintName = newNode.(ColIdent) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil } - -// EqualsRefOfInsert does deep equals between the two objects. -func EqualsRefOfInsert(a, b *Insert) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Action == b.Action && - EqualsComments(a.Comments, b.Comments) && - a.Ignore == b.Ignore && - EqualsTableName(a.Table, b.Table) && - EqualsPartitions(a.Partitions, b.Partitions) && - EqualsColumns(a.Columns, b.Columns) && - EqualsInsertRows(a.Rows, b.Rows) && - EqualsOnDup(a.OnDup, b.OnDup) +func VisitArgument(in Argument, f Visit) error { + _, err := f(in) + return err +} +func VisitBoolVal(in BoolVal, f Visit) error { + _, err := f(in) + return err } - -// CloneRefOfInsert creates a deep clone of the input. -func CloneRefOfInsert(n *Insert) *Insert { - if n == nil { +func VisitCharacteristic(in Characteristic, f Visit) error { + if in == nil { return nil } - out := *n - out.Comments = CloneComments(n.Comments) - out.Table = CloneTableName(n.Table) - out.Partitions = ClonePartitions(n.Partitions) - out.Columns = CloneColumns(n.Columns) - out.Rows = CloneInsertRows(n.Rows) - out.OnDup = CloneOnDup(n.OnDup) - return &out -} - -// VisitRefOfInsert will visit all parts of the AST -func VisitRefOfInsert(in *Insert, f Visit) error { - if in == nil { + switch in := in.(type) { + case AccessMode: + return VisitAccessMode(in, f) + case IsolationLevel: + return VisitIsolationLevel(in, f) + default: + // this should never happen return nil } +} +func VisitColIdent(in ColIdent, f Visit) error { if cont, err := f(in); err != nil || !cont { return err } - if err := VisitComments(in.Comments, f); err != nil { - return err - } - if err := VisitTableName(in.Table, f); err != nil { - return err + return nil +} +func VisitColTuple(in ColTuple, f Visit) error { + if in == nil { + return nil } - if err := VisitPartitions(in.Partitions, f); err != nil { - return err + switch in := in.(type) { + case ListArg: + return VisitListArg(in, f) + case *Subquery: + return VisitRefOfSubquery(in, f) + case ValTuple: + return VisitValTuple(in, f) + default: + // this should never happen + return nil } - if err := VisitColumns(in.Columns, f); err != nil { - return err +} +func VisitColumns(in Columns, f Visit) error { + if in == nil { + return nil } - if err := VisitInsertRows(in.Rows, f); err != nil { + if cont, err := f(in); err != nil || !cont { return err } - if err := VisitOnDup(in.OnDup, f); err != nil { - return err + for _, el := range in { + if err := VisitColIdent(el, f); err != nil { + return err + } } return nil } - -// rewriteRefOfInsert is part of the Rewrite implementation -func (a *application) rewriteRefOfInsert(parent SQLNode, node *Insert, replacer replacerFunc) error { - if node == nil { +func VisitComments(in Comments, f Visit) error { + _, err := f(in) + return err +} +func VisitConstraintInfo(in ConstraintInfo, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + switch in := in.(type) { + case *CheckConstraintDefinition: + return VisitRefOfCheckConstraintDefinition(in, f) + case *ForeignKeyDefinition: + return VisitRefOfForeignKeyDefinition(in, f) + default: + // this should never happen return nil } - if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { - parent.(*Insert).Comments = newNode.(Comments) - }); errF != nil { - return errF - } - if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { - parent.(*Insert).Table = newNode.(TableName) - }); errF != nil { - return errF - } - if errF := a.rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { - parent.(*Insert).Partitions = newNode.(Partitions) - }); errF != nil { - return errF - } - if errF := a.rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { - parent.(*Insert).Columns = newNode.(Columns) - }); errF != nil { - return errF - } - if errF := a.rewriteInsertRows(node, node.Rows, func(newNode, parent SQLNode) { - parent.(*Insert).Rows = newNode.(InsertRows) - }); errF != nil { - return errF - } - if errF := a.rewriteOnDup(node, node.OnDup, func(newNode, parent SQLNode) { - parent.(*Insert).OnDup = newNode.(OnDup) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil } - -// EqualsRefOfIntervalExpr does deep equals between the two objects. -func EqualsRefOfIntervalExpr(a, b *IntervalExpr) bool { - if a == b { - return true +func VisitDBDDLStatement(in DBDDLStatement, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + switch in := in.(type) { + case *AlterDatabase: + return VisitRefOfAlterDatabase(in, f) + case *CreateDatabase: + return VisitRefOfCreateDatabase(in, f) + case *DropDatabase: + return VisitRefOfDropDatabase(in, f) + default: + // this should never happen + return nil } - return a.Unit == b.Unit && - EqualsExpr(a.Expr, b.Expr) } - -// CloneRefOfIntervalExpr creates a deep clone of the input. -func CloneRefOfIntervalExpr(n *IntervalExpr) *IntervalExpr { - if n == nil { +func VisitDDLStatement(in DDLStatement, f Visit) error { + if in == nil { + return nil + } + switch in := in.(type) { + case *AlterTable: + return VisitRefOfAlterTable(in, f) + case *AlterView: + return VisitRefOfAlterView(in, f) + case *CreateTable: + return VisitRefOfCreateTable(in, f) + case *CreateView: + return VisitRefOfCreateView(in, f) + case *DropTable: + return VisitRefOfDropTable(in, f) + case *DropView: + return VisitRefOfDropView(in, f) + case *RenameTable: + return VisitRefOfRenameTable(in, f) + case *TruncateTable: + return VisitRefOfTruncateTable(in, f) + default: + // this should never happen return nil } - out := *n - out.Expr = CloneExpr(n.Expr) - return &out } - -// VisitRefOfIntervalExpr will visit all parts of the AST -func VisitRefOfIntervalExpr(in *IntervalExpr, f Visit) error { +func VisitExplain(in Explain, f Visit) error { if in == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Expr, f); err != nil { - return err + switch in := in.(type) { + case *ExplainStmt: + return VisitRefOfExplainStmt(in, f) + case *ExplainTab: + return VisitRefOfExplainTab(in, f) + default: + // this should never happen + return nil } - return nil } - -// rewriteRefOfIntervalExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfIntervalExpr(parent SQLNode, node *IntervalExpr, replacer replacerFunc) error { - if node == nil { +func VisitExpr(in Expr, f Visit) error { + if in == nil { + return nil + } + switch in := in.(type) { + case *AndExpr: + return VisitRefOfAndExpr(in, f) + case Argument: + return VisitArgument(in, f) + case *BinaryExpr: + return VisitRefOfBinaryExpr(in, f) + case BoolVal: + return VisitBoolVal(in, f) + case *CaseExpr: + return VisitRefOfCaseExpr(in, f) + case *ColName: + return VisitRefOfColName(in, f) + case *CollateExpr: + return VisitRefOfCollateExpr(in, f) + case *ComparisonExpr: + return VisitRefOfComparisonExpr(in, f) + case *ConvertExpr: + return VisitRefOfConvertExpr(in, f) + case *ConvertUsingExpr: + return VisitRefOfConvertUsingExpr(in, f) + case *CurTimeFuncExpr: + return VisitRefOfCurTimeFuncExpr(in, f) + case *Default: + return VisitRefOfDefault(in, f) + case *ExistsExpr: + return VisitRefOfExistsExpr(in, f) + case *FuncExpr: + return VisitRefOfFuncExpr(in, f) + case *GroupConcatExpr: + return VisitRefOfGroupConcatExpr(in, f) + case *IntervalExpr: + return VisitRefOfIntervalExpr(in, f) + case *IsExpr: + return VisitRefOfIsExpr(in, f) + case ListArg: + return VisitListArg(in, f) + case *Literal: + return VisitRefOfLiteral(in, f) + case *MatchExpr: + return VisitRefOfMatchExpr(in, f) + case *NotExpr: + return VisitRefOfNotExpr(in, f) + case *NullVal: + return VisitRefOfNullVal(in, f) + case *OrExpr: + return VisitRefOfOrExpr(in, f) + case *RangeCond: + return VisitRefOfRangeCond(in, f) + case *Subquery: + return VisitRefOfSubquery(in, f) + case *SubstrExpr: + return VisitRefOfSubstrExpr(in, f) + case *TimestampFuncExpr: + return VisitRefOfTimestampFuncExpr(in, f) + case *UnaryExpr: + return VisitRefOfUnaryExpr(in, f) + case ValTuple: + return VisitValTuple(in, f) + case *ValuesFuncExpr: + return VisitRefOfValuesFuncExpr(in, f) + case *XorExpr: + return VisitRefOfXorExpr(in, f) + default: + // this should never happen return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +} +func VisitExprs(in Exprs, f Visit) error { + if in == nil { return nil } - if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*IntervalExpr).Expr = newNode.(Expr) - }); errF != nil { - return errF + if cont, err := f(in); err != nil || !cont { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + for _, el := range in { + if err := VisitExpr(el, f); err != nil { + return err + } } return nil } - -// EqualsRefOfIsExpr does deep equals between the two objects. -func EqualsRefOfIsExpr(a, b *IsExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Operator == b.Operator && - EqualsExpr(a.Expr, b.Expr) -} - -// CloneRefOfIsExpr creates a deep clone of the input. -func CloneRefOfIsExpr(n *IsExpr) *IsExpr { - if n == nil { - return nil - } - out := *n - out.Expr = CloneExpr(n.Expr) - return &out -} - -// VisitRefOfIsExpr will visit all parts of the AST -func VisitRefOfIsExpr(in *IsExpr, f Visit) error { +func VisitGroupBy(in GroupBy, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Expr, f); err != nil { - return err + for _, el := range in { + if err := VisitExpr(el, f); err != nil { + return err + } } return nil } - -// rewriteRefOfIsExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfIsExpr(parent SQLNode, node *IsExpr, replacer replacerFunc) error { - if node == nil { +func VisitInsertRows(in InsertRows, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + switch in := in.(type) { + case *ParenSelect: + return VisitRefOfParenSelect(in, f) + case *Select: + return VisitRefOfSelect(in, f) + case *Union: + return VisitRefOfUnion(in, f) + case Values: + return VisitValues(in, f) + default: + // this should never happen return nil } - if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*IsExpr).Expr = newNode.(Expr) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil -} - -// EqualsJoinCondition does deep equals between the two objects. -func EqualsJoinCondition(a, b JoinCondition) bool { - return EqualsExpr(a.On, b.On) && - EqualsColumns(a.Using, b.Using) } - -// CloneJoinCondition creates a deep clone of the input. -func CloneJoinCondition(n JoinCondition) JoinCondition { - return *CloneRefOfJoinCondition(&n) +func VisitIsolationLevel(in IsolationLevel, f Visit) error { + _, err := f(in) + return err } - -// VisitJoinCondition will visit all parts of the AST func VisitJoinCondition(in JoinCondition, f Visit) error { if cont, err := f(in); err != nil || !cont { return err @@ -5987,304 +6881,166 @@ func VisitJoinCondition(in JoinCondition, f Visit) error { } return nil } - -// rewriteJoinCondition is part of the Rewrite implementation -func (a *application) rewriteJoinCondition(parent SQLNode, node JoinCondition, replacer replacerFunc) error { - var err error - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +func VisitListArg(in ListArg, f Visit) error { + _, err := f(in) + return err +} +func VisitOnDup(in OnDup, f Visit) error { + if in == nil { return nil } - if errF := a.rewriteExpr(node, node.On, func(newNode, parent SQLNode) { - err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'On' on 'JoinCondition'") - }); errF != nil { - return errF - } - if errF := a.rewriteColumns(node, node.Using, func(newNode, parent SQLNode) { - err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Using' on 'JoinCondition'") - }); errF != nil { - return errF - } - if err != nil { + if cont, err := f(in); err != nil || !cont { return err } - if a.post != nil && !a.post(&cur) { - return errAbort + for _, el := range in { + if err := VisitRefOfUpdateExpr(el, f); err != nil { + return err + } } return nil } - -// EqualsRefOfJoinTableExpr does deep equals between the two objects. -func EqualsRefOfJoinTableExpr(a, b *JoinTableExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsTableExpr(a.LeftExpr, b.LeftExpr) && - a.Join == b.Join && - EqualsTableExpr(a.RightExpr, b.RightExpr) && - EqualsJoinCondition(a.Condition, b.Condition) -} - -// CloneRefOfJoinTableExpr creates a deep clone of the input. -func CloneRefOfJoinTableExpr(n *JoinTableExpr) *JoinTableExpr { - if n == nil { - return nil - } - out := *n - out.LeftExpr = CloneTableExpr(n.LeftExpr) - out.RightExpr = CloneTableExpr(n.RightExpr) - out.Condition = CloneJoinCondition(n.Condition) - return &out -} - -// VisitRefOfJoinTableExpr will visit all parts of the AST -func VisitRefOfJoinTableExpr(in *JoinTableExpr, f Visit) error { +func VisitOrderBy(in OrderBy, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitTableExpr(in.LeftExpr, f); err != nil { - return err - } - if err := VisitTableExpr(in.RightExpr, f); err != nil { - return err - } - if err := VisitJoinCondition(in.Condition, f); err != nil { - return err + for _, el := range in { + if err := VisitRefOfOrder(el, f); err != nil { + return err + } } return nil } - -// rewriteRefOfJoinTableExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfJoinTableExpr(parent SQLNode, node *JoinTableExpr, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +func VisitPartitions(in Partitions, f Visit) error { + if in == nil { return nil } - if errF := a.rewriteTableExpr(node, node.LeftExpr, func(newNode, parent SQLNode) { - parent.(*JoinTableExpr).LeftExpr = newNode.(TableExpr) - }); errF != nil { - return errF - } - if errF := a.rewriteTableExpr(node, node.RightExpr, func(newNode, parent SQLNode) { - parent.(*JoinTableExpr).RightExpr = newNode.(TableExpr) - }); errF != nil { - return errF - } - if errF := a.rewriteJoinCondition(node, node.Condition, func(newNode, parent SQLNode) { - parent.(*JoinTableExpr).Condition = newNode.(JoinCondition) - }); errF != nil { - return errF + if cont, err := f(in); err != nil || !cont { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + for _, el := range in { + if err := VisitColIdent(el, f); err != nil { + return err + } } return nil } - -// EqualsRefOfKeyState does deep equals between the two objects. -func EqualsRefOfKeyState(a, b *KeyState) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Enable == b.Enable -} - -// CloneRefOfKeyState creates a deep clone of the input. -func CloneRefOfKeyState(n *KeyState) *KeyState { - if n == nil { - return nil - } - out := *n - return &out -} - -// VisitRefOfKeyState will visit all parts of the AST -func VisitRefOfKeyState(in *KeyState, f Visit) error { +func VisitRefOfAddColumns(in *AddColumns, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - return nil -} - -// rewriteRefOfKeyState is part of the Rewrite implementation -func (a *application) rewriteRefOfKeyState(parent SQLNode, node *KeyState, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + for _, el := range in.Columns { + if err := VisitRefOfColumnDefinition(el, f); err != nil { + return err + } } - if a.pre != nil && !a.pre(&cur) { - return nil + if err := VisitRefOfColName(in.First, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitRefOfColName(in.After, f); err != nil { + return err } return nil } - -// EqualsRefOfLimit does deep equals between the two objects. -func EqualsRefOfLimit(a, b *Limit) bool { - if a == b { - return true +func VisitRefOfAddConstraintDefinition(in *AddConstraintDefinition, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsExpr(a.Offset, b.Offset) && - EqualsExpr(a.Rowcount, b.Rowcount) + if err := VisitRefOfConstraintDefinition(in.ConstraintDefinition, f); err != nil { + return err + } + return nil } - -// CloneRefOfLimit creates a deep clone of the input. -func CloneRefOfLimit(n *Limit) *Limit { - if n == nil { +func VisitRefOfAddIndexDefinition(in *AddIndexDefinition, f Visit) error { + if in == nil { return nil } - out := *n - out.Offset = CloneExpr(n.Offset) - out.Rowcount = CloneExpr(n.Rowcount) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitRefOfIndexDefinition(in.IndexDefinition, f); err != nil { + return err + } + return nil } - -// VisitRefOfLimit will visit all parts of the AST -func VisitRefOfLimit(in *Limit, f Visit) error { +func VisitRefOfAliasedExpr(in *AliasedExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Offset, f); err != nil { + if err := VisitExpr(in.Expr, f); err != nil { return err } - if err := VisitExpr(in.Rowcount, f); err != nil { + if err := VisitColIdent(in.As, f); err != nil { return err } return nil } - -// rewriteRefOfLimit is part of the Rewrite implementation -func (a *application) rewriteRefOfLimit(parent SQLNode, node *Limit, replacer replacerFunc) error { - if node == nil { +func VisitRefOfAliasedTableExpr(in *AliasedTableExpr, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { - return nil + if err := VisitSimpleTableExpr(in.Expr, f); err != nil { + return err } - if errF := a.rewriteExpr(node, node.Offset, func(newNode, parent SQLNode) { - parent.(*Limit).Offset = newNode.(Expr) - }); errF != nil { - return errF + if err := VisitPartitions(in.Partitions, f); err != nil { + return err } - if errF := a.rewriteExpr(node, node.Rowcount, func(newNode, parent SQLNode) { - parent.(*Limit).Rowcount = newNode.(Expr) - }); errF != nil { - return errF + if err := VisitTableIdent(in.As, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitRefOfIndexHints(in.Hints, f); err != nil { + return err } return nil } - -// EqualsListArg does deep equals between the two objects. -func EqualsListArg(a, b ListArg) bool { - if len(a) != len(b) { - return false +func VisitRefOfAlterCharset(in *AlterCharset, f Visit) error { + if in == nil { + return nil } - for i := 0; i < len(a); i++ { - if a[i] != b[i] { - return false - } + if cont, err := f(in); err != nil || !cont { + return err } - return true -} - -// CloneListArg creates a deep clone of the input. -func CloneListArg(n ListArg) ListArg { - res := make(ListArg, 0, len(n)) - copy(res, n) - return res -} - -// VisitListArg will visit all parts of the AST -func VisitListArg(in ListArg, f Visit) error { - _, err := f(in) - return err + return nil } - -// rewriteListArg is part of the Rewrite implementation -func (a *application) rewriteListArg(parent SQLNode, node ListArg, replacer replacerFunc) error { - if node == nil { +func VisitRefOfAlterColumn(in *AlterColumn, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { - return nil + if err := VisitRefOfColName(in.Column, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitExpr(in.DefaultVal, f); err != nil { + return err } return nil } - -// EqualsRefOfLiteral does deep equals between the two objects. -func EqualsRefOfLiteral(a, b *Literal) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Val == b.Val && - a.Type == b.Type -} - -// CloneRefOfLiteral creates a deep clone of the input. -func CloneRefOfLiteral(n *Literal) *Literal { - if n == nil { +func VisitRefOfAlterDatabase(in *AlterDatabase, f Visit) error { + if in == nil { return nil } - out := *n - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfLiteral will visit all parts of the AST -func VisitRefOfLiteral(in *Literal, f Visit) error { +func VisitRefOfAlterMigration(in *AlterMigration, f Visit) error { if in == nil { return nil } @@ -6293,98 +7049,98 @@ func VisitRefOfLiteral(in *Literal, f Visit) error { } return nil } - -// rewriteRefOfLiteral is part of the Rewrite implementation -func (a *application) rewriteRefOfLiteral(parent SQLNode, node *Literal, replacer replacerFunc) error { - if node == nil { +func VisitRefOfAlterTable(in *AlterTable, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { - return nil + if err := VisitTableName(in.Table, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + for _, el := range in.AlterOptions { + if err := VisitAlterOption(el, f); err != nil { + return err + } + } + if err := VisitRefOfPartitionSpec(in.PartitionSpec, f); err != nil { + return err } return nil } - -// EqualsRefOfLoad does deep equals between the two objects. -func EqualsRefOfLoad(a, b *Load) bool { - if a == b { - return true +func VisitRefOfAlterView(in *AlterView, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return true -} - -// CloneRefOfLoad creates a deep clone of the input. -func CloneRefOfLoad(n *Load) *Load { - if n == nil { - return nil + if err := VisitTableName(in.ViewName, f); err != nil { + return err } - out := *n - return &out + if err := VisitColumns(in.Columns, f); err != nil { + return err + } + if err := VisitSelectStatement(in.Select, f); err != nil { + return err + } + return nil } - -// VisitRefOfLoad will visit all parts of the AST -func VisitRefOfLoad(in *Load, f Visit) error { +func VisitRefOfAlterVschema(in *AlterVschema, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - return nil -} - -// rewriteRefOfLoad is part of the Rewrite implementation -func (a *application) rewriteRefOfLoad(parent SQLNode, node *Load, replacer replacerFunc) error { - if node == nil { - return nil + if err := VisitTableName(in.Table, f); err != nil { + return err } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if err := VisitRefOfVindexSpec(in.VindexSpec, f); err != nil { + return err } - if a.pre != nil && !a.pre(&cur) { - return nil + for _, el := range in.VindexCols { + if err := VisitColIdent(el, f); err != nil { + return err + } } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitRefOfAutoIncSpec(in.AutoIncSpec, f); err != nil { + return err } return nil } - -// EqualsRefOfLockOption does deep equals between the two objects. -func EqualsRefOfLockOption(a, b *LockOption) bool { - if a == b { - return true +func VisitRefOfAndExpr(in *AndExpr, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.Type == b.Type + if err := VisitExpr(in.Left, f); err != nil { + return err + } + if err := VisitExpr(in.Right, f); err != nil { + return err + } + return nil } - -// CloneRefOfLockOption creates a deep clone of the input. -func CloneRefOfLockOption(n *LockOption) *LockOption { - if n == nil { +func VisitRefOfAutoIncSpec(in *AutoIncSpec, f Visit) error { + if in == nil { return nil } - out := *n - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitColIdent(in.Column, f); err != nil { + return err + } + if err := VisitTableName(in.Sequence, f); err != nil { + return err + } + return nil } - -// VisitRefOfLockOption will visit all parts of the AST -func VisitRefOfLockOption(in *LockOption, f Visit) error { +func VisitRefOfBegin(in *Begin, f Visit) error { if in == nil { return nil } @@ -6393,249 +7149,192 @@ func VisitRefOfLockOption(in *LockOption, f Visit) error { } return nil } - -// rewriteRefOfLockOption is part of the Rewrite implementation -func (a *application) rewriteRefOfLockOption(parent SQLNode, node *LockOption, replacer replacerFunc) error { - if node == nil { +func VisitRefOfBinaryExpr(in *BinaryExpr, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { - return nil + if err := VisitExpr(in.Left, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitExpr(in.Right, f); err != nil { + return err } return nil } - -// EqualsRefOfLockTables does deep equals between the two objects. -func EqualsRefOfLockTables(a, b *LockTables) bool { - if a == b { - return true +func VisitRefOfCallProc(in *CallProc, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsTableAndLockTypes(a.Tables, b.Tables) -} - -// CloneRefOfLockTables creates a deep clone of the input. -func CloneRefOfLockTables(n *LockTables) *LockTables { - if n == nil { - return nil + if err := VisitTableName(in.Name, f); err != nil { + return err } - out := *n - out.Tables = CloneTableAndLockTypes(n.Tables) - return &out + if err := VisitExprs(in.Params, f); err != nil { + return err + } + return nil } - -// VisitRefOfLockTables will visit all parts of the AST -func VisitRefOfLockTables(in *LockTables, f Visit) error { +func VisitRefOfCaseExpr(in *CaseExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } + if err := VisitExpr(in.Expr, f); err != nil { + return err + } + for _, el := range in.Whens { + if err := VisitRefOfWhen(el, f); err != nil { + return err + } + } + if err := VisitExpr(in.Else, f); err != nil { + return err + } return nil } - -// rewriteRefOfLockTables is part of the Rewrite implementation -func (a *application) rewriteRefOfLockTables(parent SQLNode, node *LockTables, replacer replacerFunc) error { - if node == nil { +func VisitRefOfChangeColumn(in *ChangeColumn, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil + if cont, err := f(in); err != nil || !cont { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitRefOfColName(in.OldColumn, f); err != nil { + return err } - return nil -} - -// EqualsRefOfMatchExpr does deep equals between the two objects. -func EqualsRefOfMatchExpr(a, b *MatchExpr) bool { - if a == b { - return true + if err := VisitRefOfColumnDefinition(in.NewColDefinition, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitRefOfColName(in.First, f); err != nil { + return err } - return EqualsSelectExprs(a.Columns, b.Columns) && - EqualsExpr(a.Expr, b.Expr) && - a.Option == b.Option -} - -// CloneRefOfMatchExpr creates a deep clone of the input. -func CloneRefOfMatchExpr(n *MatchExpr) *MatchExpr { - if n == nil { - return nil + if err := VisitRefOfColName(in.After, f); err != nil { + return err } - out := *n - out.Columns = CloneSelectExprs(n.Columns) - out.Expr = CloneExpr(n.Expr) - return &out + return nil } - -// VisitRefOfMatchExpr will visit all parts of the AST -func VisitRefOfMatchExpr(in *MatchExpr, f Visit) error { +func VisitRefOfCheckConstraintDefinition(in *CheckConstraintDefinition, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitSelectExprs(in.Columns, f); err != nil { - return err - } if err := VisitExpr(in.Expr, f); err != nil { return err } return nil } - -// rewriteRefOfMatchExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfMatchExpr(parent SQLNode, node *MatchExpr, replacer replacerFunc) error { - if node == nil { +func VisitRefOfColIdent(in *ColIdent, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { + return nil +} +func VisitRefOfColName(in *ColName, f Visit) error { + if in == nil { return nil } - if errF := a.rewriteSelectExprs(node, node.Columns, func(newNode, parent SQLNode) { - parent.(*MatchExpr).Columns = newNode.(SelectExprs) - }); errF != nil { - return errF + if cont, err := f(in); err != nil || !cont { + return err } - if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*MatchExpr).Expr = newNode.(Expr) - }); errF != nil { - return errF + if err := VisitColIdent(in.Name, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitTableName(in.Qualifier, f); err != nil { + return err } return nil } - -// EqualsRefOfModifyColumn does deep equals between the two objects. -func EqualsRefOfModifyColumn(a, b *ModifyColumn) bool { - if a == b { - return true +func VisitRefOfCollateExpr(in *CollateExpr, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsRefOfColumnDefinition(a.NewColDefinition, b.NewColDefinition) && - EqualsRefOfColName(a.First, b.First) && - EqualsRefOfColName(a.After, b.After) -} - -// CloneRefOfModifyColumn creates a deep clone of the input. -func CloneRefOfModifyColumn(n *ModifyColumn) *ModifyColumn { - if n == nil { - return nil + if err := VisitExpr(in.Expr, f); err != nil { + return err } - out := *n - out.NewColDefinition = CloneRefOfColumnDefinition(n.NewColDefinition) - out.First = CloneRefOfColName(n.First) - out.After = CloneRefOfColName(n.After) - return &out + return nil } - -// VisitRefOfModifyColumn will visit all parts of the AST -func VisitRefOfModifyColumn(in *ModifyColumn, f Visit) error { +func VisitRefOfColumnDefinition(in *ColumnDefinition, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitRefOfColumnDefinition(in.NewColDefinition, f); err != nil { + if err := VisitColIdent(in.Name, f); err != nil { return err } - if err := VisitRefOfColName(in.First, f); err != nil { + return nil +} +func VisitRefOfColumnType(in *ColumnType, f Visit) error { + if in == nil { + return nil + } + if cont, err := f(in); err != nil || !cont { return err } - if err := VisitRefOfColName(in.After, f); err != nil { + if err := VisitRefOfLiteral(in.Length, f); err != nil { + return err + } + if err := VisitRefOfLiteral(in.Scale, f); err != nil { return err } return nil } - -// rewriteRefOfModifyColumn is part of the Rewrite implementation -func (a *application) rewriteRefOfModifyColumn(parent SQLNode, node *ModifyColumn, replacer replacerFunc) error { - if node == nil { +func VisitRefOfCommit(in *Commit, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { + return nil +} +func VisitRefOfComparisonExpr(in *ComparisonExpr, f Visit) error { + if in == nil { return nil } - if errF := a.rewriteRefOfColumnDefinition(node, node.NewColDefinition, func(newNode, parent SQLNode) { - parent.(*ModifyColumn).NewColDefinition = newNode.(*ColumnDefinition) - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { - parent.(*ModifyColumn).First = newNode.(*ColName) - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { - parent.(*ModifyColumn).After = newNode.(*ColName) - }); errF != nil { - return errF + if cont, err := f(in); err != nil || !cont { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitExpr(in.Left, f); err != nil { + return err } - return nil -} - -// EqualsRefOfNextval does deep equals between the two objects. -func EqualsRefOfNextval(a, b *Nextval) bool { - if a == b { - return true + if err := VisitExpr(in.Right, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitExpr(in.Escape, f); err != nil { + return err } - return EqualsExpr(a.Expr, b.Expr) + return nil } - -// CloneRefOfNextval creates a deep clone of the input. -func CloneRefOfNextval(n *Nextval) *Nextval { - if n == nil { +func VisitRefOfConstraintDefinition(in *ConstraintDefinition, f Visit) error { + if in == nil { return nil } - out := *n - out.Expr = CloneExpr(n.Expr) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitConstraintInfo(in.Details, f); err != nil { + return err + } + return nil } - -// VisitRefOfNextval will visit all parts of the AST -func VisitRefOfNextval(in *Nextval, f Visit) error { +func VisitRefOfConvertExpr(in *ConvertExpr, f Visit) error { if in == nil { return nil } @@ -6645,114 +7344,102 @@ func VisitRefOfNextval(in *Nextval, f Visit) error { if err := VisitExpr(in.Expr, f); err != nil { return err } + if err := VisitRefOfConvertType(in.Type, f); err != nil { + return err + } return nil } - -// rewriteRefOfNextval is part of the Rewrite implementation -func (a *application) rewriteRefOfNextval(parent SQLNode, node *Nextval, replacer replacerFunc) error { - if node == nil { +func VisitRefOfConvertType(in *ConvertType, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil + if cont, err := f(in); err != nil || !cont { + return err } - if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*Nextval).Expr = newNode.(Expr) - }); errF != nil { - return errF + if err := VisitRefOfLiteral(in.Length, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitRefOfLiteral(in.Scale, f); err != nil { + return err } return nil } - -// EqualsRefOfNotExpr does deep equals between the two objects. -func EqualsRefOfNotExpr(a, b *NotExpr) bool { - if a == b { - return true +func VisitRefOfConvertUsingExpr(in *ConvertUsingExpr, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsExpr(a.Expr, b.Expr) -} - -// CloneRefOfNotExpr creates a deep clone of the input. -func CloneRefOfNotExpr(n *NotExpr) *NotExpr { - if n == nil { - return nil + if err := VisitExpr(in.Expr, f); err != nil { + return err } - out := *n - out.Expr = CloneExpr(n.Expr) - return &out + return nil } - -// VisitRefOfNotExpr will visit all parts of the AST -func VisitRefOfNotExpr(in *NotExpr, f Visit) error { +func VisitRefOfCreateDatabase(in *CreateDatabase, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Expr, f); err != nil { + if err := VisitComments(in.Comments, f); err != nil { return err } return nil } - -// rewriteRefOfNotExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfNotExpr(parent SQLNode, node *NotExpr, replacer replacerFunc) error { - if node == nil { +func VisitRefOfCreateTable(in *CreateTable, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { - return nil + if err := VisitTableName(in.Table, f); err != nil { + return err } - if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*NotExpr).Expr = newNode.(Expr) - }); errF != nil { - return errF + if err := VisitRefOfTableSpec(in.TableSpec, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitRefOfOptLike(in.OptLike, f); err != nil { + return err } return nil } - -// EqualsRefOfNullVal does deep equals between the two objects. -func EqualsRefOfNullVal(a, b *NullVal) bool { - if a == b { - return true +func VisitRefOfCreateView(in *CreateView, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return true + if err := VisitTableName(in.ViewName, f); err != nil { + return err + } + if err := VisitColumns(in.Columns, f); err != nil { + return err + } + if err := VisitSelectStatement(in.Select, f); err != nil { + return err + } + return nil } - -// CloneRefOfNullVal creates a deep clone of the input. -func CloneRefOfNullVal(n *NullVal) *NullVal { - if n == nil { +func VisitRefOfCurTimeFuncExpr(in *CurTimeFuncExpr, f Visit) error { + if in == nil { return nil } - out := *n - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitColIdent(in.Name, f); err != nil { + return err + } + if err := VisitExpr(in.Fsp, f); err != nil { + return err + } + return nil } - -// VisitRefOfNullVal will visit all parts of the AST -func VisitRefOfNullVal(in *NullVal, f Visit) error { +func VisitRefOfDefault(in *Default, f Visit) error { if in == nil { return nil } @@ -6761,423 +7448,348 @@ func VisitRefOfNullVal(in *NullVal, f Visit) error { } return nil } - -// rewriteRefOfNullVal is part of the Rewrite implementation -func (a *application) rewriteRefOfNullVal(parent SQLNode, node *NullVal, replacer replacerFunc) error { - if node == nil { +func VisitRefOfDelete(in *Delete, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { - return nil + if err := VisitComments(in.Comments, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitTableNames(in.Targets, f); err != nil { + return err } - return nil -} - -// EqualsOnDup does deep equals between the two objects. -func EqualsOnDup(a, b OnDup) bool { - if len(a) != len(b) { - return false + if err := VisitTableExprs(in.TableExprs, f); err != nil { + return err } - for i := 0; i < len(a); i++ { - if !EqualsRefOfUpdateExpr(a[i], b[i]) { - return false - } + if err := VisitPartitions(in.Partitions, f); err != nil { + return err } - return true -} - -// CloneOnDup creates a deep clone of the input. -func CloneOnDup(n OnDup) OnDup { - res := make(OnDup, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfUpdateExpr(x)) + if err := VisitRefOfWhere(in.Where, f); err != nil { + return err } - return res + if err := VisitOrderBy(in.OrderBy, f); err != nil { + return err + } + if err := VisitRefOfLimit(in.Limit, f); err != nil { + return err + } + return nil } - -// VisitOnDup will visit all parts of the AST -func VisitOnDup(in OnDup, f Visit) error { +func VisitRefOfDerivedTable(in *DerivedTable, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - for _, el := range in { - if err := VisitRefOfUpdateExpr(el, f); err != nil { - return err - } + if err := VisitSelectStatement(in.Select, f); err != nil { + return err } return nil } - -// rewriteOnDup is part of the Rewrite implementation -func (a *application) rewriteOnDup(parent SQLNode, node OnDup, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +func VisitRefOfDropColumn(in *DropColumn, f Visit) error { + if in == nil { return nil } - for i, el := range node { - if errF := a.rewriteRefOfUpdateExpr(node, el, func(newNode, parent SQLNode) { - parent.(OnDup)[i] = newNode.(*UpdateExpr) - }); errF != nil { - return errF - } + if cont, err := f(in); err != nil || !cont { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitRefOfColName(in.Name, f); err != nil { + return err } return nil } - -// EqualsRefOfOptLike does deep equals between the two objects. -func EqualsRefOfOptLike(a, b *OptLike) bool { - if a == b { - return true +func VisitRefOfDropDatabase(in *DropDatabase, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsTableName(a.LikeTable, b.LikeTable) + if err := VisitComments(in.Comments, f); err != nil { + return err + } + return nil } - -// CloneRefOfOptLike creates a deep clone of the input. -func CloneRefOfOptLike(n *OptLike) *OptLike { - if n == nil { +func VisitRefOfDropKey(in *DropKey, f Visit) error { + if in == nil { return nil } - out := *n - out.LikeTable = CloneTableName(n.LikeTable) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfOptLike will visit all parts of the AST -func VisitRefOfOptLike(in *OptLike, f Visit) error { +func VisitRefOfDropTable(in *DropTable, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitTableName(in.LikeTable, f); err != nil { + if err := VisitTableNames(in.FromTables, f); err != nil { return err } return nil } - -// rewriteRefOfOptLike is part of the Rewrite implementation -func (a *application) rewriteRefOfOptLike(parent SQLNode, node *OptLike, replacer replacerFunc) error { - if node == nil { +func VisitRefOfDropView(in *DropView, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { + if err := VisitTableNames(in.FromTables, f); err != nil { + return err + } + return nil +} +func VisitRefOfExistsExpr(in *ExistsExpr, f Visit) error { + if in == nil { return nil } - if errF := a.rewriteTableName(node, node.LikeTable, func(newNode, parent SQLNode) { - parent.(*OptLike).LikeTable = newNode.(TableName) - }); errF != nil { - return errF + if cont, err := f(in); err != nil || !cont { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitRefOfSubquery(in.Subquery, f); err != nil { + return err } return nil } - -// EqualsRefOfOrExpr does deep equals between the two objects. -func EqualsRefOfOrExpr(a, b *OrExpr) bool { - if a == b { - return true +func VisitRefOfExplainStmt(in *ExplainStmt, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsExpr(a.Left, b.Left) && - EqualsExpr(a.Right, b.Right) -} - -// CloneRefOfOrExpr creates a deep clone of the input. -func CloneRefOfOrExpr(n *OrExpr) *OrExpr { - if n == nil { - return nil + if err := VisitStatement(in.Statement, f); err != nil { + return err } - out := *n - out.Left = CloneExpr(n.Left) - out.Right = CloneExpr(n.Right) - return &out + return nil } - -// VisitRefOfOrExpr will visit all parts of the AST -func VisitRefOfOrExpr(in *OrExpr, f Visit) error { +func VisitRefOfExplainTab(in *ExplainTab, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Left, f); err != nil { + if err := VisitTableName(in.Table, f); err != nil { return err } - if err := VisitExpr(in.Right, f); err != nil { + return nil +} +func VisitRefOfFlush(in *Flush, f Visit) error { + if in == nil { + return nil + } + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitTableNames(in.TableNames, f); err != nil { return err } return nil } - -// rewriteRefOfOrExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfOrExpr(parent SQLNode, node *OrExpr, replacer replacerFunc) error { - if node == nil { +func VisitRefOfForce(in *Force, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { + return nil +} +func VisitRefOfForeignKeyDefinition(in *ForeignKeyDefinition, f Visit) error { + if in == nil { return nil } - if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { - parent.(*OrExpr).Left = newNode.(Expr) - }); errF != nil { - return errF + if cont, err := f(in); err != nil || !cont { + return err } - if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { - parent.(*OrExpr).Right = newNode.(Expr) - }); errF != nil { - return errF + if err := VisitColumns(in.Source, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitTableName(in.ReferencedTable, f); err != nil { + return err } - return nil -} - -// EqualsRefOfOrder does deep equals between the two objects. -func EqualsRefOfOrder(a, b *Order) bool { - if a == b { - return true + if err := VisitColumns(in.ReferencedColumns, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitReferenceAction(in.OnDelete, f); err != nil { + return err } - return EqualsExpr(a.Expr, b.Expr) && - a.Direction == b.Direction -} - -// CloneRefOfOrder creates a deep clone of the input. -func CloneRefOfOrder(n *Order) *Order { - if n == nil { - return nil + if err := VisitReferenceAction(in.OnUpdate, f); err != nil { + return err } - out := *n - out.Expr = CloneExpr(n.Expr) - return &out + return nil } - -// VisitRefOfOrder will visit all parts of the AST -func VisitRefOfOrder(in *Order, f Visit) error { +func VisitRefOfFuncExpr(in *FuncExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Expr, f); err != nil { + if err := VisitTableIdent(in.Qualifier, f); err != nil { + return err + } + if err := VisitColIdent(in.Name, f); err != nil { + return err + } + if err := VisitSelectExprs(in.Exprs, f); err != nil { return err } return nil } - -// rewriteRefOfOrder is part of the Rewrite implementation -func (a *application) rewriteRefOfOrder(parent SQLNode, node *Order, replacer replacerFunc) error { - if node == nil { +func VisitRefOfGroupConcatExpr(in *GroupConcatExpr, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { - return nil + if err := VisitSelectExprs(in.Exprs, f); err != nil { + return err } - if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*Order).Expr = newNode.(Expr) - }); errF != nil { - return errF + if err := VisitOrderBy(in.OrderBy, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitRefOfLimit(in.Limit, f); err != nil { + return err } return nil } - -// EqualsOrderBy does deep equals between the two objects. -func EqualsOrderBy(a, b OrderBy) bool { - if len(a) != len(b) { - return false +func VisitRefOfIndexDefinition(in *IndexDefinition, f Visit) error { + if in == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsRefOfOrder(a[i], b[i]) { - return false - } + if cont, err := f(in); err != nil || !cont { + return err } - return true -} - -// CloneOrderBy creates a deep clone of the input. -func CloneOrderBy(n OrderBy) OrderBy { - res := make(OrderBy, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfOrder(x)) + if err := VisitRefOfIndexInfo(in.Info, f); err != nil { + return err } - return res + return nil } - -// VisitOrderBy will visit all parts of the AST -func VisitOrderBy(in OrderBy, f Visit) error { +func VisitRefOfIndexHints(in *IndexHints, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - for _, el := range in { - if err := VisitRefOfOrder(el, f); err != nil { + for _, el := range in.Indexes { + if err := VisitColIdent(el, f); err != nil { return err } } return nil } - -// rewriteOrderBy is part of the Rewrite implementation -func (a *application) rewriteOrderBy(parent SQLNode, node OrderBy, replacer replacerFunc) error { - if node == nil { +func VisitRefOfIndexInfo(in *IndexInfo, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { + if err := VisitColIdent(in.Name, f); err != nil { + return err + } + if err := VisitColIdent(in.ConstraintName, f); err != nil { + return err + } + return nil +} +func VisitRefOfInsert(in *Insert, f Visit) error { + if in == nil { return nil } - for i, el := range node { - if errF := a.rewriteRefOfOrder(node, el, func(newNode, parent SQLNode) { - parent.(OrderBy)[i] = newNode.(*Order) - }); errF != nil { - return errF - } + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitComments(in.Comments, f); err != nil { + return err + } + if err := VisitTableName(in.Table, f); err != nil { + return err + } + if err := VisitPartitions(in.Partitions, f); err != nil { + return err + } + if err := VisitColumns(in.Columns, f); err != nil { + return err + } + if err := VisitInsertRows(in.Rows, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitOnDup(in.OnDup, f); err != nil { + return err } return nil } - -// EqualsRefOfOrderByOption does deep equals between the two objects. -func EqualsRefOfOrderByOption(a, b *OrderByOption) bool { - if a == b { - return true +func VisitRefOfIntervalExpr(in *IntervalExpr, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsColumns(a.Cols, b.Cols) -} - -// CloneRefOfOrderByOption creates a deep clone of the input. -func CloneRefOfOrderByOption(n *OrderByOption) *OrderByOption { - if n == nil { - return nil + if err := VisitExpr(in.Expr, f); err != nil { + return err } - out := *n - out.Cols = CloneColumns(n.Cols) - return &out + return nil } - -// VisitRefOfOrderByOption will visit all parts of the AST -func VisitRefOfOrderByOption(in *OrderByOption, f Visit) error { +func VisitRefOfIsExpr(in *IsExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitColumns(in.Cols, f); err != nil { + if err := VisitExpr(in.Expr, f); err != nil { return err } return nil } - -// rewriteRefOfOrderByOption is part of the Rewrite implementation -func (a *application) rewriteRefOfOrderByOption(parent SQLNode, node *OrderByOption, replacer replacerFunc) error { - if node == nil { +func VisitRefOfJoinCondition(in *JoinCondition, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil + if cont, err := f(in); err != nil || !cont { + return err } - if errF := a.rewriteColumns(node, node.Cols, func(newNode, parent SQLNode) { - parent.(*OrderByOption).Cols = newNode.(Columns) - }); errF != nil { - return errF + if err := VisitExpr(in.On, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitColumns(in.Using, f); err != nil { + return err } return nil } - -// EqualsRefOfOtherAdmin does deep equals between the two objects. -func EqualsRefOfOtherAdmin(a, b *OtherAdmin) bool { - if a == b { - return true +func VisitRefOfJoinTableExpr(in *JoinTableExpr, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return true -} - -// CloneRefOfOtherAdmin creates a deep clone of the input. -func CloneRefOfOtherAdmin(n *OtherAdmin) *OtherAdmin { - if n == nil { - return nil + if err := VisitTableExpr(in.LeftExpr, f); err != nil { + return err } - out := *n - return &out + if err := VisitTableExpr(in.RightExpr, f); err != nil { + return err + } + if err := VisitJoinCondition(in.Condition, f); err != nil { + return err + } + return nil } - -// VisitRefOfOtherAdmin will visit all parts of the AST -func VisitRefOfOtherAdmin(in *OtherAdmin, f Visit) error { +func VisitRefOfKeyState(in *KeyState, f Visit) error { if in == nil { return nil } @@ -7186,48 +7798,31 @@ func VisitRefOfOtherAdmin(in *OtherAdmin, f Visit) error { } return nil } - -// rewriteRefOfOtherAdmin is part of the Rewrite implementation -func (a *application) rewriteRefOfOtherAdmin(parent SQLNode, node *OtherAdmin, replacer replacerFunc) error { - if node == nil { +func VisitRefOfLimit(in *Limit, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { - return nil + if err := VisitExpr(in.Offset, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitExpr(in.Rowcount, f); err != nil { + return err } return nil } - -// EqualsRefOfOtherRead does deep equals between the two objects. -func EqualsRefOfOtherRead(a, b *OtherRead) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return true -} - -// CloneRefOfOtherRead creates a deep clone of the input. -func CloneRefOfOtherRead(n *OtherRead) *OtherRead { - if n == nil { +func VisitRefOfLiteral(in *Literal, f Visit) error { + if in == nil { return nil } - out := *n - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfOtherRead will visit all parts of the AST -func VisitRefOfOtherRead(in *OtherRead, f Visit) error { +func VisitRefOfLoad(in *Load, f Visit) error { if in == nil { return nil } @@ -7236,245 +7831,198 @@ func VisitRefOfOtherRead(in *OtherRead, f Visit) error { } return nil } - -// rewriteRefOfOtherRead is part of the Rewrite implementation -func (a *application) rewriteRefOfOtherRead(parent SQLNode, node *OtherRead, replacer replacerFunc) error { - if node == nil { +func VisitRefOfLockOption(in *LockOption, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { + return nil +} +func VisitRefOfLockTables(in *LockTables, f Visit) error { + if in == nil { return nil } - if a.post != nil && !a.post(&cur) { - return errAbort + if cont, err := f(in); err != nil || !cont { + return err } return nil } - -// EqualsRefOfParenSelect does deep equals between the two objects. -func EqualsRefOfParenSelect(a, b *ParenSelect) bool { - if a == b { - return true +func VisitRefOfMatchExpr(in *MatchExpr, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsSelectStatement(a.Select, b.Select) -} - -// CloneRefOfParenSelect creates a deep clone of the input. -func CloneRefOfParenSelect(n *ParenSelect) *ParenSelect { - if n == nil { - return nil + if err := VisitSelectExprs(in.Columns, f); err != nil { + return err } - out := *n - out.Select = CloneSelectStatement(n.Select) - return &out + if err := VisitExpr(in.Expr, f); err != nil { + return err + } + return nil } - -// VisitRefOfParenSelect will visit all parts of the AST -func VisitRefOfParenSelect(in *ParenSelect, f Visit) error { +func VisitRefOfModifyColumn(in *ModifyColumn, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitSelectStatement(in.Select, f); err != nil { + if err := VisitRefOfColumnDefinition(in.NewColDefinition, f); err != nil { + return err + } + if err := VisitRefOfColName(in.First, f); err != nil { + return err + } + if err := VisitRefOfColName(in.After, f); err != nil { return err } return nil } - -// rewriteRefOfParenSelect is part of the Rewrite implementation -func (a *application) rewriteRefOfParenSelect(parent SQLNode, node *ParenSelect, replacer replacerFunc) error { - if node == nil { +func VisitRefOfNextval(in *Nextval, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { + if err := VisitExpr(in.Expr, f); err != nil { + return err + } + return nil +} +func VisitRefOfNotExpr(in *NotExpr, f Visit) error { + if in == nil { return nil } - if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { - parent.(*ParenSelect).Select = newNode.(SelectStatement) - }); errF != nil { - return errF + if cont, err := f(in); err != nil || !cont { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitExpr(in.Expr, f); err != nil { + return err } return nil } - -// EqualsRefOfParenTableExpr does deep equals between the two objects. -func EqualsRefOfParenTableExpr(a, b *ParenTableExpr) bool { - if a == b { - return true +func VisitRefOfNullVal(in *NullVal, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsTableExprs(a.Exprs, b.Exprs) + return nil } - -// CloneRefOfParenTableExpr creates a deep clone of the input. -func CloneRefOfParenTableExpr(n *ParenTableExpr) *ParenTableExpr { - if n == nil { +func VisitRefOfOptLike(in *OptLike, f Visit) error { + if in == nil { return nil } - out := *n - out.Exprs = CloneTableExprs(n.Exprs) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitTableName(in.LikeTable, f); err != nil { + return err + } + return nil } - -// VisitRefOfParenTableExpr will visit all parts of the AST -func VisitRefOfParenTableExpr(in *ParenTableExpr, f Visit) error { +func VisitRefOfOrExpr(in *OrExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitTableExprs(in.Exprs, f); err != nil { + if err := VisitExpr(in.Left, f); err != nil { + return err + } + if err := VisitExpr(in.Right, f); err != nil { return err } return nil } - -// rewriteRefOfParenTableExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfParenTableExpr(parent SQLNode, node *ParenTableExpr, replacer replacerFunc) error { - if node == nil { +func VisitRefOfOrder(in *Order, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { + if err := VisitExpr(in.Expr, f); err != nil { + return err + } + return nil +} +func VisitRefOfOrderByOption(in *OrderByOption, f Visit) error { + if in == nil { return nil } - if errF := a.rewriteTableExprs(node, node.Exprs, func(newNode, parent SQLNode) { - parent.(*ParenTableExpr).Exprs = newNode.(TableExprs) - }); errF != nil { - return errF + if cont, err := f(in); err != nil || !cont { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitColumns(in.Cols, f); err != nil { + return err } return nil } - -// EqualsRefOfPartitionDefinition does deep equals between the two objects. -func EqualsRefOfPartitionDefinition(a, b *PartitionDefinition) bool { - if a == b { - return true +func VisitRefOfOtherAdmin(in *OtherAdmin, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.Maxvalue == b.Maxvalue && - EqualsColIdent(a.Name, b.Name) && - EqualsExpr(a.Limit, b.Limit) + return nil } - -// CloneRefOfPartitionDefinition creates a deep clone of the input. -func CloneRefOfPartitionDefinition(n *PartitionDefinition) *PartitionDefinition { - if n == nil { +func VisitRefOfOtherRead(in *OtherRead, f Visit) error { + if in == nil { return nil } - out := *n - out.Name = CloneColIdent(n.Name) - out.Limit = CloneExpr(n.Limit) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfPartitionDefinition will visit all parts of the AST -func VisitRefOfPartitionDefinition(in *PartitionDefinition, f Visit) error { +func VisitRefOfParenSelect(in *ParenSelect, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitColIdent(in.Name, f); err != nil { - return err - } - if err := VisitExpr(in.Limit, f); err != nil { + if err := VisitSelectStatement(in.Select, f); err != nil { return err } return nil } - -// rewriteRefOfPartitionDefinition is part of the Rewrite implementation -func (a *application) rewriteRefOfPartitionDefinition(parent SQLNode, node *PartitionDefinition, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +func VisitRefOfParenTableExpr(in *ParenTableExpr, f Visit) error { + if in == nil { return nil } - if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { - parent.(*PartitionDefinition).Name = newNode.(ColIdent) - }); errF != nil { - return errF - } - if errF := a.rewriteExpr(node, node.Limit, func(newNode, parent SQLNode) { - parent.(*PartitionDefinition).Limit = newNode.(Expr) - }); errF != nil { - return errF + if cont, err := f(in); err != nil || !cont { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitTableExprs(in.Exprs, f); err != nil { + return err } return nil } - -// EqualsRefOfPartitionSpec does deep equals between the two objects. -func EqualsRefOfPartitionSpec(a, b *PartitionSpec) bool { - if a == b { - return true +func VisitRefOfPartitionDefinition(in *PartitionDefinition, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.IsAll == b.IsAll && - a.WithoutValidation == b.WithoutValidation && - a.Action == b.Action && - EqualsPartitions(a.Names, b.Names) && - EqualsRefOfLiteral(a.Number, b.Number) && - EqualsTableName(a.TableName, b.TableName) && - EqualsSliceOfRefOfPartitionDefinition(a.Definitions, b.Definitions) -} - -// CloneRefOfPartitionSpec creates a deep clone of the input. -func CloneRefOfPartitionSpec(n *PartitionSpec) *PartitionSpec { - if n == nil { - return nil + if err := VisitColIdent(in.Name, f); err != nil { + return err } - out := *n - out.Names = ClonePartitions(n.Names) - out.Number = CloneRefOfLiteral(n.Number) - out.TableName = CloneTableName(n.TableName) - out.Definitions = CloneSliceOfRefOfPartitionDefinition(n.Definitions) - return &out + if err := VisitExpr(in.Limit, f); err != nil { + return err + } + return nil } - -// VisitRefOfPartitionSpec will visit all parts of the AST func VisitRefOfPartitionSpec(in *PartitionSpec, f Visit) error { if in == nil { return nil @@ -7498,434 +8046,336 @@ func VisitRefOfPartitionSpec(in *PartitionSpec, f Visit) error { } return nil } - -// rewriteRefOfPartitionSpec is part of the Rewrite implementation -func (a *application) rewriteRefOfPartitionSpec(parent SQLNode, node *PartitionSpec, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { +func VisitRefOfRangeCond(in *RangeCond, f Visit) error { + if in == nil { return nil } - if errF := a.rewritePartitions(node, node.Names, func(newNode, parent SQLNode) { - parent.(*PartitionSpec).Names = newNode.(Partitions) - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfLiteral(node, node.Number, func(newNode, parent SQLNode) { - parent.(*PartitionSpec).Number = newNode.(*Literal) - }); errF != nil { - return errF + if cont, err := f(in); err != nil || !cont { + return err } - if errF := a.rewriteTableName(node, node.TableName, func(newNode, parent SQLNode) { - parent.(*PartitionSpec).TableName = newNode.(TableName) - }); errF != nil { - return errF + if err := VisitExpr(in.Left, f); err != nil { + return err } - for i, el := range node.Definitions { - if errF := a.rewriteRefOfPartitionDefinition(node, el, func(newNode, parent SQLNode) { - parent.(*PartitionSpec).Definitions[i] = newNode.(*PartitionDefinition) - }); errF != nil { - return errF - } + if err := VisitExpr(in.From, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitExpr(in.To, f); err != nil { + return err } return nil } - -// EqualsPartitions does deep equals between the two objects. -func EqualsPartitions(a, b Partitions) bool { - if len(a) != len(b) { - return false +func VisitRefOfRelease(in *Release, f Visit) error { + if in == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsColIdent(a[i], b[i]) { - return false - } + if cont, err := f(in); err != nil || !cont { + return err } - return true -} - -// ClonePartitions creates a deep clone of the input. -func ClonePartitions(n Partitions) Partitions { - res := make(Partitions, 0, len(n)) - for _, x := range n { - res = append(res, CloneColIdent(x)) + if err := VisitColIdent(in.Name, f); err != nil { + return err } - return res + return nil } - -// VisitPartitions will visit all parts of the AST -func VisitPartitions(in Partitions, f Visit) error { +func VisitRefOfRenameIndex(in *RenameIndex, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - for _, el := range in { - if err := VisitColIdent(el, f); err != nil { - return err - } - } return nil } - -// rewritePartitions is part of the Rewrite implementation -func (a *application) rewritePartitions(parent SQLNode, node Partitions, replacer replacerFunc) error { - if node == nil { +func VisitRefOfRenameTable(in *RenameTable, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { + return nil +} +func VisitRefOfRenameTableName(in *RenameTableName, f Visit) error { + if in == nil { return nil } - for i, el := range node { - if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { - parent.(Partitions)[i] = newNode.(ColIdent) - }); errF != nil { - return errF - } + if cont, err := f(in); err != nil || !cont { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitTableName(in.Table, f); err != nil { + return err } return nil } - -// EqualsRefOfRangeCond does deep equals between the two objects. -func EqualsRefOfRangeCond(a, b *RangeCond) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +func VisitRefOfRevertMigration(in *RevertMigration, f Visit) error { + if in == nil { + return nil } - return a.Operator == b.Operator && - EqualsExpr(a.Left, b.Left) && - EqualsExpr(a.From, b.From) && - EqualsExpr(a.To, b.To) + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// CloneRefOfRangeCond creates a deep clone of the input. -func CloneRefOfRangeCond(n *RangeCond) *RangeCond { - if n == nil { +func VisitRefOfRollback(in *Rollback, f Visit) error { + if in == nil { return nil } - out := *n - out.Left = CloneExpr(n.Left) - out.From = CloneExpr(n.From) - out.To = CloneExpr(n.To) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfRangeCond will visit all parts of the AST -func VisitRefOfRangeCond(in *RangeCond, f Visit) error { +func VisitRefOfSRollback(in *SRollback, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Left, f); err != nil { + if err := VisitColIdent(in.Name, f); err != nil { return err } - if err := VisitExpr(in.From, f); err != nil { + return nil +} +func VisitRefOfSavepoint(in *Savepoint, f Visit) error { + if in == nil { + return nil + } + if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.To, f); err != nil { + if err := VisitColIdent(in.Name, f); err != nil { return err } return nil } - -// rewriteRefOfRangeCond is part of the Rewrite implementation -func (a *application) rewriteRefOfRangeCond(parent SQLNode, node *RangeCond, replacer replacerFunc) error { - if node == nil { +func VisitRefOfSelect(in *Select, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { - return nil + if err := VisitComments(in.Comments, f); err != nil { + return err } - if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { - parent.(*RangeCond).Left = newNode.(Expr) - }); errF != nil { - return errF + if err := VisitSelectExprs(in.SelectExprs, f); err != nil { + return err } - if errF := a.rewriteExpr(node, node.From, func(newNode, parent SQLNode) { - parent.(*RangeCond).From = newNode.(Expr) - }); errF != nil { - return errF + if err := VisitTableExprs(in.From, f); err != nil { + return err } - if errF := a.rewriteExpr(node, node.To, func(newNode, parent SQLNode) { - parent.(*RangeCond).To = newNode.(Expr) - }); errF != nil { - return errF + if err := VisitRefOfWhere(in.Where, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitGroupBy(in.GroupBy, f); err != nil { + return err } - return nil -} - -// EqualsRefOfRelease does deep equals between the two objects. -func EqualsRefOfRelease(a, b *Release) bool { - if a == b { - return true + if err := VisitRefOfWhere(in.Having, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitOrderBy(in.OrderBy, f); err != nil { + return err } - return EqualsColIdent(a.Name, b.Name) + if err := VisitRefOfLimit(in.Limit, f); err != nil { + return err + } + if err := VisitRefOfSelectInto(in.Into, f); err != nil { + return err + } + return nil } - -// CloneRefOfRelease creates a deep clone of the input. -func CloneRefOfRelease(n *Release) *Release { - if n == nil { +func VisitRefOfSelectInto(in *SelectInto, f Visit) error { + if in == nil { return nil } - out := *n - out.Name = CloneColIdent(n.Name) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfRelease will visit all parts of the AST -func VisitRefOfRelease(in *Release, f Visit) error { +func VisitRefOfSet(in *Set, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitColIdent(in.Name, f); err != nil { + if err := VisitComments(in.Comments, f); err != nil { + return err + } + if err := VisitSetExprs(in.Exprs, f); err != nil { return err } return nil } - -// rewriteRefOfRelease is part of the Rewrite implementation -func (a *application) rewriteRefOfRelease(parent SQLNode, node *Release, replacer replacerFunc) error { - if node == nil { +func VisitRefOfSetExpr(in *SetExpr, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil + if cont, err := f(in); err != nil || !cont { + return err } - if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { - parent.(*Release).Name = newNode.(ColIdent) - }); errF != nil { - return errF + if err := VisitColIdent(in.Name, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitExpr(in.Expr, f); err != nil { + return err } return nil } - -// EqualsRefOfRenameIndex does deep equals between the two objects. -func EqualsRefOfRenameIndex(a, b *RenameIndex) bool { - if a == b { - return true +func VisitRefOfSetTransaction(in *SetTransaction, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.OldName == b.OldName && - a.NewName == b.NewName -} - -// CloneRefOfRenameIndex creates a deep clone of the input. -func CloneRefOfRenameIndex(n *RenameIndex) *RenameIndex { - if n == nil { - return nil + if err := VisitSQLNode(in.SQLNode, f); err != nil { + return err } - out := *n - return &out + if err := VisitComments(in.Comments, f); err != nil { + return err + } + for _, el := range in.Characteristics { + if err := VisitCharacteristic(el, f); err != nil { + return err + } + } + return nil } - -// VisitRefOfRenameIndex will visit all parts of the AST -func VisitRefOfRenameIndex(in *RenameIndex, f Visit) error { +func VisitRefOfShow(in *Show, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } + if err := VisitShowInternal(in.Internal, f); err != nil { + return err + } return nil } - -// rewriteRefOfRenameIndex is part of the Rewrite implementation -func (a *application) rewriteRefOfRenameIndex(parent SQLNode, node *RenameIndex, replacer replacerFunc) error { - if node == nil { +func VisitRefOfShowBasic(in *ShowBasic, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { - return nil + if err := VisitTableName(in.Tbl, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitRefOfShowFilter(in.Filter, f); err != nil { + return err } return nil } - -// EqualsRefOfRenameTable does deep equals between the two objects. -func EqualsRefOfRenameTable(a, b *RenameTable) bool { - if a == b { - return true +func VisitRefOfShowCreate(in *ShowCreate, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsSliceOfRefOfRenameTablePair(a.TablePairs, b.TablePairs) -} - -// CloneRefOfRenameTable creates a deep clone of the input. -func CloneRefOfRenameTable(n *RenameTable) *RenameTable { - if n == nil { - return nil + if err := VisitTableName(in.Op, f); err != nil { + return err } - out := *n - out.TablePairs = CloneSliceOfRefOfRenameTablePair(n.TablePairs) - return &out + return nil } - -// VisitRefOfRenameTable will visit all parts of the AST -func VisitRefOfRenameTable(in *RenameTable, f Visit) error { +func VisitRefOfShowFilter(in *ShowFilter, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } + if err := VisitExpr(in.Filter, f); err != nil { + return err + } return nil } - -// rewriteRefOfRenameTable is part of the Rewrite implementation -func (a *application) rewriteRefOfRenameTable(parent SQLNode, node *RenameTable, replacer replacerFunc) error { - if node == nil { +func VisitRefOfShowLegacy(in *ShowLegacy, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil + if cont, err := f(in); err != nil || !cont { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitTableName(in.OnTable, f); err != nil { + return err } - return nil -} - -// EqualsRefOfRenameTableName does deep equals between the two objects. -func EqualsRefOfRenameTableName(a, b *RenameTableName) bool { - if a == b { - return true + if err := VisitTableName(in.Table, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitExpr(in.ShowCollationFilterOpt, f); err != nil { + return err } - return EqualsTableName(a.Table, b.Table) + return nil } - -// CloneRefOfRenameTableName creates a deep clone of the input. -func CloneRefOfRenameTableName(n *RenameTableName) *RenameTableName { - if n == nil { +func VisitRefOfStarExpr(in *StarExpr, f Visit) error { + if in == nil { return nil } - out := *n - out.Table = CloneTableName(n.Table) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitTableName(in.TableName, f); err != nil { + return err + } + return nil } - -// VisitRefOfRenameTableName will visit all parts of the AST -func VisitRefOfRenameTableName(in *RenameTableName, f Visit) error { +func VisitRefOfStream(in *Stream, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } + if err := VisitComments(in.Comments, f); err != nil { + return err + } + if err := VisitSelectExpr(in.SelectExpr, f); err != nil { + return err + } if err := VisitTableName(in.Table, f); err != nil { return err } return nil } - -// rewriteRefOfRenameTableName is part of the Rewrite implementation -func (a *application) rewriteRefOfRenameTableName(parent SQLNode, node *RenameTableName, replacer replacerFunc) error { - if node == nil { +func VisitRefOfSubquery(in *Subquery, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { + if err := VisitSelectStatement(in.Select, f); err != nil { + return err + } + return nil +} +func VisitRefOfSubstrExpr(in *SubstrExpr, f Visit) error { + if in == nil { return nil } - if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { - parent.(*RenameTableName).Table = newNode.(TableName) - }); errF != nil { - return errF + if cont, err := f(in); err != nil || !cont { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitRefOfColName(in.Name, f); err != nil { + return err } - return nil -} - -// EqualsRefOfRevertMigration does deep equals between the two objects. -func EqualsRefOfRevertMigration(a, b *RevertMigration) bool { - if a == b { - return true + if err := VisitRefOfLiteral(in.StrVal, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitExpr(in.From, f); err != nil { + return err } - return a.UUID == b.UUID -} - -// CloneRefOfRevertMigration creates a deep clone of the input. -func CloneRefOfRevertMigration(n *RevertMigration) *RevertMigration { - if n == nil { - return nil + if err := VisitExpr(in.To, f); err != nil { + return err } - out := *n - return &out + return nil } - -// VisitRefOfRevertMigration will visit all parts of the AST -func VisitRefOfRevertMigration(in *RevertMigration, f Visit) error { +func VisitRefOfTableIdent(in *TableIdent, f Visit) error { if in == nil { return nil } @@ -7934,48 +8384,49 @@ func VisitRefOfRevertMigration(in *RevertMigration, f Visit) error { } return nil } - -// rewriteRefOfRevertMigration is part of the Rewrite implementation -func (a *application) rewriteRefOfRevertMigration(parent SQLNode, node *RevertMigration, replacer replacerFunc) error { - if node == nil { +func VisitRefOfTableName(in *TableName, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { - return nil + if err := VisitTableIdent(in.Name, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitTableIdent(in.Qualifier, f); err != nil { + return err } return nil } - -// EqualsRefOfRollback does deep equals between the two objects. -func EqualsRefOfRollback(a, b *Rollback) bool { - if a == b { - return true +func VisitRefOfTableSpec(in *TableSpec, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return true -} - -// CloneRefOfRollback creates a deep clone of the input. -func CloneRefOfRollback(n *Rollback) *Rollback { - if n == nil { - return nil + for _, el := range in.Columns { + if err := VisitRefOfColumnDefinition(el, f); err != nil { + return err + } } - out := *n - return &out + for _, el := range in.Indexes { + if err := VisitRefOfIndexDefinition(el, f); err != nil { + return err + } + } + for _, el := range in.Constraints { + if err := VisitRefOfConstraintDefinition(el, f); err != nil { + return err + } + } + if err := VisitTableOptions(in.Options, f); err != nil { + return err + } + return nil } - -// VisitRefOfRollback will visit all parts of the AST -func VisitRefOfRollback(in *Rollback, f Visit) error { +func VisitRefOfTablespaceOperation(in *TablespaceOperation, f Visit) error { if in == nil { return nil } @@ -7984,571 +8435,614 @@ func VisitRefOfRollback(in *Rollback, f Visit) error { } return nil } - -// rewriteRefOfRollback is part of the Rewrite implementation -func (a *application) rewriteRefOfRollback(parent SQLNode, node *Rollback, replacer replacerFunc) error { - if node == nil { +func VisitRefOfTimestampFuncExpr(in *TimestampFuncExpr, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { - return nil + if err := VisitExpr(in.Expr1, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitExpr(in.Expr2, f); err != nil { + return err } return nil } - -// EqualsRefOfSRollback does deep equals between the two objects. -func EqualsRefOfSRollback(a, b *SRollback) bool { - if a == b { - return true +func VisitRefOfTruncateTable(in *TruncateTable, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsColIdent(a.Name, b.Name) -} - -// CloneRefOfSRollback creates a deep clone of the input. -func CloneRefOfSRollback(n *SRollback) *SRollback { - if n == nil { - return nil + if err := VisitTableName(in.Table, f); err != nil { + return err } - out := *n - out.Name = CloneColIdent(n.Name) - return &out + return nil } - -// VisitRefOfSRollback will visit all parts of the AST -func VisitRefOfSRollback(in *SRollback, f Visit) error { +func VisitRefOfUnaryExpr(in *UnaryExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitColIdent(in.Name, f); err != nil { + if err := VisitExpr(in.Expr, f); err != nil { return err } return nil } - -// rewriteRefOfSRollback is part of the Rewrite implementation -func (a *application) rewriteRefOfSRollback(parent SQLNode, node *SRollback, replacer replacerFunc) error { - if node == nil { +func VisitRefOfUnion(in *Union, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { - return nil + if err := VisitSelectStatement(in.FirstStatement, f); err != nil { + return err } - if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { - parent.(*SRollback).Name = newNode.(ColIdent) - }); errF != nil { - return errF + for _, el := range in.UnionSelects { + if err := VisitRefOfUnionSelect(el, f); err != nil { + return err + } } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitOrderBy(in.OrderBy, f); err != nil { + return err + } + if err := VisitRefOfLimit(in.Limit, f); err != nil { + return err } return nil } - -// EqualsRefOfSavepoint does deep equals between the two objects. -func EqualsRefOfSavepoint(a, b *Savepoint) bool { - if a == b { - return true +func VisitRefOfUnionSelect(in *UnionSelect, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsColIdent(a.Name, b.Name) + if err := VisitSelectStatement(in.Statement, f); err != nil { + return err + } + return nil } - -// CloneRefOfSavepoint creates a deep clone of the input. -func CloneRefOfSavepoint(n *Savepoint) *Savepoint { - if n == nil { +func VisitRefOfUnlockTables(in *UnlockTables, f Visit) error { + if in == nil { return nil } - out := *n - out.Name = CloneColIdent(n.Name) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfSavepoint will visit all parts of the AST -func VisitRefOfSavepoint(in *Savepoint, f Visit) error { +func VisitRefOfUpdate(in *Update, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitColIdent(in.Name, f); err != nil { + if err := VisitComments(in.Comments, f); err != nil { return err } - return nil -} - -// rewriteRefOfSavepoint is part of the Rewrite implementation -func (a *application) rewriteRefOfSavepoint(parent SQLNode, node *Savepoint, replacer replacerFunc) error { - if node == nil { - return nil + if err := VisitTableExprs(in.TableExprs, f); err != nil { + return err } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if err := VisitUpdateExprs(in.Exprs, f); err != nil { + return err } - if a.pre != nil && !a.pre(&cur) { - return nil + if err := VisitRefOfWhere(in.Where, f); err != nil { + return err } - if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { - parent.(*Savepoint).Name = newNode.(ColIdent) - }); errF != nil { - return errF + if err := VisitOrderBy(in.OrderBy, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitRefOfLimit(in.Limit, f); err != nil { + return err } return nil } - -// EqualsRefOfSelect does deep equals between the two objects. -func EqualsRefOfSelect(a, b *Select) bool { - if a == b { - return true +func VisitRefOfUpdateExpr(in *UpdateExpr, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.Distinct == b.Distinct && - a.StraightJoinHint == b.StraightJoinHint && - a.SQLCalcFoundRows == b.SQLCalcFoundRows && - EqualsRefOfBool(a.Cache, b.Cache) && - EqualsComments(a.Comments, b.Comments) && - EqualsSelectExprs(a.SelectExprs, b.SelectExprs) && - EqualsTableExprs(a.From, b.From) && - EqualsRefOfWhere(a.Where, b.Where) && - EqualsGroupBy(a.GroupBy, b.GroupBy) && - EqualsRefOfWhere(a.Having, b.Having) && - EqualsOrderBy(a.OrderBy, b.OrderBy) && - EqualsRefOfLimit(a.Limit, b.Limit) && - a.Lock == b.Lock && - EqualsRefOfSelectInto(a.Into, b.Into) -} - -// CloneRefOfSelect creates a deep clone of the input. -func CloneRefOfSelect(n *Select) *Select { - if n == nil { - return nil + if err := VisitRefOfColName(in.Name, f); err != nil { + return err } - out := *n - out.Cache = CloneRefOfBool(n.Cache) - out.Comments = CloneComments(n.Comments) - out.SelectExprs = CloneSelectExprs(n.SelectExprs) - out.From = CloneTableExprs(n.From) - out.Where = CloneRefOfWhere(n.Where) - out.GroupBy = CloneGroupBy(n.GroupBy) - out.Having = CloneRefOfWhere(n.Having) - out.OrderBy = CloneOrderBy(n.OrderBy) - out.Limit = CloneRefOfLimit(n.Limit) - out.Into = CloneRefOfSelectInto(n.Into) - return &out + if err := VisitExpr(in.Expr, f); err != nil { + return err + } + return nil } - -// VisitRefOfSelect will visit all parts of the AST -func VisitRefOfSelect(in *Select, f Visit) error { +func VisitRefOfUse(in *Use, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitComments(in.Comments, f); err != nil { + if err := VisitTableIdent(in.DBName, f); err != nil { return err } - if err := VisitSelectExprs(in.SelectExprs, f); err != nil { - return err + return nil +} +func VisitRefOfVStream(in *VStream, f Visit) error { + if in == nil { + return nil } - if err := VisitTableExprs(in.From, f); err != nil { + if cont, err := f(in); err != nil || !cont { return err } - if err := VisitRefOfWhere(in.Where, f); err != nil { + if err := VisitComments(in.Comments, f); err != nil { return err } - if err := VisitGroupBy(in.GroupBy, f); err != nil { + if err := VisitSelectExpr(in.SelectExpr, f); err != nil { return err } - if err := VisitRefOfWhere(in.Having, f); err != nil { + if err := VisitTableName(in.Table, f); err != nil { return err } - if err := VisitOrderBy(in.OrderBy, f); err != nil { + if err := VisitRefOfWhere(in.Where, f); err != nil { return err } if err := VisitRefOfLimit(in.Limit, f); err != nil { return err } - if err := VisitRefOfSelectInto(in.Into, f); err != nil { - return err - } return nil } - -// rewriteRefOfSelect is part of the Rewrite implementation -func (a *application) rewriteRefOfSelect(parent SQLNode, node *Select, replacer replacerFunc) error { - if node == nil { +func VisitRefOfValidation(in *Validation, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { + return nil +} +func VisitRefOfValuesFuncExpr(in *ValuesFuncExpr, f Visit) error { + if in == nil { return nil } - if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { - parent.(*Select).Comments = newNode.(Comments) - }); errF != nil { - return errF + if cont, err := f(in); err != nil || !cont { + return err } - if errF := a.rewriteSelectExprs(node, node.SelectExprs, func(newNode, parent SQLNode) { - parent.(*Select).SelectExprs = newNode.(SelectExprs) - }); errF != nil { - return errF + if err := VisitRefOfColName(in.Name, f); err != nil { + return err } - if errF := a.rewriteTableExprs(node, node.From, func(newNode, parent SQLNode) { - parent.(*Select).From = newNode.(TableExprs) - }); errF != nil { - return errF + return nil +} +func VisitRefOfVindexParam(in *VindexParam, f Visit) error { + if in == nil { + return nil } - if errF := a.rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { - parent.(*Select).Where = newNode.(*Where) - }); errF != nil { - return errF + if cont, err := f(in); err != nil || !cont { + return err } - if errF := a.rewriteGroupBy(node, node.GroupBy, func(newNode, parent SQLNode) { - parent.(*Select).GroupBy = newNode.(GroupBy) - }); errF != nil { - return errF + if err := VisitColIdent(in.Key, f); err != nil { + return err } - if errF := a.rewriteRefOfWhere(node, node.Having, func(newNode, parent SQLNode) { - parent.(*Select).Having = newNode.(*Where) - }); errF != nil { - return errF + return nil +} +func VisitRefOfVindexSpec(in *VindexSpec, f Visit) error { + if in == nil { + return nil } - if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { - parent.(*Select).OrderBy = newNode.(OrderBy) - }); errF != nil { - return errF + if cont, err := f(in); err != nil || !cont { + return err } - if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { - parent.(*Select).Limit = newNode.(*Limit) - }); errF != nil { - return errF + if err := VisitColIdent(in.Name, f); err != nil { + return err } - if errF := a.rewriteRefOfSelectInto(node, node.Into, func(newNode, parent SQLNode) { - parent.(*Select).Into = newNode.(*SelectInto) - }); errF != nil { - return errF + if err := VisitColIdent(in.Type, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + for _, el := range in.Params { + if err := VisitVindexParam(el, f); err != nil { + return err + } } return nil } - -// EqualsSelectExprs does deep equals between the two objects. -func EqualsSelectExprs(a, b SelectExprs) bool { - if len(a) != len(b) { - return false +func VisitRefOfWhen(in *When, f Visit) error { + if in == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsSelectExpr(a[i], b[i]) { - return false - } + if cont, err := f(in); err != nil || !cont { + return err } - return true -} - -// CloneSelectExprs creates a deep clone of the input. -func CloneSelectExprs(n SelectExprs) SelectExprs { - res := make(SelectExprs, 0, len(n)) - for _, x := range n { - res = append(res, CloneSelectExpr(x)) + if err := VisitExpr(in.Cond, f); err != nil { + return err } - return res + if err := VisitExpr(in.Val, f); err != nil { + return err + } + return nil } - -// VisitSelectExprs will visit all parts of the AST -func VisitSelectExprs(in SelectExprs, f Visit) error { +func VisitRefOfWhere(in *Where, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - for _, el := range in { - if err := VisitSelectExpr(el, f); err != nil { - return err - } + if err := VisitExpr(in.Expr, f); err != nil { + return err } return nil } - -// rewriteSelectExprs is part of the Rewrite implementation -func (a *application) rewriteSelectExprs(parent SQLNode, node SelectExprs, replacer replacerFunc) error { - if node == nil { +func VisitRefOfXorExpr(in *XorExpr, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil + if cont, err := f(in); err != nil || !cont { + return err } - for i, el := range node { - if errF := a.rewriteSelectExpr(node, el, func(newNode, parent SQLNode) { - parent.(SelectExprs)[i] = newNode.(SelectExpr) - }); errF != nil { - return errF - } + if err := VisitExpr(in.Left, f); err != nil { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + if err := VisitExpr(in.Right, f); err != nil { + return err } return nil } - -// EqualsRefOfSelectInto does deep equals between the two objects. -func EqualsRefOfSelectInto(a, b *SelectInto) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.FileName == b.FileName && - a.Charset == b.Charset && - a.FormatOption == b.FormatOption && - a.ExportOption == b.ExportOption && - a.Manifest == b.Manifest && - a.Overwrite == b.Overwrite && - a.Type == b.Type +func VisitReferenceAction(in ReferenceAction, f Visit) error { + _, err := f(in) + return err } - -// CloneRefOfSelectInto creates a deep clone of the input. -func CloneRefOfSelectInto(n *SelectInto) *SelectInto { - if n == nil { +func VisitSQLNode(in SQLNode, f Visit) error { + if in == nil { + return nil + } + switch in := in.(type) { + case AccessMode: + return VisitAccessMode(in, f) + case *AddColumns: + return VisitRefOfAddColumns(in, f) + case *AddConstraintDefinition: + return VisitRefOfAddConstraintDefinition(in, f) + case *AddIndexDefinition: + return VisitRefOfAddIndexDefinition(in, f) + case AlgorithmValue: + return VisitAlgorithmValue(in, f) + case *AliasedExpr: + return VisitRefOfAliasedExpr(in, f) + case *AliasedTableExpr: + return VisitRefOfAliasedTableExpr(in, f) + case *AlterCharset: + return VisitRefOfAlterCharset(in, f) + case *AlterColumn: + return VisitRefOfAlterColumn(in, f) + case *AlterDatabase: + return VisitRefOfAlterDatabase(in, f) + case *AlterMigration: + return VisitRefOfAlterMigration(in, f) + case *AlterTable: + return VisitRefOfAlterTable(in, f) + case *AlterView: + return VisitRefOfAlterView(in, f) + case *AlterVschema: + return VisitRefOfAlterVschema(in, f) + case *AndExpr: + return VisitRefOfAndExpr(in, f) + case Argument: + return VisitArgument(in, f) + case *AutoIncSpec: + return VisitRefOfAutoIncSpec(in, f) + case *Begin: + return VisitRefOfBegin(in, f) + case *BinaryExpr: + return VisitRefOfBinaryExpr(in, f) + case BoolVal: + return VisitBoolVal(in, f) + case *CallProc: + return VisitRefOfCallProc(in, f) + case *CaseExpr: + return VisitRefOfCaseExpr(in, f) + case *ChangeColumn: + return VisitRefOfChangeColumn(in, f) + case *CheckConstraintDefinition: + return VisitRefOfCheckConstraintDefinition(in, f) + case ColIdent: + return VisitColIdent(in, f) + case *ColName: + return VisitRefOfColName(in, f) + case *CollateExpr: + return VisitRefOfCollateExpr(in, f) + case *ColumnDefinition: + return VisitRefOfColumnDefinition(in, f) + case *ColumnType: + return VisitRefOfColumnType(in, f) + case Columns: + return VisitColumns(in, f) + case Comments: + return VisitComments(in, f) + case *Commit: + return VisitRefOfCommit(in, f) + case *ComparisonExpr: + return VisitRefOfComparisonExpr(in, f) + case *ConstraintDefinition: + return VisitRefOfConstraintDefinition(in, f) + case *ConvertExpr: + return VisitRefOfConvertExpr(in, f) + case *ConvertType: + return VisitRefOfConvertType(in, f) + case *ConvertUsingExpr: + return VisitRefOfConvertUsingExpr(in, f) + case *CreateDatabase: + return VisitRefOfCreateDatabase(in, f) + case *CreateTable: + return VisitRefOfCreateTable(in, f) + case *CreateView: + return VisitRefOfCreateView(in, f) + case *CurTimeFuncExpr: + return VisitRefOfCurTimeFuncExpr(in, f) + case *Default: + return VisitRefOfDefault(in, f) + case *Delete: + return VisitRefOfDelete(in, f) + case *DerivedTable: + return VisitRefOfDerivedTable(in, f) + case *DropColumn: + return VisitRefOfDropColumn(in, f) + case *DropDatabase: + return VisitRefOfDropDatabase(in, f) + case *DropKey: + return VisitRefOfDropKey(in, f) + case *DropTable: + return VisitRefOfDropTable(in, f) + case *DropView: + return VisitRefOfDropView(in, f) + case *ExistsExpr: + return VisitRefOfExistsExpr(in, f) + case *ExplainStmt: + return VisitRefOfExplainStmt(in, f) + case *ExplainTab: + return VisitRefOfExplainTab(in, f) + case Exprs: + return VisitExprs(in, f) + case *Flush: + return VisitRefOfFlush(in, f) + case *Force: + return VisitRefOfForce(in, f) + case *ForeignKeyDefinition: + return VisitRefOfForeignKeyDefinition(in, f) + case *FuncExpr: + return VisitRefOfFuncExpr(in, f) + case GroupBy: + return VisitGroupBy(in, f) + case *GroupConcatExpr: + return VisitRefOfGroupConcatExpr(in, f) + case *IndexDefinition: + return VisitRefOfIndexDefinition(in, f) + case *IndexHints: + return VisitRefOfIndexHints(in, f) + case *IndexInfo: + return VisitRefOfIndexInfo(in, f) + case *Insert: + return VisitRefOfInsert(in, f) + case *IntervalExpr: + return VisitRefOfIntervalExpr(in, f) + case *IsExpr: + return VisitRefOfIsExpr(in, f) + case IsolationLevel: + return VisitIsolationLevel(in, f) + case JoinCondition: + return VisitJoinCondition(in, f) + case *JoinTableExpr: + return VisitRefOfJoinTableExpr(in, f) + case *KeyState: + return VisitRefOfKeyState(in, f) + case *Limit: + return VisitRefOfLimit(in, f) + case ListArg: + return VisitListArg(in, f) + case *Literal: + return VisitRefOfLiteral(in, f) + case *Load: + return VisitRefOfLoad(in, f) + case *LockOption: + return VisitRefOfLockOption(in, f) + case *LockTables: + return VisitRefOfLockTables(in, f) + case *MatchExpr: + return VisitRefOfMatchExpr(in, f) + case *ModifyColumn: + return VisitRefOfModifyColumn(in, f) + case *Nextval: + return VisitRefOfNextval(in, f) + case *NotExpr: + return VisitRefOfNotExpr(in, f) + case *NullVal: + return VisitRefOfNullVal(in, f) + case OnDup: + return VisitOnDup(in, f) + case *OptLike: + return VisitRefOfOptLike(in, f) + case *OrExpr: + return VisitRefOfOrExpr(in, f) + case *Order: + return VisitRefOfOrder(in, f) + case OrderBy: + return VisitOrderBy(in, f) + case *OrderByOption: + return VisitRefOfOrderByOption(in, f) + case *OtherAdmin: + return VisitRefOfOtherAdmin(in, f) + case *OtherRead: + return VisitRefOfOtherRead(in, f) + case *ParenSelect: + return VisitRefOfParenSelect(in, f) + case *ParenTableExpr: + return VisitRefOfParenTableExpr(in, f) + case *PartitionDefinition: + return VisitRefOfPartitionDefinition(in, f) + case *PartitionSpec: + return VisitRefOfPartitionSpec(in, f) + case Partitions: + return VisitPartitions(in, f) + case *RangeCond: + return VisitRefOfRangeCond(in, f) + case ReferenceAction: + return VisitReferenceAction(in, f) + case *Release: + return VisitRefOfRelease(in, f) + case *RenameIndex: + return VisitRefOfRenameIndex(in, f) + case *RenameTable: + return VisitRefOfRenameTable(in, f) + case *RenameTableName: + return VisitRefOfRenameTableName(in, f) + case *RevertMigration: + return VisitRefOfRevertMigration(in, f) + case *Rollback: + return VisitRefOfRollback(in, f) + case *SRollback: + return VisitRefOfSRollback(in, f) + case *Savepoint: + return VisitRefOfSavepoint(in, f) + case *Select: + return VisitRefOfSelect(in, f) + case SelectExprs: + return VisitSelectExprs(in, f) + case *SelectInto: + return VisitRefOfSelectInto(in, f) + case *Set: + return VisitRefOfSet(in, f) + case *SetExpr: + return VisitRefOfSetExpr(in, f) + case SetExprs: + return VisitSetExprs(in, f) + case *SetTransaction: + return VisitRefOfSetTransaction(in, f) + case *Show: + return VisitRefOfShow(in, f) + case *ShowBasic: + return VisitRefOfShowBasic(in, f) + case *ShowCreate: + return VisitRefOfShowCreate(in, f) + case *ShowFilter: + return VisitRefOfShowFilter(in, f) + case *ShowLegacy: + return VisitRefOfShowLegacy(in, f) + case *StarExpr: + return VisitRefOfStarExpr(in, f) + case *Stream: + return VisitRefOfStream(in, f) + case *Subquery: + return VisitRefOfSubquery(in, f) + case *SubstrExpr: + return VisitRefOfSubstrExpr(in, f) + case TableExprs: + return VisitTableExprs(in, f) + case TableIdent: + return VisitTableIdent(in, f) + case TableName: + return VisitTableName(in, f) + case TableNames: + return VisitTableNames(in, f) + case TableOptions: + return VisitTableOptions(in, f) + case *TableSpec: + return VisitRefOfTableSpec(in, f) + case *TablespaceOperation: + return VisitRefOfTablespaceOperation(in, f) + case *TimestampFuncExpr: + return VisitRefOfTimestampFuncExpr(in, f) + case *TruncateTable: + return VisitRefOfTruncateTable(in, f) + case *UnaryExpr: + return VisitRefOfUnaryExpr(in, f) + case *Union: + return VisitRefOfUnion(in, f) + case *UnionSelect: + return VisitRefOfUnionSelect(in, f) + case *UnlockTables: + return VisitRefOfUnlockTables(in, f) + case *Update: + return VisitRefOfUpdate(in, f) + case *UpdateExpr: + return VisitRefOfUpdateExpr(in, f) + case UpdateExprs: + return VisitUpdateExprs(in, f) + case *Use: + return VisitRefOfUse(in, f) + case *VStream: + return VisitRefOfVStream(in, f) + case ValTuple: + return VisitValTuple(in, f) + case *Validation: + return VisitRefOfValidation(in, f) + case Values: + return VisitValues(in, f) + case *ValuesFuncExpr: + return VisitRefOfValuesFuncExpr(in, f) + case VindexParam: + return VisitVindexParam(in, f) + case *VindexSpec: + return VisitRefOfVindexSpec(in, f) + case *When: + return VisitRefOfWhen(in, f) + case *Where: + return VisitRefOfWhere(in, f) + case *XorExpr: + return VisitRefOfXorExpr(in, f) + default: + // this should never happen return nil } - out := *n - return &out } - -// VisitRefOfSelectInto will visit all parts of the AST -func VisitRefOfSelectInto(in *SelectInto, f Visit) error { +func VisitSelectExpr(in SelectExpr, f Visit) error { if in == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil -} - -// rewriteRefOfSelectInto is part of the Rewrite implementation -func (a *application) rewriteRefOfSelectInto(parent SQLNode, node *SelectInto, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil -} - -// EqualsRefOfSet does deep equals between the two objects. -func EqualsRefOfSet(a, b *Set) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsComments(a.Comments, b.Comments) && - EqualsSetExprs(a.Exprs, b.Exprs) -} - -// CloneRefOfSet creates a deep clone of the input. -func CloneRefOfSet(n *Set) *Set { - if n == nil { + switch in := in.(type) { + case *AliasedExpr: + return VisitRefOfAliasedExpr(in, f) + case *Nextval: + return VisitRefOfNextval(in, f) + case *StarExpr: + return VisitRefOfStarExpr(in, f) + default: + // this should never happen return nil } - out := *n - out.Comments = CloneComments(n.Comments) - out.Exprs = CloneSetExprs(n.Exprs) - return &out } - -// VisitRefOfSet will visit all parts of the AST -func VisitRefOfSet(in *Set, f Visit) error { +func VisitSelectExprs(in SelectExprs, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitComments(in.Comments, f); err != nil { - return err - } - if err := VisitSetExprs(in.Exprs, f); err != nil { - return err - } - return nil -} - -// rewriteRefOfSet is part of the Rewrite implementation -func (a *application) rewriteRefOfSet(parent SQLNode, node *Set, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { - parent.(*Set).Comments = newNode.(Comments) - }); errF != nil { - return errF - } - if errF := a.rewriteSetExprs(node, node.Exprs, func(newNode, parent SQLNode) { - parent.(*Set).Exprs = newNode.(SetExprs) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort + for _, el := range in { + if err := VisitSelectExpr(el, f); err != nil { + return err + } } return nil } - -// EqualsRefOfSetExpr does deep equals between the two objects. -func EqualsRefOfSetExpr(a, b *SetExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Scope == b.Scope && - EqualsColIdent(a.Name, b.Name) && - EqualsExpr(a.Expr, b.Expr) -} - -// CloneRefOfSetExpr creates a deep clone of the input. -func CloneRefOfSetExpr(n *SetExpr) *SetExpr { - if n == nil { - return nil - } - out := *n - out.Name = CloneColIdent(n.Name) - out.Expr = CloneExpr(n.Expr) - return &out -} - -// VisitRefOfSetExpr will visit all parts of the AST -func VisitRefOfSetExpr(in *SetExpr, f Visit) error { +func VisitSelectStatement(in SelectStatement, f Visit) error { if in == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitColIdent(in.Name, f); err != nil { - return err - } - if err := VisitExpr(in.Expr, f); err != nil { - return err - } - return nil -} - -// rewriteRefOfSetExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfSetExpr(parent SQLNode, node *SetExpr, replacer replacerFunc) error { - if node == nil { - return nil - } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { - return nil - } - if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { - parent.(*SetExpr).Name = newNode.(ColIdent) - }); errF != nil { - return errF - } - if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*SetExpr).Expr = newNode.(Expr) - }); errF != nil { - return errF - } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil -} - -// EqualsSetExprs does deep equals between the two objects. -func EqualsSetExprs(a, b SetExprs) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsRefOfSetExpr(a[i], b[i]) { - return false - } - } - return true -} - -// CloneSetExprs creates a deep clone of the input. -func CloneSetExprs(n SetExprs) SetExprs { - res := make(SetExprs, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfSetExpr(x)) + switch in := in.(type) { + case *ParenSelect: + return VisitRefOfParenSelect(in, f) + case *Select: + return VisitRefOfSelect(in, f) + case *Union: + return VisitRefOfUnion(in, f) + default: + // this should never happen + return nil } - return res } - -// VisitSetExprs will visit all parts of the AST func VisitSetExprs(in SetExprs, f Visit) error { if in == nil { return nil @@ -8563,157 +9057,246 @@ func VisitSetExprs(in SetExprs, f Visit) error { } return nil } - -// rewriteSetExprs is part of the Rewrite implementation -func (a *application) rewriteSetExprs(parent SQLNode, node SetExprs, replacer replacerFunc) error { - if node == nil { +func VisitShowInternal(in ShowInternal, f Visit) error { + if in == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + switch in := in.(type) { + case *ShowBasic: + return VisitRefOfShowBasic(in, f) + case *ShowCreate: + return VisitRefOfShowCreate(in, f) + case *ShowLegacy: + return VisitRefOfShowLegacy(in, f) + default: + // this should never happen return nil } - for i, el := range node { - if errF := a.rewriteRefOfSetExpr(node, el, func(newNode, parent SQLNode) { - parent.(SetExprs)[i] = newNode.(*SetExpr) - }); errF != nil { - return errF - } +} +func VisitSimpleTableExpr(in SimpleTableExpr, f Visit) error { + if in == nil { + return nil } - if a.post != nil && !a.post(&cur) { - return errAbort + switch in := in.(type) { + case *DerivedTable: + return VisitRefOfDerivedTable(in, f) + case TableName: + return VisitTableName(in, f) + default: + // this should never happen + return nil } - return nil } - -// EqualsRefOfSetTransaction does deep equals between the two objects. -func EqualsRefOfSetTransaction(a, b *SetTransaction) bool { - if a == b { - return true +func VisitStatement(in Statement, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + switch in := in.(type) { + case *AlterDatabase: + return VisitRefOfAlterDatabase(in, f) + case *AlterMigration: + return VisitRefOfAlterMigration(in, f) + case *AlterTable: + return VisitRefOfAlterTable(in, f) + case *AlterView: + return VisitRefOfAlterView(in, f) + case *AlterVschema: + return VisitRefOfAlterVschema(in, f) + case *Begin: + return VisitRefOfBegin(in, f) + case *CallProc: + return VisitRefOfCallProc(in, f) + case *Commit: + return VisitRefOfCommit(in, f) + case *CreateDatabase: + return VisitRefOfCreateDatabase(in, f) + case *CreateTable: + return VisitRefOfCreateTable(in, f) + case *CreateView: + return VisitRefOfCreateView(in, f) + case *Delete: + return VisitRefOfDelete(in, f) + case *DropDatabase: + return VisitRefOfDropDatabase(in, f) + case *DropTable: + return VisitRefOfDropTable(in, f) + case *DropView: + return VisitRefOfDropView(in, f) + case *ExplainStmt: + return VisitRefOfExplainStmt(in, f) + case *ExplainTab: + return VisitRefOfExplainTab(in, f) + case *Flush: + return VisitRefOfFlush(in, f) + case *Insert: + return VisitRefOfInsert(in, f) + case *Load: + return VisitRefOfLoad(in, f) + case *LockTables: + return VisitRefOfLockTables(in, f) + case *OtherAdmin: + return VisitRefOfOtherAdmin(in, f) + case *OtherRead: + return VisitRefOfOtherRead(in, f) + case *ParenSelect: + return VisitRefOfParenSelect(in, f) + case *Release: + return VisitRefOfRelease(in, f) + case *RenameTable: + return VisitRefOfRenameTable(in, f) + case *RevertMigration: + return VisitRefOfRevertMigration(in, f) + case *Rollback: + return VisitRefOfRollback(in, f) + case *SRollback: + return VisitRefOfSRollback(in, f) + case *Savepoint: + return VisitRefOfSavepoint(in, f) + case *Select: + return VisitRefOfSelect(in, f) + case *Set: + return VisitRefOfSet(in, f) + case *SetTransaction: + return VisitRefOfSetTransaction(in, f) + case *Show: + return VisitRefOfShow(in, f) + case *Stream: + return VisitRefOfStream(in, f) + case *TruncateTable: + return VisitRefOfTruncateTable(in, f) + case *Union: + return VisitRefOfUnion(in, f) + case *UnlockTables: + return VisitRefOfUnlockTables(in, f) + case *Update: + return VisitRefOfUpdate(in, f) + case *Use: + return VisitRefOfUse(in, f) + case *VStream: + return VisitRefOfVStream(in, f) + default: + // this should never happen + return nil } - return EqualsSQLNode(a.SQLNode, b.SQLNode) && - EqualsComments(a.Comments, b.Comments) && - a.Scope == b.Scope && - EqualsSliceOfCharacteristic(a.Characteristics, b.Characteristics) } - -// CloneRefOfSetTransaction creates a deep clone of the input. -func CloneRefOfSetTransaction(n *SetTransaction) *SetTransaction { - if n == nil { +func VisitTableExpr(in TableExpr, f Visit) error { + if in == nil { + return nil + } + switch in := in.(type) { + case *AliasedTableExpr: + return VisitRefOfAliasedTableExpr(in, f) + case *JoinTableExpr: + return VisitRefOfJoinTableExpr(in, f) + case *ParenTableExpr: + return VisitRefOfParenTableExpr(in, f) + default: + // this should never happen return nil } - out := *n - out.SQLNode = CloneSQLNode(n.SQLNode) - out.Comments = CloneComments(n.Comments) - out.Characteristics = CloneSliceOfCharacteristic(n.Characteristics) - return &out } - -// VisitRefOfSetTransaction will visit all parts of the AST -func VisitRefOfSetTransaction(in *SetTransaction, f Visit) error { +func VisitTableExprs(in TableExprs, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitSQLNode(in.SQLNode, f); err != nil { - return err - } - if err := VisitComments(in.Comments, f); err != nil { - return err - } - for _, el := range in.Characteristics { - if err := VisitCharacteristic(el, f); err != nil { + for _, el := range in { + if err := VisitTableExpr(el, f); err != nil { return err } } return nil } - -// rewriteRefOfSetTransaction is part of the Rewrite implementation -func (a *application) rewriteRefOfSetTransaction(parent SQLNode, node *SetTransaction, replacer replacerFunc) error { - if node == nil { - return nil +func VisitTableIdent(in TableIdent, f Visit) error { + if cont, err := f(in); err != nil || !cont { + return err } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, + return nil +} +func VisitTableName(in TableName, f Visit) error { + if cont, err := f(in); err != nil || !cont { + return err } - if a.pre != nil && !a.pre(&cur) { - return nil + if err := VisitTableIdent(in.Name, f); err != nil { + return err } - if errF := a.rewriteSQLNode(node, node.SQLNode, func(newNode, parent SQLNode) { - parent.(*SetTransaction).SQLNode = newNode.(SQLNode) - }); errF != nil { - return errF + if err := VisitTableIdent(in.Qualifier, f); err != nil { + return err } - if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { - parent.(*SetTransaction).Comments = newNode.(Comments) - }); errF != nil { - return errF + return nil +} +func VisitTableNames(in TableNames, f Visit) error { + if in == nil { + return nil } - for i, el := range node.Characteristics { - if errF := a.rewriteCharacteristic(node, el, func(newNode, parent SQLNode) { - parent.(*SetTransaction).Characteristics[i] = newNode.(Characteristic) - }); errF != nil { - return errF - } + if cont, err := f(in); err != nil || !cont { + return err } - if a.post != nil && !a.post(&cur) { - return errAbort + for _, el := range in { + if err := VisitTableName(el, f); err != nil { + return err + } } return nil } - -// EqualsRefOfShow does deep equals between the two objects. -func EqualsRefOfShow(a, b *Show) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsShowInternal(a.Internal, b.Internal) +func VisitTableOptions(in TableOptions, f Visit) error { + _, err := f(in) + return err } - -// CloneRefOfShow creates a deep clone of the input. -func CloneRefOfShow(n *Show) *Show { - if n == nil { +func VisitUpdateExprs(in UpdateExprs, f Visit) error { + if in == nil { return nil } - out := *n - out.Internal = CloneShowInternal(n.Internal) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + for _, el := range in { + if err := VisitRefOfUpdateExpr(el, f); err != nil { + return err + } + } + return nil } - -// VisitRefOfShow will visit all parts of the AST -func VisitRefOfShow(in *Show, f Visit) error { +func VisitValTuple(in ValTuple, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitShowInternal(in.Internal, f); err != nil { - return err + for _, el := range in { + if err := VisitExpr(el, f); err != nil { + return err + } } return nil } - -// rewriteRefOfShow is part of the Rewrite implementation -func (a *application) rewriteRefOfShow(parent SQLNode, node *Show, replacer replacerFunc) error { - if node == nil { +func VisitValues(in Values, f Visit) error { + if in == nil { return nil } + if cont, err := f(in); err != nil || !cont { + return err + } + for _, el := range in { + if err := VisitValTuple(el, f); err != nil { + return err + } + } + return nil +} +func VisitVindexParam(in VindexParam, f Visit) error { + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitColIdent(in.Key, f); err != nil { + return err + } + return nil +} +func (a *application) rewriteAccessMode(parent SQLNode, node AccessMode, replacer replacerFunc) error { cur := Cursor{ node: node, parent: parent, @@ -8722,65 +9305,74 @@ func (a *application) rewriteRefOfShow(parent SQLNode, node *Show, replacer repl if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteShowInternal(node, node.Internal, func(newNode, parent SQLNode) { - parent.(*Show).Internal = newNode.(ShowInternal) - }); errF != nil { - return errF - } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfShowBasic does deep equals between the two objects. -func EqualsRefOfShowBasic(a, b *ShowBasic) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Full == b.Full && - a.DbName == b.DbName && - a.Command == b.Command && - EqualsTableName(a.Tbl, b.Tbl) && - EqualsRefOfShowFilter(a.Filter, b.Filter) -} - -// CloneRefOfShowBasic creates a deep clone of the input. -func CloneRefOfShowBasic(n *ShowBasic) *ShowBasic { - if n == nil { - return nil +func (a *application) rewriteAlgorithmValue(parent SQLNode, node AlgorithmValue, replacer replacerFunc) error { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - out := *n - out.Tbl = CloneTableName(n.Tbl) - out.Filter = CloneRefOfShowFilter(n.Filter) - return &out -} - -// VisitRefOfShowBasic will visit all parts of the AST -func VisitRefOfShowBasic(in *ShowBasic, f Visit) error { - if in == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableName(in.Tbl, f); err != nil { - return err - } - if err := VisitRefOfShowFilter(in.Filter, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfShowBasic is part of the Rewrite implementation -func (a *application) rewriteRefOfShowBasic(parent SQLNode, node *ShowBasic, replacer replacerFunc) error { +func (a *application) rewriteAlterOption(parent SQLNode, node AlterOption, replacer replacerFunc) error { if node == nil { return nil } + switch node := node.(type) { + case *AddColumns: + return a.rewriteRefOfAddColumns(parent, node, replacer) + case *AddConstraintDefinition: + return a.rewriteRefOfAddConstraintDefinition(parent, node, replacer) + case *AddIndexDefinition: + return a.rewriteRefOfAddIndexDefinition(parent, node, replacer) + case AlgorithmValue: + return a.rewriteAlgorithmValue(parent, node, replacer) + case *AlterCharset: + return a.rewriteRefOfAlterCharset(parent, node, replacer) + case *AlterColumn: + return a.rewriteRefOfAlterColumn(parent, node, replacer) + case *ChangeColumn: + return a.rewriteRefOfChangeColumn(parent, node, replacer) + case *DropColumn: + return a.rewriteRefOfDropColumn(parent, node, replacer) + case *DropKey: + return a.rewriteRefOfDropKey(parent, node, replacer) + case *Force: + return a.rewriteRefOfForce(parent, node, replacer) + case *KeyState: + return a.rewriteRefOfKeyState(parent, node, replacer) + case *LockOption: + return a.rewriteRefOfLockOption(parent, node, replacer) + case *ModifyColumn: + return a.rewriteRefOfModifyColumn(parent, node, replacer) + case *OrderByOption: + return a.rewriteRefOfOrderByOption(parent, node, replacer) + case *RenameIndex: + return a.rewriteRefOfRenameIndex(parent, node, replacer) + case *RenameTableName: + return a.rewriteRefOfRenameTableName(parent, node, replacer) + case TableOptions: + return a.rewriteTableOptions(parent, node, replacer) + case *TablespaceOperation: + return a.rewriteRefOfTablespaceOperation(parent, node, replacer) + case *Validation: + return a.rewriteRefOfValidation(parent, node, replacer) + default: + // this should never happen + return nil + } +} +func (a *application) rewriteArgument(parent SQLNode, node Argument, replacer replacerFunc) error { cur := Cursor{ node: node, parent: parent, @@ -8789,63 +9381,41 @@ func (a *application) rewriteRefOfShowBasic(parent SQLNode, node *ShowBasic, rep if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteTableName(node, node.Tbl, func(newNode, parent SQLNode) { - parent.(*ShowBasic).Tbl = newNode.(TableName) - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfShowFilter(node, node.Filter, func(newNode, parent SQLNode) { - parent.(*ShowBasic).Filter = newNode.(*ShowFilter) - }); errF != nil { - return errF - } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfShowCreate does deep equals between the two objects. -func EqualsRefOfShowCreate(a, b *ShowCreate) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Command == b.Command && - EqualsTableName(a.Op, b.Op) -} - -// CloneRefOfShowCreate creates a deep clone of the input. -func CloneRefOfShowCreate(n *ShowCreate) *ShowCreate { - if n == nil { - return nil +func (a *application) rewriteBoolVal(parent SQLNode, node BoolVal, replacer replacerFunc) error { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - out := *n - out.Op = CloneTableName(n.Op) - return &out -} - -// VisitRefOfShowCreate will visit all parts of the AST -func VisitRefOfShowCreate(in *ShowCreate, f Visit) error { - if in == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableName(in.Op, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfShowCreate is part of the Rewrite implementation -func (a *application) rewriteRefOfShowCreate(parent SQLNode, node *ShowCreate, replacer replacerFunc) error { +func (a *application) rewriteCharacteristic(parent SQLNode, node Characteristic, replacer replacerFunc) error { if node == nil { return nil } + switch node := node.(type) { + case AccessMode: + return a.rewriteAccessMode(parent, node, replacer) + case IsolationLevel: + return a.rewriteIsolationLevel(parent, node, replacer) + default: + // this should never happen + return nil + } +} +func (a *application) rewriteColIdent(parent SQLNode, node ColIdent, replacer replacerFunc) error { + var err error cur := Cursor{ node: node, parent: parent, @@ -8854,55 +9424,55 @@ func (a *application) rewriteRefOfShowCreate(parent SQLNode, node *ShowCreate, r if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteTableName(node, node.Op, func(newNode, parent SQLNode) { - parent.(*ShowCreate).Op = newNode.(TableName) - }); errF != nil { - return errF + if err != nil { + return err } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfShowFilter does deep equals between the two objects. -func EqualsRefOfShowFilter(a, b *ShowFilter) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +func (a *application) rewriteColTuple(parent SQLNode, node ColTuple, replacer replacerFunc) error { + if node == nil { + return nil } - return a.Like == b.Like && - EqualsExpr(a.Filter, b.Filter) -} - -// CloneRefOfShowFilter creates a deep clone of the input. -func CloneRefOfShowFilter(n *ShowFilter) *ShowFilter { - if n == nil { + switch node := node.(type) { + case ListArg: + return a.rewriteListArg(parent, node, replacer) + case *Subquery: + return a.rewriteRefOfSubquery(parent, node, replacer) + case ValTuple: + return a.rewriteValTuple(parent, node, replacer) + default: + // this should never happen return nil } - out := *n - out.Filter = CloneExpr(n.Filter) - return &out } - -// VisitRefOfShowFilter will visit all parts of the AST -func VisitRefOfShowFilter(in *ShowFilter, f Visit) error { - if in == nil { +func (a *application) rewriteColumns(parent SQLNode, node Columns, replacer replacerFunc) error { + if node == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if err := VisitExpr(in.Filter, f); err != nil { - return err + if a.pre != nil && !a.pre(&cur) { + return nil + } + for i, el := range node { + if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { + parent.(Columns)[i] = newNode.(ColIdent) + }); errF != nil { + return errF + } + } + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfShowFilter is part of the Rewrite implementation -func (a *application) rewriteRefOfShowFilter(parent SQLNode, node *ShowFilter, replacer replacerFunc) error { +func (a *application) rewriteComments(parent SQLNode, node Comments, replacer replacerFunc) error { if node == nil { return nil } @@ -8914,69 +9484,154 @@ func (a *application) rewriteRefOfShowFilter(parent SQLNode, node *ShowFilter, r if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteExpr(node, node.Filter, func(newNode, parent SQLNode) { - parent.(*ShowFilter).Filter = newNode.(Expr) - }); errF != nil { - return errF - } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfShowLegacy does deep equals between the two objects. -func EqualsRefOfShowLegacy(a, b *ShowLegacy) bool { - if a == b { - return true +func (a *application) rewriteConstraintInfo(parent SQLNode, node ConstraintInfo, replacer replacerFunc) error { + if node == nil { + return nil } - if a == nil || b == nil { - return false + switch node := node.(type) { + case *CheckConstraintDefinition: + return a.rewriteRefOfCheckConstraintDefinition(parent, node, replacer) + case *ForeignKeyDefinition: + return a.rewriteRefOfForeignKeyDefinition(parent, node, replacer) + default: + // this should never happen + return nil } - return a.Extended == b.Extended && - a.Type == b.Type && - EqualsTableName(a.OnTable, b.OnTable) && - EqualsTableName(a.Table, b.Table) && - EqualsRefOfShowTablesOpt(a.ShowTablesOpt, b.ShowTablesOpt) && - a.Scope == b.Scope && - EqualsExpr(a.ShowCollationFilterOpt, b.ShowCollationFilterOpt) } - -// CloneRefOfShowLegacy creates a deep clone of the input. -func CloneRefOfShowLegacy(n *ShowLegacy) *ShowLegacy { - if n == nil { +func (a *application) rewriteDBDDLStatement(parent SQLNode, node DBDDLStatement, replacer replacerFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *AlterDatabase: + return a.rewriteRefOfAlterDatabase(parent, node, replacer) + case *CreateDatabase: + return a.rewriteRefOfCreateDatabase(parent, node, replacer) + case *DropDatabase: + return a.rewriteRefOfDropDatabase(parent, node, replacer) + default: + // this should never happen return nil } - out := *n - out.OnTable = CloneTableName(n.OnTable) - out.Table = CloneTableName(n.Table) - out.ShowTablesOpt = CloneRefOfShowTablesOpt(n.ShowTablesOpt) - out.ShowCollationFilterOpt = CloneExpr(n.ShowCollationFilterOpt) - return &out } - -// VisitRefOfShowLegacy will visit all parts of the AST -func VisitRefOfShowLegacy(in *ShowLegacy, f Visit) error { - if in == nil { +func (a *application) rewriteDDLStatement(parent SQLNode, node DDLStatement, replacer replacerFunc) error { + if node == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err + switch node := node.(type) { + case *AlterTable: + return a.rewriteRefOfAlterTable(parent, node, replacer) + case *AlterView: + return a.rewriteRefOfAlterView(parent, node, replacer) + case *CreateTable: + return a.rewriteRefOfCreateTable(parent, node, replacer) + case *CreateView: + return a.rewriteRefOfCreateView(parent, node, replacer) + case *DropTable: + return a.rewriteRefOfDropTable(parent, node, replacer) + case *DropView: + return a.rewriteRefOfDropView(parent, node, replacer) + case *RenameTable: + return a.rewriteRefOfRenameTable(parent, node, replacer) + case *TruncateTable: + return a.rewriteRefOfTruncateTable(parent, node, replacer) + default: + // this should never happen + return nil } - if err := VisitTableName(in.OnTable, f); err != nil { - return err +} +func (a *application) rewriteExplain(parent SQLNode, node Explain, replacer replacerFunc) error { + if node == nil { + return nil } - if err := VisitTableName(in.Table, f); err != nil { - return err + switch node := node.(type) { + case *ExplainStmt: + return a.rewriteRefOfExplainStmt(parent, node, replacer) + case *ExplainTab: + return a.rewriteRefOfExplainTab(parent, node, replacer) + default: + // this should never happen + return nil } - if err := VisitExpr(in.ShowCollationFilterOpt, f); err != nil { - return err +} +func (a *application) rewriteExpr(parent SQLNode, node Expr, replacer replacerFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *AndExpr: + return a.rewriteRefOfAndExpr(parent, node, replacer) + case Argument: + return a.rewriteArgument(parent, node, replacer) + case *BinaryExpr: + return a.rewriteRefOfBinaryExpr(parent, node, replacer) + case BoolVal: + return a.rewriteBoolVal(parent, node, replacer) + case *CaseExpr: + return a.rewriteRefOfCaseExpr(parent, node, replacer) + case *ColName: + return a.rewriteRefOfColName(parent, node, replacer) + case *CollateExpr: + return a.rewriteRefOfCollateExpr(parent, node, replacer) + case *ComparisonExpr: + return a.rewriteRefOfComparisonExpr(parent, node, replacer) + case *ConvertExpr: + return a.rewriteRefOfConvertExpr(parent, node, replacer) + case *ConvertUsingExpr: + return a.rewriteRefOfConvertUsingExpr(parent, node, replacer) + case *CurTimeFuncExpr: + return a.rewriteRefOfCurTimeFuncExpr(parent, node, replacer) + case *Default: + return a.rewriteRefOfDefault(parent, node, replacer) + case *ExistsExpr: + return a.rewriteRefOfExistsExpr(parent, node, replacer) + case *FuncExpr: + return a.rewriteRefOfFuncExpr(parent, node, replacer) + case *GroupConcatExpr: + return a.rewriteRefOfGroupConcatExpr(parent, node, replacer) + case *IntervalExpr: + return a.rewriteRefOfIntervalExpr(parent, node, replacer) + case *IsExpr: + return a.rewriteRefOfIsExpr(parent, node, replacer) + case ListArg: + return a.rewriteListArg(parent, node, replacer) + case *Literal: + return a.rewriteRefOfLiteral(parent, node, replacer) + case *MatchExpr: + return a.rewriteRefOfMatchExpr(parent, node, replacer) + case *NotExpr: + return a.rewriteRefOfNotExpr(parent, node, replacer) + case *NullVal: + return a.rewriteRefOfNullVal(parent, node, replacer) + case *OrExpr: + return a.rewriteRefOfOrExpr(parent, node, replacer) + case *RangeCond: + return a.rewriteRefOfRangeCond(parent, node, replacer) + case *Subquery: + return a.rewriteRefOfSubquery(parent, node, replacer) + case *SubstrExpr: + return a.rewriteRefOfSubstrExpr(parent, node, replacer) + case *TimestampFuncExpr: + return a.rewriteRefOfTimestampFuncExpr(parent, node, replacer) + case *UnaryExpr: + return a.rewriteRefOfUnaryExpr(parent, node, replacer) + case ValTuple: + return a.rewriteValTuple(parent, node, replacer) + case *ValuesFuncExpr: + return a.rewriteRefOfValuesFuncExpr(parent, node, replacer) + case *XorExpr: + return a.rewriteRefOfXorExpr(parent, node, replacer) + default: + // this should never happen + return nil } - return nil } - -// rewriteRefOfShowLegacy is part of the Rewrite implementation -func (a *application) rewriteRefOfShowLegacy(parent SQLNode, node *ShowLegacy, replacer replacerFunc) error { +func (a *application) rewriteExprs(parent SQLNode, node Exprs, replacer replacerFunc) error { if node == nil { return nil } @@ -8988,67 +9643,61 @@ func (a *application) rewriteRefOfShowLegacy(parent SQLNode, node *ShowLegacy, r if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteTableName(node, node.OnTable, func(newNode, parent SQLNode) { - parent.(*ShowLegacy).OnTable = newNode.(TableName) - }); errF != nil { - return errF - } - if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { - parent.(*ShowLegacy).Table = newNode.(TableName) - }); errF != nil { - return errF - } - if errF := a.rewriteExpr(node, node.ShowCollationFilterOpt, func(newNode, parent SQLNode) { - parent.(*ShowLegacy).ShowCollationFilterOpt = newNode.(Expr) - }); errF != nil { - return errF + for i, el := range node { + if errF := a.rewriteExpr(node, el, func(newNode, parent SQLNode) { + parent.(Exprs)[i] = newNode.(Expr) + }); errF != nil { + return errF + } } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfStarExpr does deep equals between the two objects. -func EqualsRefOfStarExpr(a, b *StarExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsTableName(a.TableName, b.TableName) -} - -// CloneRefOfStarExpr creates a deep clone of the input. -func CloneRefOfStarExpr(n *StarExpr) *StarExpr { - if n == nil { +func (a *application) rewriteGroupBy(parent SQLNode, node GroupBy, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.TableName = CloneTableName(n.TableName) - return &out -} - -// VisitRefOfStarExpr will visit all parts of the AST -func VisitRefOfStarExpr(in *StarExpr, f Visit) error { - if in == nil { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } - if cont, err := f(in); err != nil || !cont { - return err + for i, el := range node { + if errF := a.rewriteExpr(node, el, func(newNode, parent SQLNode) { + parent.(GroupBy)[i] = newNode.(Expr) + }); errF != nil { + return errF + } } - if err := VisitTableName(in.TableName, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfStarExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfStarExpr(parent SQLNode, node *StarExpr, replacer replacerFunc) error { +func (a *application) rewriteInsertRows(parent SQLNode, node InsertRows, replacer replacerFunc) error { if node == nil { return nil } + switch node := node.(type) { + case *ParenSelect: + return a.rewriteRefOfParenSelect(parent, node, replacer) + case *Select: + return a.rewriteRefOfSelect(parent, node, replacer) + case *Union: + return a.rewriteRefOfUnion(parent, node, replacer) + case Values: + return a.rewriteValues(parent, node, replacer) + default: + // this should never happen + return nil + } +} +func (a *application) rewriteIsolationLevel(parent SQLNode, node IsolationLevel, replacer replacerFunc) error { cur := Cursor{ node: node, parent: parent, @@ -9057,64 +9706,40 @@ func (a *application) rewriteRefOfStarExpr(parent SQLNode, node *StarExpr, repla if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteTableName(node, node.TableName, func(newNode, parent SQLNode) { - parent.(*StarExpr).TableName = newNode.(TableName) - }); errF != nil { - return errF - } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfStream does deep equals between the two objects. -func EqualsRefOfStream(a, b *Stream) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsComments(a.Comments, b.Comments) && - EqualsSelectExpr(a.SelectExpr, b.SelectExpr) && - EqualsTableName(a.Table, b.Table) -} - -// CloneRefOfStream creates a deep clone of the input. -func CloneRefOfStream(n *Stream) *Stream { - if n == nil { - return nil +func (a *application) rewriteJoinCondition(parent SQLNode, node JoinCondition, replacer replacerFunc) error { + var err error + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - out := *n - out.Comments = CloneComments(n.Comments) - out.SelectExpr = CloneSelectExpr(n.SelectExpr) - out.Table = CloneTableName(n.Table) - return &out -} - -// VisitRefOfStream will visit all parts of the AST -func VisitRefOfStream(in *Stream, f Visit) error { - if in == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - if cont, err := f(in); err != nil || !cont { - return err + if errF := a.rewriteExpr(node, node.On, func(newNode, parent SQLNode) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'On' on 'JoinCondition'") + }); errF != nil { + return errF } - if err := VisitComments(in.Comments, f); err != nil { - return err + if errF := a.rewriteColumns(node, node.Using, func(newNode, parent SQLNode) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Using' on 'JoinCondition'") + }); errF != nil { + return errF } - if err := VisitSelectExpr(in.SelectExpr, f); err != nil { + if err != nil { return err } - if err := VisitTableName(in.Table, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfStream is part of the Rewrite implementation -func (a *application) rewriteRefOfStream(parent SQLNode, node *Stream, replacer replacerFunc) error { +func (a *application) rewriteListArg(parent SQLNode, node ListArg, replacer replacerFunc) error { if node == nil { return nil } @@ -9126,64 +9751,36 @@ func (a *application) rewriteRefOfStream(parent SQLNode, node *Stream, replacer if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { - parent.(*Stream).Comments = newNode.(Comments) - }); errF != nil { - return errF - } - if errF := a.rewriteSelectExpr(node, node.SelectExpr, func(newNode, parent SQLNode) { - parent.(*Stream).SelectExpr = newNode.(SelectExpr) - }); errF != nil { - return errF - } - if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { - parent.(*Stream).Table = newNode.(TableName) - }); errF != nil { - return errF - } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfSubquery does deep equals between the two objects. -func EqualsRefOfSubquery(a, b *Subquery) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsSelectStatement(a.Select, b.Select) -} - -// CloneRefOfSubquery creates a deep clone of the input. -func CloneRefOfSubquery(n *Subquery) *Subquery { - if n == nil { +func (a *application) rewriteOnDup(parent SQLNode, node OnDup, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.Select = CloneSelectStatement(n.Select) - return &out -} - -// VisitRefOfSubquery will visit all parts of the AST -func VisitRefOfSubquery(in *Subquery, f Visit) error { - if in == nil { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } - if cont, err := f(in); err != nil || !cont { - return err + for i, el := range node { + if errF := a.rewriteRefOfUpdateExpr(node, el, func(newNode, parent SQLNode) { + parent.(OnDup)[i] = newNode.(*UpdateExpr) + }); errF != nil { + return errF + } } - if err := VisitSelectStatement(in.Select, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfSubquery is part of the Rewrite implementation -func (a *application) rewriteRefOfSubquery(parent SQLNode, node *Subquery, replacer replacerFunc) error { +func (a *application) rewriteOrderBy(parent SQLNode, node OrderBy, replacer replacerFunc) error { if node == nil { return nil } @@ -9195,69 +9792,43 @@ func (a *application) rewriteRefOfSubquery(parent SQLNode, node *Subquery, repla if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { - parent.(*Subquery).Select = newNode.(SelectStatement) - }); errF != nil { - return errF + for i, el := range node { + if errF := a.rewriteRefOfOrder(node, el, func(newNode, parent SQLNode) { + parent.(OrderBy)[i] = newNode.(*Order) + }); errF != nil { + return errF + } } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfSubstrExpr does deep equals between the two objects. -func EqualsRefOfSubstrExpr(a, b *SubstrExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsRefOfColName(a.Name, b.Name) && - EqualsRefOfLiteral(a.StrVal, b.StrVal) && - EqualsExpr(a.From, b.From) && - EqualsExpr(a.To, b.To) -} - -// CloneRefOfSubstrExpr creates a deep clone of the input. -func CloneRefOfSubstrExpr(n *SubstrExpr) *SubstrExpr { - if n == nil { - return nil - } - out := *n - out.Name = CloneRefOfColName(n.Name) - out.StrVal = CloneRefOfLiteral(n.StrVal) - out.From = CloneExpr(n.From) - out.To = CloneExpr(n.To) - return &out -} - -// VisitRefOfSubstrExpr will visit all parts of the AST -func VisitRefOfSubstrExpr(in *SubstrExpr, f Visit) error { - if in == nil { +func (a *application) rewritePartitions(parent SQLNode, node Partitions, replacer replacerFunc) error { + if node == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitRefOfColName(in.Name, f); err != nil { - return err + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if err := VisitRefOfLiteral(in.StrVal, f); err != nil { - return err + if a.pre != nil && !a.pre(&cur) { + return nil } - if err := VisitExpr(in.From, f); err != nil { - return err + for i, el := range node { + if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { + parent.(Partitions)[i] = newNode.(ColIdent) + }); errF != nil { + return errF + } } - if err := VisitExpr(in.To, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfSubstrExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfSubstrExpr(parent SQLNode, node *SubstrExpr, replacer replacerFunc) error { +func (a *application) rewriteRefOfAddColumns(parent SQLNode, node *AddColumns, replacer replacerFunc) error { if node == nil { return nil } @@ -9269,72 +9840,29 @@ func (a *application) rewriteRefOfSubstrExpr(parent SQLNode, node *SubstrExpr, r if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { - parent.(*SubstrExpr).Name = newNode.(*ColName) - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfLiteral(node, node.StrVal, func(newNode, parent SQLNode) { - parent.(*SubstrExpr).StrVal = newNode.(*Literal) - }); errF != nil { - return errF + for i, el := range node.Columns { + if errF := a.rewriteRefOfColumnDefinition(node, el, func(newNode, parent SQLNode) { + parent.(*AddColumns).Columns[i] = newNode.(*ColumnDefinition) + }); errF != nil { + return errF + } } - if errF := a.rewriteExpr(node, node.From, func(newNode, parent SQLNode) { - parent.(*SubstrExpr).From = newNode.(Expr) + if errF := a.rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { + parent.(*AddColumns).First = newNode.(*ColName) }); errF != nil { return errF } - if errF := a.rewriteExpr(node, node.To, func(newNode, parent SQLNode) { - parent.(*SubstrExpr).To = newNode.(Expr) + if errF := a.rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { + parent.(*AddColumns).After = newNode.(*ColName) }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil -} - -// EqualsTableExprs does deep equals between the two objects. -func EqualsTableExprs(a, b TableExprs) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsTableExpr(a[i], b[i]) { - return false - } - } - return true -} - -// CloneTableExprs creates a deep clone of the input. -func CloneTableExprs(n TableExprs) TableExprs { - res := make(TableExprs, 0, len(n)) - for _, x := range n { - res = append(res, CloneTableExpr(x)) - } - return res -} - -// VisitTableExprs will visit all parts of the AST -func VisitTableExprs(in TableExprs, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - for _, el := range in { - if err := VisitTableExpr(el, f); err != nil { - return err - } - } + if a.post != nil && !a.post(&cur) { + return errAbort + } return nil } - -// rewriteTableExprs is part of the Rewrite implementation -func (a *application) rewriteTableExprs(parent SQLNode, node TableExprs, replacer replacerFunc) error { +func (a *application) rewriteRefOfAddConstraintDefinition(parent SQLNode, node *AddConstraintDefinition, replacer replacerFunc) error { if node == nil { return nil } @@ -9346,40 +9874,20 @@ func (a *application) rewriteTableExprs(parent SQLNode, node TableExprs, replace if a.pre != nil && !a.pre(&cur) { return nil } - for i, el := range node { - if errF := a.rewriteTableExpr(node, el, func(newNode, parent SQLNode) { - parent.(TableExprs)[i] = newNode.(TableExpr) - }); errF != nil { - return errF - } + if errF := a.rewriteRefOfConstraintDefinition(node, node.ConstraintDefinition, func(newNode, parent SQLNode) { + parent.(*AddConstraintDefinition).ConstraintDefinition = newNode.(*ConstraintDefinition) + }); errF != nil { + return errF } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsTableIdent does deep equals between the two objects. -func EqualsTableIdent(a, b TableIdent) bool { - return a.v == b.v -} - -// CloneTableIdent creates a deep clone of the input. -func CloneTableIdent(n TableIdent) TableIdent { - return *CloneRefOfTableIdent(&n) -} - -// VisitTableIdent will visit all parts of the AST -func VisitTableIdent(in TableIdent, f Visit) error { - if cont, err := f(in); err != nil || !cont { - return err +func (a *application) rewriteRefOfAddIndexDefinition(parent SQLNode, node *AddIndexDefinition, replacer replacerFunc) error { + if node == nil { + return nil } - return nil -} - -// rewriteTableIdent is part of the Rewrite implementation -func (a *application) rewriteTableIdent(parent SQLNode, node TableIdent, replacer replacerFunc) error { - var err error cur := Cursor{ node: node, parent: parent, @@ -9388,43 +9896,20 @@ func (a *application) rewriteTableIdent(parent SQLNode, node TableIdent, replace if a.pre != nil && !a.pre(&cur) { return nil } - if err != nil { - return err + if errF := a.rewriteRefOfIndexDefinition(node, node.IndexDefinition, func(newNode, parent SQLNode) { + parent.(*AddIndexDefinition).IndexDefinition = newNode.(*IndexDefinition) + }); errF != nil { + return errF } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsTableName does deep equals between the two objects. -func EqualsTableName(a, b TableName) bool { - return EqualsTableIdent(a.Name, b.Name) && - EqualsTableIdent(a.Qualifier, b.Qualifier) -} - -// CloneTableName creates a deep clone of the input. -func CloneTableName(n TableName) TableName { - return *CloneRefOfTableName(&n) -} - -// VisitTableName will visit all parts of the AST -func VisitTableName(in TableName, f Visit) error { - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableIdent(in.Name, f); err != nil { - return err - } - if err := VisitTableIdent(in.Qualifier, f); err != nil { - return err +func (a *application) rewriteRefOfAliasedExpr(parent SQLNode, node *AliasedExpr, replacer replacerFunc) error { + if node == nil { + return nil } - return nil -} - -// rewriteTableName is part of the Rewrite implementation -func (a *application) rewriteTableName(parent SQLNode, node TableName, replacer replacerFunc) error { - var err error cur := Cursor{ node: node, parent: parent, @@ -9433,65 +9918,59 @@ func (a *application) rewriteTableName(parent SQLNode, node TableName, replacer if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { - err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Name' on 'TableName'") + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*AliasedExpr).Expr = newNode.(Expr) }); errF != nil { return errF } - if errF := a.rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { - err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Qualifier' on 'TableName'") + if errF := a.rewriteColIdent(node, node.As, func(newNode, parent SQLNode) { + parent.(*AliasedExpr).As = newNode.(ColIdent) }); errF != nil { return errF } - if err != nil { - return err - } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsTableNames does deep equals between the two objects. -func EqualsTableNames(a, b TableNames) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsTableName(a[i], b[i]) { - return false - } +func (a *application) rewriteRefOfAliasedTableExpr(parent SQLNode, node *AliasedTableExpr, replacer replacerFunc) error { + if node == nil { + return nil } - return true -} - -// CloneTableNames creates a deep clone of the input. -func CloneTableNames(n TableNames) TableNames { - res := make(TableNames, 0, len(n)) - for _, x := range n { - res = append(res, CloneTableName(x)) + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return res -} - -// VisitTableNames will visit all parts of the AST -func VisitTableNames(in TableNames, f Visit) error { - if in == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - if cont, err := f(in); err != nil || !cont { - return err + if errF := a.rewriteSimpleTableExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).Expr = newNode.(SimpleTableExpr) + }); errF != nil { + return errF } - for _, el := range in { - if err := VisitTableName(el, f); err != nil { - return err - } + if errF := a.rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).Partitions = newNode.(Partitions) + }); errF != nil { + return errF + } + if errF := a.rewriteTableIdent(node, node.As, func(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).As = newNode.(TableIdent) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfIndexHints(node, node.Hints, func(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).Hints = newNode.(*IndexHints) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteTableNames is part of the Rewrite implementation -func (a *application) rewriteTableNames(parent SQLNode, node TableNames, replacer replacerFunc) error { +func (a *application) rewriteRefOfAlterCharset(parent SQLNode, node *AlterCharset, replacer replacerFunc) error { if node == nil { return nil } @@ -9503,49 +9982,12 @@ func (a *application) rewriteTableNames(parent SQLNode, node TableNames, replace if a.pre != nil && !a.pre(&cur) { return nil } - for i, el := range node { - if errF := a.rewriteTableName(node, el, func(newNode, parent SQLNode) { - parent.(TableNames)[i] = newNode.(TableName) - }); errF != nil { - return errF - } - } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsTableOptions does deep equals between the two objects. -func EqualsTableOptions(a, b TableOptions) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsRefOfTableOption(a[i], b[i]) { - return false - } - } - return true -} - -// CloneTableOptions creates a deep clone of the input. -func CloneTableOptions(n TableOptions) TableOptions { - res := make(TableOptions, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfTableOption(x)) - } - return res -} - -// VisitTableOptions will visit all parts of the AST -func VisitTableOptions(in TableOptions, f Visit) error { - _, err := f(in) - return err -} - -// rewriteTableOptions is part of the Rewrite implementation -func (a *application) rewriteTableOptions(parent SQLNode, node TableOptions, replacer replacerFunc) error { +func (a *application) rewriteRefOfAlterColumn(parent SQLNode, node *AlterColumn, replacer replacerFunc) error { if node == nil { return nil } @@ -9557,70 +9999,56 @@ func (a *application) rewriteTableOptions(parent SQLNode, node TableOptions, rep if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteRefOfColName(node, node.Column, func(newNode, parent SQLNode) { + parent.(*AlterColumn).Column = newNode.(*ColName) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.DefaultVal, func(newNode, parent SQLNode) { + parent.(*AlterColumn).DefaultVal = newNode.(Expr) + }); errF != nil { + return errF + } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfTableSpec does deep equals between the two objects. -func EqualsRefOfTableSpec(a, b *TableSpec) bool { - if a == b { - return true +func (a *application) rewriteRefOfAlterDatabase(parent SQLNode, node *AlterDatabase, replacer replacerFunc) error { + if node == nil { + return nil } - if a == nil || b == nil { - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return EqualsSliceOfRefOfColumnDefinition(a.Columns, b.Columns) && - EqualsSliceOfRefOfIndexDefinition(a.Indexes, b.Indexes) && - EqualsSliceOfRefOfConstraintDefinition(a.Constraints, b.Constraints) && - EqualsTableOptions(a.Options, b.Options) -} - -// CloneRefOfTableSpec creates a deep clone of the input. -func CloneRefOfTableSpec(n *TableSpec) *TableSpec { - if n == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - out := *n - out.Columns = CloneSliceOfRefOfColumnDefinition(n.Columns) - out.Indexes = CloneSliceOfRefOfIndexDefinition(n.Indexes) - out.Constraints = CloneSliceOfRefOfConstraintDefinition(n.Constraints) - out.Options = CloneTableOptions(n.Options) - return &out + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// VisitRefOfTableSpec will visit all parts of the AST -func VisitRefOfTableSpec(in *TableSpec, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfAlterMigration(parent SQLNode, node *AlterMigration, replacer replacerFunc) error { + if node == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - for _, el := range in.Columns { - if err := VisitRefOfColumnDefinition(el, f); err != nil { - return err - } - } - for _, el := range in.Indexes { - if err := VisitRefOfIndexDefinition(el, f); err != nil { - return err - } - } - for _, el := range in.Constraints { - if err := VisitRefOfConstraintDefinition(el, f); err != nil { - return err - } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if err := VisitTableOptions(in.Options, f); err != nil { - return err + if a.pre != nil && !a.pre(&cur) { + return nil + } + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfTableSpec is part of the Rewrite implementation -func (a *application) rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, replacer replacerFunc) error { +func (a *application) rewriteRefOfAlterTable(parent SQLNode, node *AlterTable, replacer replacerFunc) error { if node == nil { return nil } @@ -9632,29 +10060,20 @@ func (a *application) rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, rep if a.pre != nil && !a.pre(&cur) { return nil } - for i, el := range node.Columns { - if errF := a.rewriteRefOfColumnDefinition(node, el, func(newNode, parent SQLNode) { - parent.(*TableSpec).Columns[i] = newNode.(*ColumnDefinition) - }); errF != nil { - return errF - } - } - for i, el := range node.Indexes { - if errF := a.rewriteRefOfIndexDefinition(node, el, func(newNode, parent SQLNode) { - parent.(*TableSpec).Indexes[i] = newNode.(*IndexDefinition) - }); errF != nil { - return errF - } + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*AlterTable).Table = newNode.(TableName) + }); errF != nil { + return errF } - for i, el := range node.Constraints { - if errF := a.rewriteRefOfConstraintDefinition(node, el, func(newNode, parent SQLNode) { - parent.(*TableSpec).Constraints[i] = newNode.(*ConstraintDefinition) + for i, el := range node.AlterOptions { + if errF := a.rewriteAlterOption(node, el, func(newNode, parent SQLNode) { + parent.(*AlterTable).AlterOptions[i] = newNode.(AlterOption) }); errF != nil { return errF } } - if errF := a.rewriteTableOptions(node, node.Options, func(newNode, parent SQLNode) { - parent.(*TableSpec).Options = newNode.(TableOptions) + if errF := a.rewriteRefOfPartitionSpec(node, node.PartitionSpec, func(newNode, parent SQLNode) { + parent.(*AlterTable).PartitionSpec = newNode.(*PartitionSpec) }); errF != nil { return errF } @@ -9663,40 +10082,39 @@ func (a *application) rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, rep } return nil } - -// EqualsRefOfTablespaceOperation does deep equals between the two objects. -func EqualsRefOfTablespaceOperation(a, b *TablespaceOperation) bool { - if a == b { - return true +func (a *application) rewriteRefOfAlterView(parent SQLNode, node *AlterView, replacer replacerFunc) error { + if node == nil { + return nil } - if a == nil || b == nil { - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return a.Import == b.Import -} - -// CloneRefOfTablespaceOperation creates a deep clone of the input. -func CloneRefOfTablespaceOperation(n *TablespaceOperation) *TablespaceOperation { - if n == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - out := *n - return &out -} - -// VisitRefOfTablespaceOperation will visit all parts of the AST -func VisitRefOfTablespaceOperation(in *TablespaceOperation, f Visit) error { - if in == nil { - return nil + if errF := a.rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { + parent.(*AlterView).ViewName = newNode.(TableName) + }); errF != nil { + return errF } - if cont, err := f(in); err != nil || !cont { - return err + if errF := a.rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { + parent.(*AlterView).Columns = newNode.(Columns) + }); errF != nil { + return errF + } + if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + parent.(*AlterView).Select = newNode.(SelectStatement) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfTablespaceOperation is part of the Rewrite implementation -func (a *application) rewriteRefOfTablespaceOperation(parent SQLNode, node *TablespaceOperation, replacer replacerFunc) error { +func (a *application) rewriteRefOfAlterVschema(parent SQLNode, node *AlterVschema, replacer replacerFunc) error { if node == nil { return nil } @@ -9708,56 +10126,61 @@ func (a *application) rewriteRefOfTablespaceOperation(parent SQLNode, node *Tabl if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*AlterVschema).Table = newNode.(TableName) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfVindexSpec(node, node.VindexSpec, func(newNode, parent SQLNode) { + parent.(*AlterVschema).VindexSpec = newNode.(*VindexSpec) + }); errF != nil { + return errF + } + for i, el := range node.VindexCols { + if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { + parent.(*AlterVschema).VindexCols[i] = newNode.(ColIdent) + }); errF != nil { + return errF + } + } + if errF := a.rewriteRefOfAutoIncSpec(node, node.AutoIncSpec, func(newNode, parent SQLNode) { + parent.(*AlterVschema).AutoIncSpec = newNode.(*AutoIncSpec) + }); errF != nil { + return errF + } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfTimestampFuncExpr does deep equals between the two objects. -func EqualsRefOfTimestampFuncExpr(a, b *TimestampFuncExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Name == b.Name && - a.Unit == b.Unit && - EqualsExpr(a.Expr1, b.Expr1) && - EqualsExpr(a.Expr2, b.Expr2) -} - -// CloneRefOfTimestampFuncExpr creates a deep clone of the input. -func CloneRefOfTimestampFuncExpr(n *TimestampFuncExpr) *TimestampFuncExpr { - if n == nil { +func (a *application) rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.Expr1 = CloneExpr(n.Expr1) - out.Expr2 = CloneExpr(n.Expr2) - return &out -} - -// VisitRefOfTimestampFuncExpr will visit all parts of the AST -func VisitRefOfTimestampFuncExpr(in *TimestampFuncExpr, f Visit) error { - if in == nil { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } - if cont, err := f(in); err != nil || !cont { - return err + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*AndExpr).Left = newNode.(Expr) + }); errF != nil { + return errF } - if err := VisitExpr(in.Expr1, f); err != nil { - return err + if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + parent.(*AndExpr).Right = newNode.(Expr) + }); errF != nil { + return errF } - if err := VisitExpr(in.Expr2, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfTimestampFuncExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfTimestampFuncExpr(parent SQLNode, node *TimestampFuncExpr, replacer replacerFunc) error { +func (a *application) rewriteRefOfAutoIncSpec(parent SQLNode, node *AutoIncSpec, replacer replacerFunc) error { if node == nil { return nil } @@ -9769,13 +10192,13 @@ func (a *application) rewriteRefOfTimestampFuncExpr(parent SQLNode, node *Timest if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteExpr(node, node.Expr1, func(newNode, parent SQLNode) { - parent.(*TimestampFuncExpr).Expr1 = newNode.(Expr) + if errF := a.rewriteColIdent(node, node.Column, func(newNode, parent SQLNode) { + parent.(*AutoIncSpec).Column = newNode.(ColIdent) }); errF != nil { return errF } - if errF := a.rewriteExpr(node, node.Expr2, func(newNode, parent SQLNode) { - parent.(*TimestampFuncExpr).Expr2 = newNode.(Expr) + if errF := a.rewriteTableName(node, node.Sequence, func(newNode, parent SQLNode) { + parent.(*AutoIncSpec).Sequence = newNode.(TableName) }); errF != nil { return errF } @@ -9784,44 +10207,24 @@ func (a *application) rewriteRefOfTimestampFuncExpr(parent SQLNode, node *Timest } return nil } - -// EqualsRefOfTruncateTable does deep equals between the two objects. -func EqualsRefOfTruncateTable(a, b *TruncateTable) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsTableName(a.Table, b.Table) -} - -// CloneRefOfTruncateTable creates a deep clone of the input. -func CloneRefOfTruncateTable(n *TruncateTable) *TruncateTable { - if n == nil { +func (a *application) rewriteRefOfBegin(parent SQLNode, node *Begin, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.Table = CloneTableName(n.Table) - return &out -} - -// VisitRefOfTruncateTable will visit all parts of the AST -func VisitRefOfTruncateTable(in *TruncateTable, f Visit) error { - if in == nil { - return nil + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if cont, err := f(in); err != nil || !cont { - return err + if a.pre != nil && !a.pre(&cur) { + return nil } - if err := VisitTableName(in.Table, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfTruncateTable is part of the Rewrite implementation -func (a *application) rewriteRefOfTruncateTable(parent SQLNode, node *TruncateTable, replacer replacerFunc) error { +func (a *application) rewriteRefOfBinaryExpr(parent SQLNode, node *BinaryExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -9833,8 +10236,13 @@ func (a *application) rewriteRefOfTruncateTable(parent SQLNode, node *TruncateTa if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { - parent.(*TruncateTable).Table = newNode.(TableName) + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*BinaryExpr).Left = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + parent.(*BinaryExpr).Right = newNode.(Expr) }); errF != nil { return errF } @@ -9843,45 +10251,34 @@ func (a *application) rewriteRefOfTruncateTable(parent SQLNode, node *TruncateTa } return nil } - -// EqualsRefOfUnaryExpr does deep equals between the two objects. -func EqualsRefOfUnaryExpr(a, b *UnaryExpr) bool { - if a == b { - return true +func (a *application) rewriteRefOfCallProc(parent SQLNode, node *CallProc, replacer replacerFunc) error { + if node == nil { + return nil } - if a == nil || b == nil { - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return a.Operator == b.Operator && - EqualsExpr(a.Expr, b.Expr) -} - -// CloneRefOfUnaryExpr creates a deep clone of the input. -func CloneRefOfUnaryExpr(n *UnaryExpr) *UnaryExpr { - if n == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - out := *n - out.Expr = CloneExpr(n.Expr) - return &out -} - -// VisitRefOfUnaryExpr will visit all parts of the AST -func VisitRefOfUnaryExpr(in *UnaryExpr, f Visit) error { - if in == nil { - return nil + if errF := a.rewriteTableName(node, node.Name, func(newNode, parent SQLNode) { + parent.(*CallProc).Name = newNode.(TableName) + }); errF != nil { + return errF } - if cont, err := f(in); err != nil || !cont { - return err + if errF := a.rewriteExprs(node, node.Params, func(newNode, parent SQLNode) { + parent.(*CallProc).Params = newNode.(Exprs) + }); errF != nil { + return errF } - if err := VisitExpr(in.Expr, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfUnaryExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfUnaryExpr(parent SQLNode, node *UnaryExpr, replacer replacerFunc) error { +func (a *application) rewriteRefOfCaseExpr(parent SQLNode, node *CaseExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -9894,71 +10291,28 @@ func (a *application) rewriteRefOfUnaryExpr(parent SQLNode, node *UnaryExpr, rep return nil } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*UnaryExpr).Expr = newNode.(Expr) + parent.(*CaseExpr).Expr = newNode.(Expr) }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { - return errAbort - } - return nil -} - -// EqualsRefOfUnion does deep equals between the two objects. -func EqualsRefOfUnion(a, b *Union) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsSelectStatement(a.FirstStatement, b.FirstStatement) && - EqualsSliceOfRefOfUnionSelect(a.UnionSelects, b.UnionSelects) && - EqualsOrderBy(a.OrderBy, b.OrderBy) && - EqualsRefOfLimit(a.Limit, b.Limit) && - a.Lock == b.Lock -} - -// CloneRefOfUnion creates a deep clone of the input. -func CloneRefOfUnion(n *Union) *Union { - if n == nil { - return nil - } - out := *n - out.FirstStatement = CloneSelectStatement(n.FirstStatement) - out.UnionSelects = CloneSliceOfRefOfUnionSelect(n.UnionSelects) - out.OrderBy = CloneOrderBy(n.OrderBy) - out.Limit = CloneRefOfLimit(n.Limit) - return &out -} - -// VisitRefOfUnion will visit all parts of the AST -func VisitRefOfUnion(in *Union, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitSelectStatement(in.FirstStatement, f); err != nil { - return err - } - for _, el := range in.UnionSelects { - if err := VisitRefOfUnionSelect(el, f); err != nil { - return err + for i, el := range node.Whens { + if errF := a.rewriteRefOfWhen(node, el, func(newNode, parent SQLNode) { + parent.(*CaseExpr).Whens[i] = newNode.(*When) + }); errF != nil { + return errF } } - if err := VisitOrderBy(in.OrderBy, f); err != nil { - return err + if errF := a.rewriteExpr(node, node.Else, func(newNode, parent SQLNode) { + parent.(*CaseExpr).Else = newNode.(Expr) + }); errF != nil { + return errF } - if err := VisitRefOfLimit(in.Limit, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfUnion is part of the Rewrite implementation -func (a *application) rewriteRefOfUnion(parent SQLNode, node *Union, replacer replacerFunc) error { +func (a *application) rewriteRefOfChangeColumn(parent SQLNode, node *ChangeColumn, replacer replacerFunc) error { if node == nil { return nil } @@ -9970,25 +10324,23 @@ func (a *application) rewriteRefOfUnion(parent SQLNode, node *Union, replacer re if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteSelectStatement(node, node.FirstStatement, func(newNode, parent SQLNode) { - parent.(*Union).FirstStatement = newNode.(SelectStatement) + if errF := a.rewriteRefOfColName(node, node.OldColumn, func(newNode, parent SQLNode) { + parent.(*ChangeColumn).OldColumn = newNode.(*ColName) }); errF != nil { return errF } - for i, el := range node.UnionSelects { - if errF := a.rewriteRefOfUnionSelect(node, el, func(newNode, parent SQLNode) { - parent.(*Union).UnionSelects[i] = newNode.(*UnionSelect) - }); errF != nil { - return errF - } + if errF := a.rewriteRefOfColumnDefinition(node, node.NewColDefinition, func(newNode, parent SQLNode) { + parent.(*ChangeColumn).NewColDefinition = newNode.(*ColumnDefinition) + }); errF != nil { + return errF } - if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { - parent.(*Union).OrderBy = newNode.(OrderBy) + if errF := a.rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { + parent.(*ChangeColumn).First = newNode.(*ColName) }); errF != nil { return errF } - if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { - parent.(*Union).Limit = newNode.(*Limit) + if errF := a.rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { + parent.(*ChangeColumn).After = newNode.(*ColName) }); errF != nil { return errF } @@ -9997,45 +10349,29 @@ func (a *application) rewriteRefOfUnion(parent SQLNode, node *Union, replacer re } return nil } - -// EqualsRefOfUnionSelect does deep equals between the two objects. -func EqualsRefOfUnionSelect(a, b *UnionSelect) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Distinct == b.Distinct && - EqualsSelectStatement(a.Statement, b.Statement) -} - -// CloneRefOfUnionSelect creates a deep clone of the input. -func CloneRefOfUnionSelect(n *UnionSelect) *UnionSelect { - if n == nil { +func (a *application) rewriteRefOfCheckConstraintDefinition(parent SQLNode, node *CheckConstraintDefinition, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.Statement = CloneSelectStatement(n.Statement) - return &out -} - -// VisitRefOfUnionSelect will visit all parts of the AST -func VisitRefOfUnionSelect(in *UnionSelect, f Visit) error { - if in == nil { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } - if cont, err := f(in); err != nil || !cont { - return err + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*CheckConstraintDefinition).Expr = newNode.(Expr) + }); errF != nil { + return errF } - if err := VisitSelectStatement(in.Statement, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfUnionSelect is part of the Rewrite implementation -func (a *application) rewriteRefOfUnionSelect(parent SQLNode, node *UnionSelect, replacer replacerFunc) error { +func (a *application) rewriteRefOfColIdent(parent SQLNode, node *ColIdent, replacer replacerFunc) error { if node == nil { return nil } @@ -10047,50 +10383,39 @@ func (a *application) rewriteRefOfUnionSelect(parent SQLNode, node *UnionSelect, if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteSelectStatement(node, node.Statement, func(newNode, parent SQLNode) { - parent.(*UnionSelect).Statement = newNode.(SelectStatement) - }); errF != nil { - return errF - } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfUnlockTables does deep equals between the two objects. -func EqualsRefOfUnlockTables(a, b *UnlockTables) bool { - if a == b { - return true +func (a *application) rewriteRefOfColName(parent SQLNode, node *ColName, replacer replacerFunc) error { + if node == nil { + return nil } - if a == nil || b == nil { - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return true -} - -// CloneRefOfUnlockTables creates a deep clone of the input. -func CloneRefOfUnlockTables(n *UnlockTables) *UnlockTables { - if n == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - out := *n - return &out -} - -// VisitRefOfUnlockTables will visit all parts of the AST -func VisitRefOfUnlockTables(in *UnlockTables, f Visit) error { - if in == nil { - return nil + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*ColName).Name = newNode.(ColIdent) + }); errF != nil { + return errF } - if cont, err := f(in); err != nil || !cont { - return err + if errF := a.rewriteTableName(node, node.Qualifier, func(newNode, parent SQLNode) { + parent.(*ColName).Qualifier = newNode.(TableName) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfUnlockTables is part of the Rewrite implementation -func (a *application) rewriteRefOfUnlockTables(parent SQLNode, node *UnlockTables, replacer replacerFunc) error { +func (a *application) rewriteRefOfCollateExpr(parent SQLNode, node *CollateExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -10102,75 +10427,39 @@ func (a *application) rewriteRefOfUnlockTables(parent SQLNode, node *UnlockTable if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*CollateExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfUpdate does deep equals between the two objects. -func EqualsRefOfUpdate(a, b *Update) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsComments(a.Comments, b.Comments) && - a.Ignore == b.Ignore && - EqualsTableExprs(a.TableExprs, b.TableExprs) && - EqualsUpdateExprs(a.Exprs, b.Exprs) && - EqualsRefOfWhere(a.Where, b.Where) && - EqualsOrderBy(a.OrderBy, b.OrderBy) && - EqualsRefOfLimit(a.Limit, b.Limit) -} - -// CloneRefOfUpdate creates a deep clone of the input. -func CloneRefOfUpdate(n *Update) *Update { - if n == nil { - return nil - } - out := *n - out.Comments = CloneComments(n.Comments) - out.TableExprs = CloneTableExprs(n.TableExprs) - out.Exprs = CloneUpdateExprs(n.Exprs) - out.Where = CloneRefOfWhere(n.Where) - out.OrderBy = CloneOrderBy(n.OrderBy) - out.Limit = CloneRefOfLimit(n.Limit) - return &out -} - -// VisitRefOfUpdate will visit all parts of the AST -func VisitRefOfUpdate(in *Update, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfColumnDefinition(parent SQLNode, node *ColumnDefinition, replacer replacerFunc) error { + if node == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitComments(in.Comments, f); err != nil { - return err - } - if err := VisitTableExprs(in.TableExprs, f); err != nil { - return err - } - if err := VisitUpdateExprs(in.Exprs, f); err != nil { - return err + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if err := VisitRefOfWhere(in.Where, f); err != nil { - return err + if a.pre != nil && !a.pre(&cur) { + return nil } - if err := VisitOrderBy(in.OrderBy, f); err != nil { - return err + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*ColumnDefinition).Name = newNode.(ColIdent) + }); errF != nil { + return errF } - if err := VisitRefOfLimit(in.Limit, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfUpdate is part of the Rewrite implementation -func (a *application) rewriteRefOfUpdate(parent SQLNode, node *Update, replacer replacerFunc) error { +func (a *application) rewriteRefOfColumnType(parent SQLNode, node *ColumnType, replacer replacerFunc) error { if node == nil { return nil } @@ -10182,33 +10471,13 @@ func (a *application) rewriteRefOfUpdate(parent SQLNode, node *Update, replacer if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { - parent.(*Update).Comments = newNode.(Comments) - }); errF != nil { - return errF - } - if errF := a.rewriteTableExprs(node, node.TableExprs, func(newNode, parent SQLNode) { - parent.(*Update).TableExprs = newNode.(TableExprs) - }); errF != nil { - return errF - } - if errF := a.rewriteUpdateExprs(node, node.Exprs, func(newNode, parent SQLNode) { - parent.(*Update).Exprs = newNode.(UpdateExprs) - }); errF != nil { - return errF - } - if errF := a.rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { - parent.(*Update).Where = newNode.(*Where) - }); errF != nil { - return errF - } - if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { - parent.(*Update).OrderBy = newNode.(OrderBy) + if errF := a.rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { + parent.(*ColumnType).Length = newNode.(*Literal) }); errF != nil { return errF } - if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { - parent.(*Update).Limit = newNode.(*Limit) + if errF := a.rewriteRefOfLiteral(node, node.Scale, func(newNode, parent SQLNode) { + parent.(*ColumnType).Scale = newNode.(*Literal) }); errF != nil { return errF } @@ -10217,49 +10486,24 @@ func (a *application) rewriteRefOfUpdate(parent SQLNode, node *Update, replacer } return nil } - -// EqualsRefOfUpdateExpr does deep equals between the two objects. -func EqualsRefOfUpdateExpr(a, b *UpdateExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsRefOfColName(a.Name, b.Name) && - EqualsExpr(a.Expr, b.Expr) -} - -// CloneRefOfUpdateExpr creates a deep clone of the input. -func CloneRefOfUpdateExpr(n *UpdateExpr) *UpdateExpr { - if n == nil { - return nil - } - out := *n - out.Name = CloneRefOfColName(n.Name) - out.Expr = CloneExpr(n.Expr) - return &out -} - -// VisitRefOfUpdateExpr will visit all parts of the AST -func VisitRefOfUpdateExpr(in *UpdateExpr, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfCommit(parent SQLNode, node *Commit, replacer replacerFunc) error { + if node == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if err := VisitRefOfColName(in.Name, f); err != nil { - return err + if a.pre != nil && !a.pre(&cur) { + return nil } - if err := VisitExpr(in.Expr, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfUpdateExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfUpdateExpr(parent SQLNode, node *UpdateExpr, replacer replacerFunc) error { +func (a *application) rewriteRefOfComparisonExpr(parent SQLNode, node *ComparisonExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -10271,13 +10515,18 @@ func (a *application) rewriteRefOfUpdateExpr(parent SQLNode, node *UpdateExpr, r if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { - parent.(*UpdateExpr).Name = newNode.(*ColName) + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*ComparisonExpr).Left = newNode.(Expr) }); errF != nil { return errF } - if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*UpdateExpr).Expr = newNode.(Expr) + if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + parent.(*ComparisonExpr).Right = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Escape, func(newNode, parent SQLNode) { + parent.(*ComparisonExpr).Escape = newNode.(Expr) }); errF != nil { return errF } @@ -10286,47 +10535,29 @@ func (a *application) rewriteRefOfUpdateExpr(parent SQLNode, node *UpdateExpr, r } return nil } - -// EqualsUpdateExprs does deep equals between the two objects. -func EqualsUpdateExprs(a, b UpdateExprs) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsRefOfUpdateExpr(a[i], b[i]) { - return false - } +func (a *application) rewriteRefOfConstraintDefinition(parent SQLNode, node *ConstraintDefinition, replacer replacerFunc) error { + if node == nil { + return nil } - return true -} - -// CloneUpdateExprs creates a deep clone of the input. -func CloneUpdateExprs(n UpdateExprs) UpdateExprs { - res := make(UpdateExprs, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfUpdateExpr(x)) + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return res -} - -// VisitUpdateExprs will visit all parts of the AST -func VisitUpdateExprs(in UpdateExprs, f Visit) error { - if in == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - if cont, err := f(in); err != nil || !cont { - return err + if errF := a.rewriteConstraintInfo(node, node.Details, func(newNode, parent SQLNode) { + parent.(*ConstraintDefinition).Details = newNode.(ConstraintInfo) + }); errF != nil { + return errF } - for _, el := range in { - if err := VisitRefOfUpdateExpr(el, f); err != nil { - return err - } + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteUpdateExprs is part of the Rewrite implementation -func (a *application) rewriteUpdateExprs(parent SQLNode, node UpdateExprs, replacer replacerFunc) error { +func (a *application) rewriteRefOfConvertExpr(parent SQLNode, node *ConvertExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -10338,56 +10569,49 @@ func (a *application) rewriteUpdateExprs(parent SQLNode, node UpdateExprs, repla if a.pre != nil && !a.pre(&cur) { return nil } - for i, el := range node { - if errF := a.rewriteRefOfUpdateExpr(node, el, func(newNode, parent SQLNode) { - parent.(UpdateExprs)[i] = newNode.(*UpdateExpr) - }); errF != nil { - return errF - } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*ConvertExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfConvertType(node, node.Type, func(newNode, parent SQLNode) { + parent.(*ConvertExpr).Type = newNode.(*ConvertType) + }); errF != nil { + return errF } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfUse does deep equals between the two objects. -func EqualsRefOfUse(a, b *Use) bool { - if a == b { - return true +func (a *application) rewriteRefOfConvertType(parent SQLNode, node *ConvertType, replacer replacerFunc) error { + if node == nil { + return nil } - if a == nil || b == nil { - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return EqualsTableIdent(a.DBName, b.DBName) -} - -// CloneRefOfUse creates a deep clone of the input. -func CloneRefOfUse(n *Use) *Use { - if n == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - out := *n - out.DBName = CloneTableIdent(n.DBName) - return &out -} - -// VisitRefOfUse will visit all parts of the AST -func VisitRefOfUse(in *Use, f Visit) error { - if in == nil { - return nil + if errF := a.rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { + parent.(*ConvertType).Length = newNode.(*Literal) + }); errF != nil { + return errF } - if cont, err := f(in); err != nil || !cont { - return err + if errF := a.rewriteRefOfLiteral(node, node.Scale, func(newNode, parent SQLNode) { + parent.(*ConvertType).Scale = newNode.(*Literal) + }); errF != nil { + return errF } - if err := VisitTableIdent(in.DBName, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfUse is part of the Rewrite implementation -func (a *application) rewriteRefOfUse(parent SQLNode, node *Use, replacer replacerFunc) error { +func (a *application) rewriteRefOfConvertUsingExpr(parent SQLNode, node *ConvertUsingExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -10399,8 +10623,8 @@ func (a *application) rewriteRefOfUse(parent SQLNode, node *Use, replacer replac if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteTableIdent(node, node.DBName, func(newNode, parent SQLNode) { - parent.(*Use).DBName = newNode.(TableIdent) + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*ConvertUsingExpr).Expr = newNode.(Expr) }); errF != nil { return errF } @@ -10409,64 +10633,29 @@ func (a *application) rewriteRefOfUse(parent SQLNode, node *Use, replacer replac } return nil } - -// EqualsRefOfVStream does deep equals between the two objects. -func EqualsRefOfVStream(a, b *VStream) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsComments(a.Comments, b.Comments) && - EqualsSelectExpr(a.SelectExpr, b.SelectExpr) && - EqualsTableName(a.Table, b.Table) && - EqualsRefOfWhere(a.Where, b.Where) && - EqualsRefOfLimit(a.Limit, b.Limit) -} - -// CloneRefOfVStream creates a deep clone of the input. -func CloneRefOfVStream(n *VStream) *VStream { - if n == nil { - return nil - } - out := *n - out.Comments = CloneComments(n.Comments) - out.SelectExpr = CloneSelectExpr(n.SelectExpr) - out.Table = CloneTableName(n.Table) - out.Where = CloneRefOfWhere(n.Where) - out.Limit = CloneRefOfLimit(n.Limit) - return &out -} - -// VisitRefOfVStream will visit all parts of the AST -func VisitRefOfVStream(in *VStream, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfCreateDatabase(parent SQLNode, node *CreateDatabase, replacer replacerFunc) error { + if node == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitComments(in.Comments, f); err != nil { - return err - } - if err := VisitSelectExpr(in.SelectExpr, f); err != nil { - return err + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if err := VisitTableName(in.Table, f); err != nil { - return err + if a.pre != nil && !a.pre(&cur) { + return nil } - if err := VisitRefOfWhere(in.Where, f); err != nil { - return err + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*CreateDatabase).Comments = newNode.(Comments) + }); errF != nil { + return errF } - if err := VisitRefOfLimit(in.Limit, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfVStream is part of the Rewrite implementation -func (a *application) rewriteRefOfVStream(parent SQLNode, node *VStream, replacer replacerFunc) error { +func (a *application) rewriteRefOfCreateTable(parent SQLNode, node *CreateTable, replacer replacerFunc) error { if node == nil { return nil } @@ -10478,28 +10667,18 @@ func (a *application) rewriteRefOfVStream(parent SQLNode, node *VStream, replace if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { - parent.(*VStream).Comments = newNode.(Comments) - }); errF != nil { - return errF - } - if errF := a.rewriteSelectExpr(node, node.SelectExpr, func(newNode, parent SQLNode) { - parent.(*VStream).SelectExpr = newNode.(SelectExpr) - }); errF != nil { - return errF - } if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { - parent.(*VStream).Table = newNode.(TableName) + parent.(*CreateTable).Table = newNode.(TableName) }); errF != nil { return errF } - if errF := a.rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { - parent.(*VStream).Where = newNode.(*Where) + if errF := a.rewriteRefOfTableSpec(node, node.TableSpec, func(newNode, parent SQLNode) { + parent.(*CreateTable).TableSpec = newNode.(*TableSpec) }); errF != nil { return errF } - if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { - parent.(*VStream).Limit = newNode.(*Limit) + if errF := a.rewriteRefOfOptLike(node, node.OptLike, func(newNode, parent SQLNode) { + parent.(*CreateTable).OptLike = newNode.(*OptLike) }); errF != nil { return errF } @@ -10508,47 +10687,39 @@ func (a *application) rewriteRefOfVStream(parent SQLNode, node *VStream, replace } return nil } - -// EqualsValTuple does deep equals between the two objects. -func EqualsValTuple(a, b ValTuple) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsExpr(a[i], b[i]) { - return false - } +func (a *application) rewriteRefOfCreateView(parent SQLNode, node *CreateView, replacer replacerFunc) error { + if node == nil { + return nil } - return true -} - -// CloneValTuple creates a deep clone of the input. -func CloneValTuple(n ValTuple) ValTuple { - res := make(ValTuple, 0, len(n)) - for _, x := range n { - res = append(res, CloneExpr(x)) + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return res -} - -// VisitValTuple will visit all parts of the AST -func VisitValTuple(in ValTuple, f Visit) error { - if in == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - if cont, err := f(in); err != nil || !cont { - return err + if errF := a.rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { + parent.(*CreateView).ViewName = newNode.(TableName) + }); errF != nil { + return errF } - for _, el := range in { - if err := VisitExpr(el, f); err != nil { - return err - } + if errF := a.rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { + parent.(*CreateView).Columns = newNode.(Columns) + }); errF != nil { + return errF + } + if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + parent.(*CreateView).Select = newNode.(SelectStatement) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteValTuple is part of the Rewrite implementation -func (a *application) rewriteValTuple(parent SQLNode, node ValTuple, replacer replacerFunc) error { +func (a *application) rewriteRefOfCurTimeFuncExpr(parent SQLNode, node *CurTimeFuncExpr, replacer replacerFunc) error { if node == nil { return nil } @@ -10560,52 +10731,39 @@ func (a *application) rewriteValTuple(parent SQLNode, node ValTuple, replacer re if a.pre != nil && !a.pre(&cur) { return nil } - for i, el := range node { - if errF := a.rewriteExpr(node, el, func(newNode, parent SQLNode) { - parent.(ValTuple)[i] = newNode.(Expr) - }); errF != nil { - return errF - } + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*CurTimeFuncExpr).Name = newNode.(ColIdent) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Fsp, func(newNode, parent SQLNode) { + parent.(*CurTimeFuncExpr).Fsp = newNode.(Expr) + }); errF != nil { + return errF } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfValidation does deep equals between the two objects. -func EqualsRefOfValidation(a, b *Validation) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.With == b.With -} - -// CloneRefOfValidation creates a deep clone of the input. -func CloneRefOfValidation(n *Validation) *Validation { - if n == nil { +func (a *application) rewriteRefOfDefault(parent SQLNode, node *Default, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - return &out -} - -// VisitRefOfValidation will visit all parts of the AST -func VisitRefOfValidation(in *Validation, f Visit) error { - if in == nil { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } - if cont, err := f(in); err != nil || !cont { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfValidation is part of the Rewrite implementation -func (a *application) rewriteRefOfValidation(parent SQLNode, node *Validation, replacer replacerFunc) error { +func (a *application) rewriteRefOfDelete(parent SQLNode, node *Delete, replacer replacerFunc) error { if node == nil { return nil } @@ -10617,52 +10775,47 @@ func (a *application) rewriteRefOfValidation(parent SQLNode, node *Validation, r if a.pre != nil && !a.pre(&cur) { return nil } - if a.post != nil && !a.post(&cur) { - return errAbort + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Delete).Comments = newNode.(Comments) + }); errF != nil { + return errF } - return nil -} - -// EqualsValues does deep equals between the two objects. -func EqualsValues(a, b Values) bool { - if len(a) != len(b) { - return false + if errF := a.rewriteTableNames(node, node.Targets, func(newNode, parent SQLNode) { + parent.(*Delete).Targets = newNode.(TableNames) + }); errF != nil { + return errF } - for i := 0; i < len(a); i++ { - if !EqualsValTuple(a[i], b[i]) { - return false - } + if errF := a.rewriteTableExprs(node, node.TableExprs, func(newNode, parent SQLNode) { + parent.(*Delete).TableExprs = newNode.(TableExprs) + }); errF != nil { + return errF } - return true -} - -// CloneValues creates a deep clone of the input. -func CloneValues(n Values) Values { - res := make(Values, 0, len(n)) - for _, x := range n { - res = append(res, CloneValTuple(x)) + if errF := a.rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { + parent.(*Delete).Partitions = newNode.(Partitions) + }); errF != nil { + return errF } - return res -} - -// VisitValues will visit all parts of the AST -func VisitValues(in Values, f Visit) error { - if in == nil { - return nil + if errF := a.rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { + parent.(*Delete).Where = newNode.(*Where) + }); errF != nil { + return errF } - if cont, err := f(in); err != nil || !cont { - return err + if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + parent.(*Delete).OrderBy = newNode.(OrderBy) + }); errF != nil { + return errF } - for _, el := range in { - if err := VisitValTuple(el, f); err != nil { - return err - } + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*Delete).Limit = newNode.(*Limit) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteValues is part of the Rewrite implementation -func (a *application) rewriteValues(parent SQLNode, node Values, replacer replacerFunc) error { +func (a *application) rewriteRefOfDerivedTable(parent SQLNode, node *DerivedTable, replacer replacerFunc) error { if node == nil { return nil } @@ -10674,56 +10827,39 @@ func (a *application) rewriteValues(parent SQLNode, node Values, replacer replac if a.pre != nil && !a.pre(&cur) { return nil } - for i, el := range node { - if errF := a.rewriteValTuple(node, el, func(newNode, parent SQLNode) { - parent.(Values)[i] = newNode.(ValTuple) - }); errF != nil { - return errF - } + if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + parent.(*DerivedTable).Select = newNode.(SelectStatement) + }); errF != nil { + return errF } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfValuesFuncExpr does deep equals between the two objects. -func EqualsRefOfValuesFuncExpr(a, b *ValuesFuncExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsRefOfColName(a.Name, b.Name) -} - -// CloneRefOfValuesFuncExpr creates a deep clone of the input. -func CloneRefOfValuesFuncExpr(n *ValuesFuncExpr) *ValuesFuncExpr { - if n == nil { +func (a *application) rewriteRefOfDropColumn(parent SQLNode, node *DropColumn, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.Name = CloneRefOfColName(n.Name) - return &out -} - -// VisitRefOfValuesFuncExpr will visit all parts of the AST -func VisitRefOfValuesFuncExpr(in *ValuesFuncExpr, f Visit) error { - if in == nil { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } - if cont, err := f(in); err != nil || !cont { - return err + if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { + parent.(*DropColumn).Name = newNode.(*ColName) + }); errF != nil { + return errF } - if err := VisitRefOfColName(in.Name, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil -} - -// rewriteRefOfValuesFuncExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfValuesFuncExpr(parent SQLNode, node *ValuesFuncExpr, replacer replacerFunc) error { +} +func (a *application) rewriteRefOfDropDatabase(parent SQLNode, node *DropDatabase, replacer replacerFunc) error { if node == nil { return nil } @@ -10735,8 +10871,8 @@ func (a *application) rewriteRefOfValuesFuncExpr(parent SQLNode, node *ValuesFun if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { - parent.(*ValuesFuncExpr).Name = newNode.(*ColName) + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*DropDatabase).Comments = newNode.(Comments) }); errF != nil { return errF } @@ -10745,32 +10881,10 @@ func (a *application) rewriteRefOfValuesFuncExpr(parent SQLNode, node *ValuesFun } return nil } - -// EqualsVindexParam does deep equals between the two objects. -func EqualsVindexParam(a, b VindexParam) bool { - return a.Val == b.Val && - EqualsColIdent(a.Key, b.Key) -} - -// CloneVindexParam creates a deep clone of the input. -func CloneVindexParam(n VindexParam) VindexParam { - return *CloneRefOfVindexParam(&n) -} - -// VisitVindexParam will visit all parts of the AST -func VisitVindexParam(in VindexParam, f Visit) error { - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitColIdent(in.Key, f); err != nil { - return err +func (a *application) rewriteRefOfDropKey(parent SQLNode, node *DropKey, replacer replacerFunc) error { + if node == nil { + return nil } - return nil -} - -// rewriteVindexParam is part of the Rewrite implementation -func (a *application) rewriteVindexParam(parent SQLNode, node VindexParam, replacer replacerFunc) error { - var err error cur := Cursor{ node: node, parent: parent, @@ -10779,69 +10893,34 @@ func (a *application) rewriteVindexParam(parent SQLNode, node VindexParam, repla if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { - err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Key' on 'VindexParam'") - }); errF != nil { - return errF - } - if err != nil { - return err - } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfVindexSpec does deep equals between the two objects. -func EqualsRefOfVindexSpec(a, b *VindexSpec) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsColIdent(a.Name, b.Name) && - EqualsColIdent(a.Type, b.Type) && - EqualsSliceOfVindexParam(a.Params, b.Params) -} - -// CloneRefOfVindexSpec creates a deep clone of the input. -func CloneRefOfVindexSpec(n *VindexSpec) *VindexSpec { - if n == nil { - return nil - } - out := *n - out.Name = CloneColIdent(n.Name) - out.Type = CloneColIdent(n.Type) - out.Params = CloneSliceOfVindexParam(n.Params) - return &out -} - -// VisitRefOfVindexSpec will visit all parts of the AST -func VisitRefOfVindexSpec(in *VindexSpec, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfDropTable(parent SQLNode, node *DropTable, replacer replacerFunc) error { + if node == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if err := VisitColIdent(in.Name, f); err != nil { - return err + if a.pre != nil && !a.pre(&cur) { + return nil } - if err := VisitColIdent(in.Type, f); err != nil { - return err + if errF := a.rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { + parent.(*DropTable).FromTables = newNode.(TableNames) + }); errF != nil { + return errF } - for _, el := range in.Params { - if err := VisitVindexParam(el, f); err != nil { - return err - } + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfVindexSpec is part of the Rewrite implementation -func (a *application) rewriteRefOfVindexSpec(parent SQLNode, node *VindexSpec, replacer replacerFunc) error { +func (a *application) rewriteRefOfDropView(parent SQLNode, node *DropView, replacer replacerFunc) error { if node == nil { return nil } @@ -10853,71 +10932,39 @@ func (a *application) rewriteRefOfVindexSpec(parent SQLNode, node *VindexSpec, r if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { - parent.(*VindexSpec).Name = newNode.(ColIdent) - }); errF != nil { - return errF - } - if errF := a.rewriteColIdent(node, node.Type, func(newNode, parent SQLNode) { - parent.(*VindexSpec).Type = newNode.(ColIdent) + if errF := a.rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { + parent.(*DropView).FromTables = newNode.(TableNames) }); errF != nil { return errF } - for i, el := range node.Params { - if errF := a.rewriteVindexParam(node, el, func(newNode, parent SQLNode) { - parent.(*VindexSpec).Params[i] = newNode.(VindexParam) - }); errF != nil { - return errF - } - } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfWhen does deep equals between the two objects. -func EqualsRefOfWhen(a, b *When) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsExpr(a.Cond, b.Cond) && - EqualsExpr(a.Val, b.Val) -} - -// CloneRefOfWhen creates a deep clone of the input. -func CloneRefOfWhen(n *When) *When { - if n == nil { +func (a *application) rewriteRefOfExistsExpr(parent SQLNode, node *ExistsExpr, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.Cond = CloneExpr(n.Cond) - out.Val = CloneExpr(n.Val) - return &out -} - -// VisitRefOfWhen will visit all parts of the AST -func VisitRefOfWhen(in *When, f Visit) error { - if in == nil { - return nil + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if cont, err := f(in); err != nil || !cont { - return err + if a.pre != nil && !a.pre(&cur) { + return nil } - if err := VisitExpr(in.Cond, f); err != nil { - return err + if errF := a.rewriteRefOfSubquery(node, node.Subquery, func(newNode, parent SQLNode) { + parent.(*ExistsExpr).Subquery = newNode.(*Subquery) + }); errF != nil { + return errF } - if err := VisitExpr(in.Val, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfWhen is part of the Rewrite implementation -func (a *application) rewriteRefOfWhen(parent SQLNode, node *When, replacer replacerFunc) error { +func (a *application) rewriteRefOfExplainStmt(parent SQLNode, node *ExplainStmt, replacer replacerFunc) error { if node == nil { return nil } @@ -10929,13 +10976,8 @@ func (a *application) rewriteRefOfWhen(parent SQLNode, node *When, replacer repl if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteExpr(node, node.Cond, func(newNode, parent SQLNode) { - parent.(*When).Cond = newNode.(Expr) - }); errF != nil { - return errF - } - if errF := a.rewriteExpr(node, node.Val, func(newNode, parent SQLNode) { - parent.(*When).Val = newNode.(Expr) + if errF := a.rewriteStatement(node, node.Statement, func(newNode, parent SQLNode) { + parent.(*ExplainStmt).Statement = newNode.(Statement) }); errF != nil { return errF } @@ -10944,45 +10986,29 @@ func (a *application) rewriteRefOfWhen(parent SQLNode, node *When, replacer repl } return nil } - -// EqualsRefOfWhere does deep equals between the two objects. -func EqualsRefOfWhere(a, b *Where) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Type == b.Type && - EqualsExpr(a.Expr, b.Expr) -} - -// CloneRefOfWhere creates a deep clone of the input. -func CloneRefOfWhere(n *Where) *Where { - if n == nil { +func (a *application) rewriteRefOfExplainTab(parent SQLNode, node *ExplainTab, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.Expr = CloneExpr(n.Expr) - return &out -} - -// VisitRefOfWhere will visit all parts of the AST -func VisitRefOfWhere(in *Where, f Visit) error { - if in == nil { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } - if cont, err := f(in); err != nil || !cont { - return err + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*ExplainTab).Table = newNode.(TableName) + }); errF != nil { + return errF } - if err := VisitExpr(in.Expr, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfWhere is part of the Rewrite implementation -func (a *application) rewriteRefOfWhere(parent SQLNode, node *Where, replacer replacerFunc) error { +func (a *application) rewriteRefOfFlush(parent SQLNode, node *Flush, replacer replacerFunc) error { if node == nil { return nil } @@ -10994,8 +11020,8 @@ func (a *application) rewriteRefOfWhere(parent SQLNode, node *Where, replacer re if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { - parent.(*Where).Expr = newNode.(Expr) + if errF := a.rewriteTableNames(node, node.TableNames, func(newNode, parent SQLNode) { + parent.(*Flush).TableNames = newNode.(TableNames) }); errF != nil { return errF } @@ -11004,49 +11030,24 @@ func (a *application) rewriteRefOfWhere(parent SQLNode, node *Where, replacer re } return nil } - -// EqualsRefOfXorExpr does deep equals between the two objects. -func EqualsRefOfXorExpr(a, b *XorExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsExpr(a.Left, b.Left) && - EqualsExpr(a.Right, b.Right) -} - -// CloneRefOfXorExpr creates a deep clone of the input. -func CloneRefOfXorExpr(n *XorExpr) *XorExpr { - if n == nil { - return nil - } - out := *n - out.Left = CloneExpr(n.Left) - out.Right = CloneExpr(n.Right) - return &out -} - -// VisitRefOfXorExpr will visit all parts of the AST -func VisitRefOfXorExpr(in *XorExpr, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfForce(parent SQLNode, node *Force, replacer replacerFunc) error { + if node == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if err := VisitExpr(in.Left, f); err != nil { - return err + if a.pre != nil && !a.pre(&cur) { + return nil } - if err := VisitExpr(in.Right, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfXorExpr is part of the Rewrite implementation -func (a *application) rewriteRefOfXorExpr(parent SQLNode, node *XorExpr, replacer replacerFunc) error { +func (a *application) rewriteRefOfForeignKeyDefinition(parent SQLNode, node *ForeignKeyDefinition, replacer replacerFunc) error { if node == nil { return nil } @@ -11058,2365 +11059,1603 @@ func (a *application) rewriteRefOfXorExpr(parent SQLNode, node *XorExpr, replace if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { - parent.(*XorExpr).Left = newNode.(Expr) + if errF := a.rewriteColumns(node, node.Source, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).Source = newNode.(Columns) }); errF != nil { return errF } - if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { - parent.(*XorExpr).Right = newNode.(Expr) + if errF := a.rewriteTableName(node, node.ReferencedTable, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).ReferencedTable = newNode.(TableName) }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { - return errAbort + if errF := a.rewriteColumns(node, node.ReferencedColumns, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).ReferencedColumns = newNode.(Columns) + }); errF != nil { + return errF } - return nil -} - -// EqualsAlterOption does deep equals between the two objects. -func EqualsAlterOption(inA, inB AlterOption) bool { - if inA == nil && inB == nil { - return true + if errF := a.rewriteReferenceAction(node, node.OnDelete, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).OnDelete = newNode.(ReferenceAction) + }); errF != nil { + return errF } - if inA == nil || inB == nil { - return false + if errF := a.rewriteReferenceAction(node, node.OnUpdate, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).OnUpdate = newNode.(ReferenceAction) + }); errF != nil { + return errF } - switch a := inA.(type) { - case *AddColumns: - b, ok := inB.(*AddColumns) - if !ok { - return false - } - return EqualsRefOfAddColumns(a, b) - case *AddConstraintDefinition: - b, ok := inB.(*AddConstraintDefinition) - if !ok { - return false - } - return EqualsRefOfAddConstraintDefinition(a, b) - case *AddIndexDefinition: - b, ok := inB.(*AddIndexDefinition) - if !ok { - return false - } - return EqualsRefOfAddIndexDefinition(a, b) - case AlgorithmValue: - b, ok := inB.(AlgorithmValue) - if !ok { - return false - } - return a == b - case *AlterCharset: - b, ok := inB.(*AlterCharset) - if !ok { - return false - } - return EqualsRefOfAlterCharset(a, b) - case *AlterColumn: - b, ok := inB.(*AlterColumn) - if !ok { - return false - } - return EqualsRefOfAlterColumn(a, b) - case *ChangeColumn: - b, ok := inB.(*ChangeColumn) - if !ok { - return false - } - return EqualsRefOfChangeColumn(a, b) - case *DropColumn: - b, ok := inB.(*DropColumn) - if !ok { - return false - } - return EqualsRefOfDropColumn(a, b) - case *DropKey: - b, ok := inB.(*DropKey) - if !ok { - return false - } - return EqualsRefOfDropKey(a, b) - case *Force: - b, ok := inB.(*Force) - if !ok { - return false - } - return EqualsRefOfForce(a, b) - case *KeyState: - b, ok := inB.(*KeyState) - if !ok { - return false - } - return EqualsRefOfKeyState(a, b) - case *LockOption: - b, ok := inB.(*LockOption) - if !ok { - return false - } - return EqualsRefOfLockOption(a, b) - case *ModifyColumn: - b, ok := inB.(*ModifyColumn) - if !ok { - return false - } - return EqualsRefOfModifyColumn(a, b) - case *OrderByOption: - b, ok := inB.(*OrderByOption) - if !ok { - return false - } - return EqualsRefOfOrderByOption(a, b) - case *RenameIndex: - b, ok := inB.(*RenameIndex) - if !ok { - return false - } - return EqualsRefOfRenameIndex(a, b) - case *RenameTableName: - b, ok := inB.(*RenameTableName) - if !ok { - return false - } - return EqualsRefOfRenameTableName(a, b) - case TableOptions: - b, ok := inB.(TableOptions) - if !ok { - return false - } - return EqualsTableOptions(a, b) - case *TablespaceOperation: - b, ok := inB.(*TablespaceOperation) - if !ok { - return false - } - return EqualsRefOfTablespaceOperation(a, b) - case *Validation: - b, ok := inB.(*Validation) - if !ok { - return false - } - return EqualsRefOfValidation(a, b) - default: - // this should never happen - return false + if a.post != nil && !a.post(&cur) { + return errAbort } + return nil } - -// CloneAlterOption creates a deep clone of the input. -func CloneAlterOption(in AlterOption) AlterOption { - if in == nil { +func (a *application) rewriteRefOfFuncExpr(parent SQLNode, node *FuncExpr, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AddColumns: - return CloneRefOfAddColumns(in) - case *AddConstraintDefinition: - return CloneRefOfAddConstraintDefinition(in) - case *AddIndexDefinition: - return CloneRefOfAddIndexDefinition(in) - case AlgorithmValue: - return in - case *AlterCharset: - return CloneRefOfAlterCharset(in) - case *AlterColumn: - return CloneRefOfAlterColumn(in) - case *ChangeColumn: - return CloneRefOfChangeColumn(in) - case *DropColumn: - return CloneRefOfDropColumn(in) - case *DropKey: - return CloneRefOfDropKey(in) - case *Force: - return CloneRefOfForce(in) - case *KeyState: - return CloneRefOfKeyState(in) - case *LockOption: - return CloneRefOfLockOption(in) - case *ModifyColumn: - return CloneRefOfModifyColumn(in) - case *OrderByOption: - return CloneRefOfOrderByOption(in) - case *RenameIndex: - return CloneRefOfRenameIndex(in) - case *RenameTableName: - return CloneRefOfRenameTableName(in) - case TableOptions: - return CloneTableOptions(in) - case *TablespaceOperation: - return CloneRefOfTablespaceOperation(in) - case *Validation: - return CloneRefOfValidation(in) - default: - // this should never happen - return nil + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } -} - -// VisitAlterOption will visit all parts of the AST -func VisitAlterOption(in AlterOption, f Visit) error { - if in == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - switch in := in.(type) { - case *AddColumns: - return VisitRefOfAddColumns(in, f) - case *AddConstraintDefinition: - return VisitRefOfAddConstraintDefinition(in, f) - case *AddIndexDefinition: - return VisitRefOfAddIndexDefinition(in, f) - case AlgorithmValue: - return VisitAlgorithmValue(in, f) - case *AlterCharset: - return VisitRefOfAlterCharset(in, f) - case *AlterColumn: - return VisitRefOfAlterColumn(in, f) - case *ChangeColumn: - return VisitRefOfChangeColumn(in, f) - case *DropColumn: - return VisitRefOfDropColumn(in, f) - case *DropKey: - return VisitRefOfDropKey(in, f) - case *Force: - return VisitRefOfForce(in, f) - case *KeyState: - return VisitRefOfKeyState(in, f) - case *LockOption: - return VisitRefOfLockOption(in, f) - case *ModifyColumn: - return VisitRefOfModifyColumn(in, f) - case *OrderByOption: - return VisitRefOfOrderByOption(in, f) - case *RenameIndex: - return VisitRefOfRenameIndex(in, f) - case *RenameTableName: - return VisitRefOfRenameTableName(in, f) - case TableOptions: - return VisitTableOptions(in, f) - case *TablespaceOperation: - return VisitRefOfTablespaceOperation(in, f) - case *Validation: - return VisitRefOfValidation(in, f) - default: - // this should never happen - return nil + if errF := a.rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { + parent.(*FuncExpr).Qualifier = newNode.(TableIdent) + }); errF != nil { + return errF + } + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*FuncExpr).Name = newNode.(ColIdent) + }); errF != nil { + return errF + } + if errF := a.rewriteSelectExprs(node, node.Exprs, func(newNode, parent SQLNode) { + parent.(*FuncExpr).Exprs = newNode.(SelectExprs) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort } + return nil } - -// rewriteAlterOption is part of the Rewrite implementation -func (a *application) rewriteAlterOption(parent SQLNode, node AlterOption, replacer replacerFunc) error { +func (a *application) rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupConcatExpr, replacer replacerFunc) error { if node == nil { return nil } - switch node := node.(type) { - case *AddColumns: - return a.rewriteRefOfAddColumns(parent, node, replacer) - case *AddConstraintDefinition: - return a.rewriteRefOfAddConstraintDefinition(parent, node, replacer) - case *AddIndexDefinition: - return a.rewriteRefOfAddIndexDefinition(parent, node, replacer) - case AlgorithmValue: - return a.rewriteAlgorithmValue(parent, node, replacer) - case *AlterCharset: - return a.rewriteRefOfAlterCharset(parent, node, replacer) - case *AlterColumn: - return a.rewriteRefOfAlterColumn(parent, node, replacer) - case *ChangeColumn: - return a.rewriteRefOfChangeColumn(parent, node, replacer) - case *DropColumn: - return a.rewriteRefOfDropColumn(parent, node, replacer) - case *DropKey: - return a.rewriteRefOfDropKey(parent, node, replacer) - case *Force: - return a.rewriteRefOfForce(parent, node, replacer) - case *KeyState: - return a.rewriteRefOfKeyState(parent, node, replacer) - case *LockOption: - return a.rewriteRefOfLockOption(parent, node, replacer) - case *ModifyColumn: - return a.rewriteRefOfModifyColumn(parent, node, replacer) - case *OrderByOption: - return a.rewriteRefOfOrderByOption(parent, node, replacer) - case *RenameIndex: - return a.rewriteRefOfRenameIndex(parent, node, replacer) - case *RenameTableName: - return a.rewriteRefOfRenameTableName(parent, node, replacer) - case TableOptions: - return a.rewriteTableOptions(parent, node, replacer) - case *TablespaceOperation: - return a.rewriteRefOfTablespaceOperation(parent, node, replacer) - case *Validation: - return a.rewriteRefOfValidation(parent, node, replacer) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } -} - -// EqualsCharacteristic does deep equals between the two objects. -func EqualsCharacteristic(inA, inB Characteristic) bool { - if inA == nil && inB == nil { - return true + if errF := a.rewriteSelectExprs(node, node.Exprs, func(newNode, parent SQLNode) { + parent.(*GroupConcatExpr).Exprs = newNode.(SelectExprs) + }); errF != nil { + return errF } - if inA == nil || inB == nil { - return false + if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + parent.(*GroupConcatExpr).OrderBy = newNode.(OrderBy) + }); errF != nil { + return errF } - switch a := inA.(type) { - case AccessMode: - b, ok := inB.(AccessMode) - if !ok { - return false - } - return a == b - case IsolationLevel: - b, ok := inB.(IsolationLevel) - if !ok { - return false - } - return a == b - default: - // this should never happen - return false + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*GroupConcatExpr).Limit = newNode.(*Limit) + }); errF != nil { + return errF } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// CloneCharacteristic creates a deep clone of the input. -func CloneCharacteristic(in Characteristic) Characteristic { - if in == nil { +func (a *application) rewriteRefOfIndexDefinition(parent SQLNode, node *IndexDefinition, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case AccessMode: - return in - case IsolationLevel: - return in - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteRefOfIndexInfo(node, node.Info, func(newNode, parent SQLNode) { + parent.(*IndexDefinition).Info = newNode.(*IndexInfo) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// VisitCharacteristic will visit all parts of the AST -func VisitCharacteristic(in Characteristic, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfIndexHints(parent SQLNode, node *IndexHints, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case AccessMode: - return VisitAccessMode(in, f) - case IsolationLevel: - return VisitIsolationLevel(in, f) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + for i, el := range node.Indexes { + if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { + parent.(*IndexHints).Indexes[i] = newNode.(ColIdent) + }); errF != nil { + return errF + } + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// rewriteCharacteristic is part of the Rewrite implementation -func (a *application) rewriteCharacteristic(parent SQLNode, node Characteristic, replacer replacerFunc) error { +func (a *application) rewriteRefOfIndexInfo(parent SQLNode, node *IndexInfo, replacer replacerFunc) error { if node == nil { return nil } - switch node := node.(type) { - case AccessMode: - return a.rewriteAccessMode(parent, node, replacer) - case IsolationLevel: - return a.rewriteIsolationLevel(parent, node, replacer) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*IndexInfo).Name = newNode.(ColIdent) + }); errF != nil { + return errF + } + if errF := a.rewriteColIdent(node, node.ConstraintName, func(newNode, parent SQLNode) { + parent.(*IndexInfo).ConstraintName = newNode.(ColIdent) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// EqualsColTuple does deep equals between the two objects. -func EqualsColTuple(inA, inB ColTuple) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteRefOfInsert(parent SQLNode, node *Insert, replacer replacerFunc) error { + if node == nil { + return nil } - if inA == nil || inB == nil { - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - switch a := inA.(type) { - case ListArg: - b, ok := inB.(ListArg) - if !ok { - return false - } - return EqualsListArg(a, b) - case *Subquery: - b, ok := inB.(*Subquery) - if !ok { - return false - } - return EqualsRefOfSubquery(a, b) - case ValTuple: - b, ok := inB.(ValTuple) - if !ok { - return false - } - return EqualsValTuple(a, b) - default: - // this should never happen - return false + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Insert).Comments = newNode.(Comments) + }); errF != nil { + return errF + } + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*Insert).Table = newNode.(TableName) + }); errF != nil { + return errF + } + if errF := a.rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { + parent.(*Insert).Partitions = newNode.(Partitions) + }); errF != nil { + return errF + } + if errF := a.rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { + parent.(*Insert).Columns = newNode.(Columns) + }); errF != nil { + return errF + } + if errF := a.rewriteInsertRows(node, node.Rows, func(newNode, parent SQLNode) { + parent.(*Insert).Rows = newNode.(InsertRows) + }); errF != nil { + return errF + } + if errF := a.rewriteOnDup(node, node.OnDup, func(newNode, parent SQLNode) { + parent.(*Insert).OnDup = newNode.(OnDup) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort } + return nil } - -// CloneColTuple creates a deep clone of the input. -func CloneColTuple(in ColTuple) ColTuple { - if in == nil { +func (a *application) rewriteRefOfIntervalExpr(parent SQLNode, node *IntervalExpr, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case ListArg: - return CloneListArg(in) - case *Subquery: - return CloneRefOfSubquery(in) - case ValTuple: - return CloneValTuple(in) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*IntervalExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// VisitColTuple will visit all parts of the AST -func VisitColTuple(in ColTuple, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfIsExpr(parent SQLNode, node *IsExpr, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case ListArg: - return VisitListArg(in, f) - case *Subquery: - return VisitRefOfSubquery(in, f) - case ValTuple: - return VisitValTuple(in, f) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*IsExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// rewriteColTuple is part of the Rewrite implementation -func (a *application) rewriteColTuple(parent SQLNode, node ColTuple, replacer replacerFunc) error { +func (a *application) rewriteRefOfJoinCondition(parent SQLNode, node *JoinCondition, replacer replacerFunc) error { if node == nil { return nil } - switch node := node.(type) { - case ListArg: - return a.rewriteListArg(parent, node, replacer) - case *Subquery: - return a.rewriteRefOfSubquery(parent, node, replacer) - case ValTuple: - return a.rewriteValTuple(parent, node, replacer) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } -} - -// EqualsConstraintInfo does deep equals between the two objects. -func EqualsConstraintInfo(inA, inB ConstraintInfo) bool { - if inA == nil && inB == nil { - return true + if errF := a.rewriteExpr(node, node.On, func(newNode, parent SQLNode) { + parent.(*JoinCondition).On = newNode.(Expr) + }); errF != nil { + return errF } - if inA == nil || inB == nil { - return false + if errF := a.rewriteColumns(node, node.Using, func(newNode, parent SQLNode) { + parent.(*JoinCondition).Using = newNode.(Columns) + }); errF != nil { + return errF } - switch a := inA.(type) { - case *CheckConstraintDefinition: - b, ok := inB.(*CheckConstraintDefinition) - if !ok { - return false - } - return EqualsRefOfCheckConstraintDefinition(a, b) - case *ForeignKeyDefinition: - b, ok := inB.(*ForeignKeyDefinition) - if !ok { - return false - } - return EqualsRefOfForeignKeyDefinition(a, b) - default: - // this should never happen - return false + if a.post != nil && !a.post(&cur) { + return errAbort } + return nil } - -// CloneConstraintInfo creates a deep clone of the input. -func CloneConstraintInfo(in ConstraintInfo) ConstraintInfo { - if in == nil { +func (a *application) rewriteRefOfJoinTableExpr(parent SQLNode, node *JoinTableExpr, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *CheckConstraintDefinition: - return CloneRefOfCheckConstraintDefinition(in) - case *ForeignKeyDefinition: - return CloneRefOfForeignKeyDefinition(in) - default: - // this should never happen - return nil + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } -} - -// VisitConstraintInfo will visit all parts of the AST -func VisitConstraintInfo(in ConstraintInfo, f Visit) error { - if in == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - switch in := in.(type) { - case *CheckConstraintDefinition: - return VisitRefOfCheckConstraintDefinition(in, f) - case *ForeignKeyDefinition: - return VisitRefOfForeignKeyDefinition(in, f) - default: - // this should never happen - return nil + if errF := a.rewriteTableExpr(node, node.LeftExpr, func(newNode, parent SQLNode) { + parent.(*JoinTableExpr).LeftExpr = newNode.(TableExpr) + }); errF != nil { + return errF + } + if errF := a.rewriteTableExpr(node, node.RightExpr, func(newNode, parent SQLNode) { + parent.(*JoinTableExpr).RightExpr = newNode.(TableExpr) + }); errF != nil { + return errF + } + if errF := a.rewriteJoinCondition(node, node.Condition, func(newNode, parent SQLNode) { + parent.(*JoinTableExpr).Condition = newNode.(JoinCondition) + }); errF != nil { + return errF } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// rewriteConstraintInfo is part of the Rewrite implementation -func (a *application) rewriteConstraintInfo(parent SQLNode, node ConstraintInfo, replacer replacerFunc) error { +func (a *application) rewriteRefOfKeyState(parent SQLNode, node *KeyState, replacer replacerFunc) error { if node == nil { return nil } - switch node := node.(type) { - case *CheckConstraintDefinition: - return a.rewriteRefOfCheckConstraintDefinition(parent, node, replacer) - case *ForeignKeyDefinition: - return a.rewriteRefOfForeignKeyDefinition(parent, node, replacer) - default: - // this should never happen - return nil - } -} - -// EqualsDBDDLStatement does deep equals between the two objects. -func EqualsDBDDLStatement(inA, inB DBDDLStatement) bool { - if inA == nil && inB == nil { - return true + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if inA == nil || inB == nil { - return false + if a.pre != nil && !a.pre(&cur) { + return nil } - switch a := inA.(type) { - case *AlterDatabase: - b, ok := inB.(*AlterDatabase) - if !ok { - return false - } - return EqualsRefOfAlterDatabase(a, b) - case *CreateDatabase: - b, ok := inB.(*CreateDatabase) - if !ok { - return false - } - return EqualsRefOfCreateDatabase(a, b) - case *DropDatabase: - b, ok := inB.(*DropDatabase) - if !ok { - return false - } - return EqualsRefOfDropDatabase(a, b) - default: - // this should never happen - return false + if a.post != nil && !a.post(&cur) { + return errAbort } + return nil } - -// CloneDBDDLStatement creates a deep clone of the input. -func CloneDBDDLStatement(in DBDDLStatement) DBDDLStatement { - if in == nil { +func (a *application) rewriteRefOfLimit(parent SQLNode, node *Limit, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AlterDatabase: - return CloneRefOfAlterDatabase(in) - case *CreateDatabase: - return CloneRefOfCreateDatabase(in) - case *DropDatabase: - return CloneRefOfDropDatabase(in) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteExpr(node, node.Offset, func(newNode, parent SQLNode) { + parent.(*Limit).Offset = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Rowcount, func(newNode, parent SQLNode) { + parent.(*Limit).Rowcount = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// VisitDBDDLStatement will visit all parts of the AST -func VisitDBDDLStatement(in DBDDLStatement, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfLiteral(parent SQLNode, node *Literal, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AlterDatabase: - return VisitRefOfAlterDatabase(in, f) - case *CreateDatabase: - return VisitRefOfCreateDatabase(in, f) - case *DropDatabase: - return VisitRefOfDropDatabase(in, f) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// rewriteDBDDLStatement is part of the Rewrite implementation -func (a *application) rewriteDBDDLStatement(parent SQLNode, node DBDDLStatement, replacer replacerFunc) error { +func (a *application) rewriteRefOfLoad(parent SQLNode, node *Load, replacer replacerFunc) error { if node == nil { return nil } - switch node := node.(type) { - case *AlterDatabase: - return a.rewriteRefOfAlterDatabase(parent, node, replacer) - case *CreateDatabase: - return a.rewriteRefOfCreateDatabase(parent, node, replacer) - case *DropDatabase: - return a.rewriteRefOfDropDatabase(parent, node, replacer) - default: - // this should never happen - return nil - } -} - -// EqualsDDLStatement does deep equals between the two objects. -func EqualsDDLStatement(inA, inB DDLStatement) bool { - if inA == nil && inB == nil { - return true + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if inA == nil || inB == nil { - return false + if a.pre != nil && !a.pre(&cur) { + return nil } - switch a := inA.(type) { - case *AlterTable: - b, ok := inB.(*AlterTable) - if !ok { - return false - } - return EqualsRefOfAlterTable(a, b) - case *AlterView: - b, ok := inB.(*AlterView) - if !ok { - return false - } - return EqualsRefOfAlterView(a, b) - case *CreateTable: - b, ok := inB.(*CreateTable) - if !ok { - return false - } - return EqualsRefOfCreateTable(a, b) - case *CreateView: - b, ok := inB.(*CreateView) - if !ok { - return false - } - return EqualsRefOfCreateView(a, b) - case *DropTable: - b, ok := inB.(*DropTable) - if !ok { - return false - } - return EqualsRefOfDropTable(a, b) - case *DropView: - b, ok := inB.(*DropView) - if !ok { - return false - } - return EqualsRefOfDropView(a, b) - case *RenameTable: - b, ok := inB.(*RenameTable) - if !ok { - return false - } - return EqualsRefOfRenameTable(a, b) - case *TruncateTable: - b, ok := inB.(*TruncateTable) - if !ok { - return false - } - return EqualsRefOfTruncateTable(a, b) - default: - // this should never happen - return false + if a.post != nil && !a.post(&cur) { + return errAbort } + return nil } - -// CloneDDLStatement creates a deep clone of the input. -func CloneDDLStatement(in DDLStatement) DDLStatement { - if in == nil { +func (a *application) rewriteRefOfLockOption(parent SQLNode, node *LockOption, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AlterTable: - return CloneRefOfAlterTable(in) - case *AlterView: - return CloneRefOfAlterView(in) - case *CreateTable: - return CloneRefOfCreateTable(in) - case *CreateView: - return CloneRefOfCreateView(in) - case *DropTable: - return CloneRefOfDropTable(in) - case *DropView: - return CloneRefOfDropView(in) - case *RenameTable: - return CloneRefOfRenameTable(in) - case *TruncateTable: - return CloneRefOfTruncateTable(in) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// VisitDDLStatement will visit all parts of the AST -func VisitDDLStatement(in DDLStatement, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfLockTables(parent SQLNode, node *LockTables, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AlterTable: - return VisitRefOfAlterTable(in, f) - case *AlterView: - return VisitRefOfAlterView(in, f) - case *CreateTable: - return VisitRefOfCreateTable(in, f) - case *CreateView: - return VisitRefOfCreateView(in, f) - case *DropTable: - return VisitRefOfDropTable(in, f) - case *DropView: - return VisitRefOfDropView(in, f) - case *RenameTable: - return VisitRefOfRenameTable(in, f) - case *TruncateTable: - return VisitRefOfTruncateTable(in, f) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// rewriteDDLStatement is part of the Rewrite implementation -func (a *application) rewriteDDLStatement(parent SQLNode, node DDLStatement, replacer replacerFunc) error { +func (a *application) rewriteRefOfMatchExpr(parent SQLNode, node *MatchExpr, replacer replacerFunc) error { if node == nil { return nil } - switch node := node.(type) { - case *AlterTable: - return a.rewriteRefOfAlterTable(parent, node, replacer) - case *AlterView: - return a.rewriteRefOfAlterView(parent, node, replacer) - case *CreateTable: - return a.rewriteRefOfCreateTable(parent, node, replacer) - case *CreateView: - return a.rewriteRefOfCreateView(parent, node, replacer) - case *DropTable: - return a.rewriteRefOfDropTable(parent, node, replacer) - case *DropView: - return a.rewriteRefOfDropView(parent, node, replacer) - case *RenameTable: - return a.rewriteRefOfRenameTable(parent, node, replacer) - case *TruncateTable: - return a.rewriteRefOfTruncateTable(parent, node, replacer) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } -} - -// EqualsExplain does deep equals between the two objects. -func EqualsExplain(inA, inB Explain) bool { - if inA == nil && inB == nil { - return true + if errF := a.rewriteSelectExprs(node, node.Columns, func(newNode, parent SQLNode) { + parent.(*MatchExpr).Columns = newNode.(SelectExprs) + }); errF != nil { + return errF } - if inA == nil || inB == nil { - return false + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*MatchExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF } - switch a := inA.(type) { - case *ExplainStmt: - b, ok := inB.(*ExplainStmt) - if !ok { - return false - } - return EqualsRefOfExplainStmt(a, b) - case *ExplainTab: - b, ok := inB.(*ExplainTab) - if !ok { - return false - } - return EqualsRefOfExplainTab(a, b) - default: - // this should never happen - return false + if a.post != nil && !a.post(&cur) { + return errAbort } + return nil } - -// CloneExplain creates a deep clone of the input. -func CloneExplain(in Explain) Explain { - if in == nil { - return nil - } - switch in := in.(type) { - case *ExplainStmt: - return CloneRefOfExplainStmt(in) - case *ExplainTab: - return CloneRefOfExplainTab(in) - default: - // this should never happen +func (a *application) rewriteRefOfModifyColumn(parent SQLNode, node *ModifyColumn, replacer replacerFunc) error { + if node == nil { return nil } -} - -// VisitExplain will visit all parts of the AST -func VisitExplain(in Explain, f Visit) error { - if in == nil { - return nil + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - switch in := in.(type) { - case *ExplainStmt: - return VisitRefOfExplainStmt(in, f) - case *ExplainTab: - return VisitRefOfExplainTab(in, f) - default: - // this should never happen + if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteRefOfColumnDefinition(node, node.NewColDefinition, func(newNode, parent SQLNode) { + parent.(*ModifyColumn).NewColDefinition = newNode.(*ColumnDefinition) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { + parent.(*ModifyColumn).First = newNode.(*ColName) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { + parent.(*ModifyColumn).After = newNode.(*ColName) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// rewriteExplain is part of the Rewrite implementation -func (a *application) rewriteExplain(parent SQLNode, node Explain, replacer replacerFunc) error { +func (a *application) rewriteRefOfNextval(parent SQLNode, node *Nextval, replacer replacerFunc) error { if node == nil { return nil } - switch node := node.(type) { - case *ExplainStmt: - return a.rewriteRefOfExplainStmt(parent, node, replacer) - case *ExplainTab: - return a.rewriteRefOfExplainTab(parent, node, replacer) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } -} - -// EqualsExpr does deep equals between the two objects. -func EqualsExpr(inA, inB Expr) bool { - if inA == nil && inB == nil { - return true + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*Nextval).Expr = newNode.(Expr) + }); errF != nil { + return errF } - if inA == nil || inB == nil { - return false + if a.post != nil && !a.post(&cur) { + return errAbort } - switch a := inA.(type) { - case *AndExpr: - b, ok := inB.(*AndExpr) - if !ok { - return false - } - return EqualsRefOfAndExpr(a, b) - case Argument: - b, ok := inB.(Argument) - if !ok { - return false - } - return a == b - case *BinaryExpr: - b, ok := inB.(*BinaryExpr) - if !ok { - return false - } - return EqualsRefOfBinaryExpr(a, b) - case BoolVal: - b, ok := inB.(BoolVal) - if !ok { - return false - } - return a == b - case *CaseExpr: - b, ok := inB.(*CaseExpr) - if !ok { - return false - } - return EqualsRefOfCaseExpr(a, b) - case *ColName: - b, ok := inB.(*ColName) - if !ok { - return false - } - return EqualsRefOfColName(a, b) - case *CollateExpr: - b, ok := inB.(*CollateExpr) - if !ok { - return false - } - return EqualsRefOfCollateExpr(a, b) - case *ComparisonExpr: - b, ok := inB.(*ComparisonExpr) - if !ok { - return false - } - return EqualsRefOfComparisonExpr(a, b) - case *ConvertExpr: - b, ok := inB.(*ConvertExpr) - if !ok { - return false - } - return EqualsRefOfConvertExpr(a, b) - case *ConvertUsingExpr: - b, ok := inB.(*ConvertUsingExpr) - if !ok { - return false - } - return EqualsRefOfConvertUsingExpr(a, b) - case *CurTimeFuncExpr: - b, ok := inB.(*CurTimeFuncExpr) - if !ok { - return false - } - return EqualsRefOfCurTimeFuncExpr(a, b) - case *Default: - b, ok := inB.(*Default) - if !ok { - return false - } - return EqualsRefOfDefault(a, b) - case *ExistsExpr: - b, ok := inB.(*ExistsExpr) - if !ok { - return false - } - return EqualsRefOfExistsExpr(a, b) - case *FuncExpr: - b, ok := inB.(*FuncExpr) - if !ok { - return false - } - return EqualsRefOfFuncExpr(a, b) - case *GroupConcatExpr: - b, ok := inB.(*GroupConcatExpr) - if !ok { - return false - } - return EqualsRefOfGroupConcatExpr(a, b) - case *IntervalExpr: - b, ok := inB.(*IntervalExpr) - if !ok { - return false - } - return EqualsRefOfIntervalExpr(a, b) - case *IsExpr: - b, ok := inB.(*IsExpr) - if !ok { - return false - } - return EqualsRefOfIsExpr(a, b) - case ListArg: - b, ok := inB.(ListArg) - if !ok { - return false - } - return EqualsListArg(a, b) - case *Literal: - b, ok := inB.(*Literal) - if !ok { - return false - } - return EqualsRefOfLiteral(a, b) - case *MatchExpr: - b, ok := inB.(*MatchExpr) - if !ok { - return false - } - return EqualsRefOfMatchExpr(a, b) - case *NotExpr: - b, ok := inB.(*NotExpr) - if !ok { - return false - } - return EqualsRefOfNotExpr(a, b) - case *NullVal: - b, ok := inB.(*NullVal) - if !ok { - return false - } - return EqualsRefOfNullVal(a, b) - case *OrExpr: - b, ok := inB.(*OrExpr) - if !ok { - return false - } - return EqualsRefOfOrExpr(a, b) - case *RangeCond: - b, ok := inB.(*RangeCond) - if !ok { - return false - } - return EqualsRefOfRangeCond(a, b) - case *Subquery: - b, ok := inB.(*Subquery) - if !ok { - return false - } - return EqualsRefOfSubquery(a, b) - case *SubstrExpr: - b, ok := inB.(*SubstrExpr) - if !ok { - return false - } - return EqualsRefOfSubstrExpr(a, b) - case *TimestampFuncExpr: - b, ok := inB.(*TimestampFuncExpr) - if !ok { - return false - } - return EqualsRefOfTimestampFuncExpr(a, b) - case *UnaryExpr: - b, ok := inB.(*UnaryExpr) - if !ok { - return false - } - return EqualsRefOfUnaryExpr(a, b) - case ValTuple: - b, ok := inB.(ValTuple) - if !ok { - return false - } - return EqualsValTuple(a, b) - case *ValuesFuncExpr: - b, ok := inB.(*ValuesFuncExpr) - if !ok { - return false - } - return EqualsRefOfValuesFuncExpr(a, b) - case *XorExpr: - b, ok := inB.(*XorExpr) - if !ok { - return false - } - return EqualsRefOfXorExpr(a, b) - default: - // this should never happen - return false + return nil +} +func (a *application) rewriteRefOfNotExpr(parent SQLNode, node *NotExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*NotExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort } + return nil } - -// CloneExpr creates a deep clone of the input. -func CloneExpr(in Expr) Expr { - if in == nil { +func (a *application) rewriteRefOfNullVal(parent SQLNode, node *NullVal, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AndExpr: - return CloneRefOfAndExpr(in) - case Argument: - return in - case *BinaryExpr: - return CloneRefOfBinaryExpr(in) - case BoolVal: - return in - case *CaseExpr: - return CloneRefOfCaseExpr(in) - case *ColName: - return CloneRefOfColName(in) - case *CollateExpr: - return CloneRefOfCollateExpr(in) - case *ComparisonExpr: - return CloneRefOfComparisonExpr(in) - case *ConvertExpr: - return CloneRefOfConvertExpr(in) - case *ConvertUsingExpr: - return CloneRefOfConvertUsingExpr(in) - case *CurTimeFuncExpr: - return CloneRefOfCurTimeFuncExpr(in) - case *Default: - return CloneRefOfDefault(in) - case *ExistsExpr: - return CloneRefOfExistsExpr(in) - case *FuncExpr: - return CloneRefOfFuncExpr(in) - case *GroupConcatExpr: - return CloneRefOfGroupConcatExpr(in) - case *IntervalExpr: - return CloneRefOfIntervalExpr(in) - case *IsExpr: - return CloneRefOfIsExpr(in) - case ListArg: - return CloneListArg(in) - case *Literal: - return CloneRefOfLiteral(in) - case *MatchExpr: - return CloneRefOfMatchExpr(in) - case *NotExpr: - return CloneRefOfNotExpr(in) - case *NullVal: - return CloneRefOfNullVal(in) - case *OrExpr: - return CloneRefOfOrExpr(in) - case *RangeCond: - return CloneRefOfRangeCond(in) - case *Subquery: - return CloneRefOfSubquery(in) - case *SubstrExpr: - return CloneRefOfSubstrExpr(in) - case *TimestampFuncExpr: - return CloneRefOfTimestampFuncExpr(in) - case *UnaryExpr: - return CloneRefOfUnaryExpr(in) - case ValTuple: - return CloneValTuple(in) - case *ValuesFuncExpr: - return CloneRefOfValuesFuncExpr(in) - case *XorExpr: - return CloneRefOfXorExpr(in) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// VisitExpr will visit all parts of the AST -func VisitExpr(in Expr, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfOptLike(parent SQLNode, node *OptLike, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AndExpr: - return VisitRefOfAndExpr(in, f) - case Argument: - return VisitArgument(in, f) - case *BinaryExpr: - return VisitRefOfBinaryExpr(in, f) - case BoolVal: - return VisitBoolVal(in, f) - case *CaseExpr: - return VisitRefOfCaseExpr(in, f) - case *ColName: - return VisitRefOfColName(in, f) - case *CollateExpr: - return VisitRefOfCollateExpr(in, f) - case *ComparisonExpr: - return VisitRefOfComparisonExpr(in, f) - case *ConvertExpr: - return VisitRefOfConvertExpr(in, f) - case *ConvertUsingExpr: - return VisitRefOfConvertUsingExpr(in, f) - case *CurTimeFuncExpr: - return VisitRefOfCurTimeFuncExpr(in, f) - case *Default: - return VisitRefOfDefault(in, f) - case *ExistsExpr: - return VisitRefOfExistsExpr(in, f) - case *FuncExpr: - return VisitRefOfFuncExpr(in, f) - case *GroupConcatExpr: - return VisitRefOfGroupConcatExpr(in, f) - case *IntervalExpr: - return VisitRefOfIntervalExpr(in, f) - case *IsExpr: - return VisitRefOfIsExpr(in, f) - case ListArg: - return VisitListArg(in, f) - case *Literal: - return VisitRefOfLiteral(in, f) - case *MatchExpr: - return VisitRefOfMatchExpr(in, f) - case *NotExpr: - return VisitRefOfNotExpr(in, f) - case *NullVal: - return VisitRefOfNullVal(in, f) - case *OrExpr: - return VisitRefOfOrExpr(in, f) - case *RangeCond: - return VisitRefOfRangeCond(in, f) - case *Subquery: - return VisitRefOfSubquery(in, f) - case *SubstrExpr: - return VisitRefOfSubstrExpr(in, f) - case *TimestampFuncExpr: - return VisitRefOfTimestampFuncExpr(in, f) - case *UnaryExpr: - return VisitRefOfUnaryExpr(in, f) - case ValTuple: - return VisitValTuple(in, f) - case *ValuesFuncExpr: - return VisitRefOfValuesFuncExpr(in, f) - case *XorExpr: - return VisitRefOfXorExpr(in, f) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteTableName(node, node.LikeTable, func(newNode, parent SQLNode) { + parent.(*OptLike).LikeTable = newNode.(TableName) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil +} +func (a *application) rewriteRefOfOrExpr(parent SQLNode, node *OrExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*OrExpr).Left = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + parent.(*OrExpr).Right = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil +} +func (a *application) rewriteRefOfOrder(parent SQLNode, node *Order, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*Order).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil +} +func (a *application) rewriteRefOfOrderByOption(parent SQLNode, node *OrderByOption, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteColumns(node, node.Cols, func(newNode, parent SQLNode) { + parent.(*OrderByOption).Cols = newNode.(Columns) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil +} +func (a *application) rewriteRefOfOtherAdmin(parent SQLNode, node *OtherAdmin, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// rewriteExpr is part of the Rewrite implementation -func (a *application) rewriteExpr(parent SQLNode, node Expr, replacer replacerFunc) error { +func (a *application) rewriteRefOfOtherRead(parent SQLNode, node *OtherRead, replacer replacerFunc) error { if node == nil { return nil } - switch node := node.(type) { - case *AndExpr: - return a.rewriteRefOfAndExpr(parent, node, replacer) - case Argument: - return a.rewriteArgument(parent, node, replacer) - case *BinaryExpr: - return a.rewriteRefOfBinaryExpr(parent, node, replacer) - case BoolVal: - return a.rewriteBoolVal(parent, node, replacer) - case *CaseExpr: - return a.rewriteRefOfCaseExpr(parent, node, replacer) - case *ColName: - return a.rewriteRefOfColName(parent, node, replacer) - case *CollateExpr: - return a.rewriteRefOfCollateExpr(parent, node, replacer) - case *ComparisonExpr: - return a.rewriteRefOfComparisonExpr(parent, node, replacer) - case *ConvertExpr: - return a.rewriteRefOfConvertExpr(parent, node, replacer) - case *ConvertUsingExpr: - return a.rewriteRefOfConvertUsingExpr(parent, node, replacer) - case *CurTimeFuncExpr: - return a.rewriteRefOfCurTimeFuncExpr(parent, node, replacer) - case *Default: - return a.rewriteRefOfDefault(parent, node, replacer) - case *ExistsExpr: - return a.rewriteRefOfExistsExpr(parent, node, replacer) - case *FuncExpr: - return a.rewriteRefOfFuncExpr(parent, node, replacer) - case *GroupConcatExpr: - return a.rewriteRefOfGroupConcatExpr(parent, node, replacer) - case *IntervalExpr: - return a.rewriteRefOfIntervalExpr(parent, node, replacer) - case *IsExpr: - return a.rewriteRefOfIsExpr(parent, node, replacer) - case ListArg: - return a.rewriteListArg(parent, node, replacer) - case *Literal: - return a.rewriteRefOfLiteral(parent, node, replacer) - case *MatchExpr: - return a.rewriteRefOfMatchExpr(parent, node, replacer) - case *NotExpr: - return a.rewriteRefOfNotExpr(parent, node, replacer) - case *NullVal: - return a.rewriteRefOfNullVal(parent, node, replacer) - case *OrExpr: - return a.rewriteRefOfOrExpr(parent, node, replacer) - case *RangeCond: - return a.rewriteRefOfRangeCond(parent, node, replacer) - case *Subquery: - return a.rewriteRefOfSubquery(parent, node, replacer) - case *SubstrExpr: - return a.rewriteRefOfSubstrExpr(parent, node, replacer) - case *TimestampFuncExpr: - return a.rewriteRefOfTimestampFuncExpr(parent, node, replacer) - case *UnaryExpr: - return a.rewriteRefOfUnaryExpr(parent, node, replacer) - case ValTuple: - return a.rewriteValTuple(parent, node, replacer) - case *ValuesFuncExpr: - return a.rewriteRefOfValuesFuncExpr(parent, node, replacer) - case *XorExpr: - return a.rewriteRefOfXorExpr(parent, node, replacer) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// EqualsInsertRows does deep equals between the two objects. -func EqualsInsertRows(inA, inB InsertRows) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteRefOfParenSelect(parent SQLNode, node *ParenSelect, replacer replacerFunc) error { + if node == nil { + return nil } - if inA == nil || inB == nil { - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - switch a := inA.(type) { - case *ParenSelect: - b, ok := inB.(*ParenSelect) - if !ok { - return false - } - return EqualsRefOfParenSelect(a, b) - case *Select: - b, ok := inB.(*Select) - if !ok { - return false - } - return EqualsRefOfSelect(a, b) - case *Union: - b, ok := inB.(*Union) - if !ok { - return false - } - return EqualsRefOfUnion(a, b) - case Values: - b, ok := inB.(Values) - if !ok { - return false - } - return EqualsValues(a, b) - default: - // this should never happen - return false + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + parent.(*ParenSelect).Select = newNode.(SelectStatement) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort } + return nil } - -// CloneInsertRows creates a deep clone of the input. -func CloneInsertRows(in InsertRows) InsertRows { - if in == nil { +func (a *application) rewriteRefOfParenTableExpr(parent SQLNode, node *ParenTableExpr, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *ParenSelect: - return CloneRefOfParenSelect(in) - case *Select: - return CloneRefOfSelect(in) - case *Union: - return CloneRefOfUnion(in) - case Values: - return CloneValues(in) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteTableExprs(node, node.Exprs, func(newNode, parent SQLNode) { + parent.(*ParenTableExpr).Exprs = newNode.(TableExprs) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// VisitInsertRows will visit all parts of the AST -func VisitInsertRows(in InsertRows, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfPartitionDefinition(parent SQLNode, node *PartitionDefinition, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *ParenSelect: - return VisitRefOfParenSelect(in, f) - case *Select: - return VisitRefOfSelect(in, f) - case *Union: - return VisitRefOfUnion(in, f) - case Values: - return VisitValues(in, f) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*PartitionDefinition).Name = newNode.(ColIdent) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*PartitionDefinition).Limit = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil +} +func (a *application) rewriteRefOfPartitionSpec(parent SQLNode, node *PartitionSpec, replacer replacerFunc) error { + if node == nil { return nil } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewritePartitions(node, node.Names, func(newNode, parent SQLNode) { + parent.(*PartitionSpec).Names = newNode.(Partitions) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLiteral(node, node.Number, func(newNode, parent SQLNode) { + parent.(*PartitionSpec).Number = newNode.(*Literal) + }); errF != nil { + return errF + } + if errF := a.rewriteTableName(node, node.TableName, func(newNode, parent SQLNode) { + parent.(*PartitionSpec).TableName = newNode.(TableName) + }); errF != nil { + return errF + } + for i, el := range node.Definitions { + if errF := a.rewriteRefOfPartitionDefinition(node, el, func(newNode, parent SQLNode) { + parent.(*PartitionSpec).Definitions[i] = newNode.(*PartitionDefinition) + }); errF != nil { + return errF + } + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// rewriteInsertRows is part of the Rewrite implementation -func (a *application) rewriteInsertRows(parent SQLNode, node InsertRows, replacer replacerFunc) error { +func (a *application) rewriteRefOfRangeCond(parent SQLNode, node *RangeCond, replacer replacerFunc) error { if node == nil { return nil } - switch node := node.(type) { - case *ParenSelect: - return a.rewriteRefOfParenSelect(parent, node, replacer) - case *Select: - return a.rewriteRefOfSelect(parent, node, replacer) - case *Union: - return a.rewriteRefOfUnion(parent, node, replacer) - case Values: - return a.rewriteValues(parent, node, replacer) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } -} - -// EqualsSelectExpr does deep equals between the two objects. -func EqualsSelectExpr(inA, inB SelectExpr) bool { - if inA == nil && inB == nil { - return true + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*RangeCond).Left = newNode.(Expr) + }); errF != nil { + return errF } - if inA == nil || inB == nil { - return false + if errF := a.rewriteExpr(node, node.From, func(newNode, parent SQLNode) { + parent.(*RangeCond).From = newNode.(Expr) + }); errF != nil { + return errF } - switch a := inA.(type) { - case *AliasedExpr: - b, ok := inB.(*AliasedExpr) - if !ok { - return false - } - return EqualsRefOfAliasedExpr(a, b) - case *Nextval: - b, ok := inB.(*Nextval) - if !ok { - return false - } - return EqualsRefOfNextval(a, b) - case *StarExpr: - b, ok := inB.(*StarExpr) - if !ok { - return false - } - return EqualsRefOfStarExpr(a, b) - default: - // this should never happen - return false + if errF := a.rewriteExpr(node, node.To, func(newNode, parent SQLNode) { + parent.(*RangeCond).To = newNode.(Expr) + }); errF != nil { + return errF } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// CloneSelectExpr creates a deep clone of the input. -func CloneSelectExpr(in SelectExpr) SelectExpr { - if in == nil { +func (a *application) rewriteRefOfRelease(parent SQLNode, node *Release, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AliasedExpr: - return CloneRefOfAliasedExpr(in) - case *Nextval: - return CloneRefOfNextval(in) - case *StarExpr: - return CloneRefOfStarExpr(in) - default: - // this should never happen - return nil + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } -} - -// VisitSelectExpr will visit all parts of the AST -func VisitSelectExpr(in SelectExpr, f Visit) error { - if in == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - switch in := in.(type) { - case *AliasedExpr: - return VisitRefOfAliasedExpr(in, f) - case *Nextval: - return VisitRefOfNextval(in, f) - case *StarExpr: - return VisitRefOfStarExpr(in, f) - default: - // this should never happen - return nil + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*Release).Name = newNode.(ColIdent) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort } + return nil } - -// rewriteSelectExpr is part of the Rewrite implementation -func (a *application) rewriteSelectExpr(parent SQLNode, node SelectExpr, replacer replacerFunc) error { +func (a *application) rewriteRefOfRenameIndex(parent SQLNode, node *RenameIndex, replacer replacerFunc) error { if node == nil { return nil } - switch node := node.(type) { - case *AliasedExpr: - return a.rewriteRefOfAliasedExpr(parent, node, replacer) - case *Nextval: - return a.rewriteRefOfNextval(parent, node, replacer) - case *StarExpr: - return a.rewriteRefOfStarExpr(parent, node, replacer) - default: - // this should never happen - return nil - } -} - -// EqualsSelectStatement does deep equals between the two objects. -func EqualsSelectStatement(inA, inB SelectStatement) bool { - if inA == nil && inB == nil { - return true + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if inA == nil || inB == nil { - return false + if a.pre != nil && !a.pre(&cur) { + return nil } - switch a := inA.(type) { - case *ParenSelect: - b, ok := inB.(*ParenSelect) - if !ok { - return false - } - return EqualsRefOfParenSelect(a, b) - case *Select: - b, ok := inB.(*Select) - if !ok { - return false - } - return EqualsRefOfSelect(a, b) - case *Union: - b, ok := inB.(*Union) - if !ok { - return false - } - return EqualsRefOfUnion(a, b) - default: - // this should never happen - return false + if a.post != nil && !a.post(&cur) { + return errAbort } + return nil } - -// CloneSelectStatement creates a deep clone of the input. -func CloneSelectStatement(in SelectStatement) SelectStatement { - if in == nil { +func (a *application) rewriteRefOfRenameTable(parent SQLNode, node *RenameTable, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *ParenSelect: - return CloneRefOfParenSelect(in) - case *Select: - return CloneRefOfSelect(in) - case *Union: - return CloneRefOfUnion(in) - default: - // this should never happen - return nil + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } -} - -// VisitSelectStatement will visit all parts of the AST -func VisitSelectStatement(in SelectStatement, f Visit) error { - if in == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - switch in := in.(type) { - case *ParenSelect: - return VisitRefOfParenSelect(in, f) - case *Select: - return VisitRefOfSelect(in, f) - case *Union: - return VisitRefOfUnion(in, f) - default: - // this should never happen - return nil + if a.post != nil && !a.post(&cur) { + return errAbort } + return nil } - -// rewriteSelectStatement is part of the Rewrite implementation -func (a *application) rewriteSelectStatement(parent SQLNode, node SelectStatement, replacer replacerFunc) error { +func (a *application) rewriteRefOfRenameTableName(parent SQLNode, node *RenameTableName, replacer replacerFunc) error { if node == nil { return nil } - switch node := node.(type) { - case *ParenSelect: - return a.rewriteRefOfParenSelect(parent, node, replacer) - case *Select: - return a.rewriteRefOfSelect(parent, node, replacer) - case *Union: - return a.rewriteRefOfUnion(parent, node, replacer) - default: - // this should never happen - return nil + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } -} - -// EqualsShowInternal does deep equals between the two objects. -func EqualsShowInternal(inA, inB ShowInternal) bool { - if inA == nil && inB == nil { - return true + if a.pre != nil && !a.pre(&cur) { + return nil } - if inA == nil || inB == nil { - return false + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*RenameTableName).Table = newNode.(TableName) + }); errF != nil { + return errF } - switch a := inA.(type) { - case *ShowBasic: - b, ok := inB.(*ShowBasic) - if !ok { - return false - } - return EqualsRefOfShowBasic(a, b) - case *ShowCreate: - b, ok := inB.(*ShowCreate) - if !ok { - return false - } - return EqualsRefOfShowCreate(a, b) - case *ShowLegacy: - b, ok := inB.(*ShowLegacy) - if !ok { - return false - } - return EqualsRefOfShowLegacy(a, b) - default: - // this should never happen - return false + if a.post != nil && !a.post(&cur) { + return errAbort } + return nil } - -// CloneShowInternal creates a deep clone of the input. -func CloneShowInternal(in ShowInternal) ShowInternal { - if in == nil { +func (a *application) rewriteRefOfRevertMigration(parent SQLNode, node *RevertMigration, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *ShowBasic: - return CloneRefOfShowBasic(in) - case *ShowCreate: - return CloneRefOfShowCreate(in) - case *ShowLegacy: - return CloneRefOfShowLegacy(in) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// VisitShowInternal will visit all parts of the AST -func VisitShowInternal(in ShowInternal, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfRollback(parent SQLNode, node *Rollback, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *ShowBasic: - return VisitRefOfShowBasic(in, f) - case *ShowCreate: - return VisitRefOfShowCreate(in, f) - case *ShowLegacy: - return VisitRefOfShowLegacy(in, f) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// rewriteShowInternal is part of the Rewrite implementation -func (a *application) rewriteShowInternal(parent SQLNode, node ShowInternal, replacer replacerFunc) error { +func (a *application) rewriteRefOfSRollback(parent SQLNode, node *SRollback, replacer replacerFunc) error { if node == nil { return nil } - switch node := node.(type) { - case *ShowBasic: - return a.rewriteRefOfShowBasic(parent, node, replacer) - case *ShowCreate: - return a.rewriteRefOfShowCreate(parent, node, replacer) - case *ShowLegacy: - return a.rewriteRefOfShowLegacy(parent, node, replacer) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*SRollback).Name = newNode.(ColIdent) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// EqualsSimpleTableExpr does deep equals between the two objects. -func EqualsSimpleTableExpr(inA, inB SimpleTableExpr) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteRefOfSavepoint(parent SQLNode, node *Savepoint, replacer replacerFunc) error { + if node == nil { + return nil } - if inA == nil || inB == nil { - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - switch a := inA.(type) { - case *DerivedTable: - b, ok := inB.(*DerivedTable) - if !ok { - return false - } - return EqualsRefOfDerivedTable(a, b) - case TableName: - b, ok := inB.(TableName) - if !ok { - return false - } - return EqualsTableName(a, b) - default: - // this should never happen - return false + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*Savepoint).Name = newNode.(ColIdent) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort } + return nil } - -// CloneSimpleTableExpr creates a deep clone of the input. -func CloneSimpleTableExpr(in SimpleTableExpr) SimpleTableExpr { - if in == nil { +func (a *application) rewriteRefOfSelect(parent SQLNode, node *Select, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *DerivedTable: - return CloneRefOfDerivedTable(in) - case TableName: - return CloneTableName(in) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Select).Comments = newNode.(Comments) + }); errF != nil { + return errF + } + if errF := a.rewriteSelectExprs(node, node.SelectExprs, func(newNode, parent SQLNode) { + parent.(*Select).SelectExprs = newNode.(SelectExprs) + }); errF != nil { + return errF + } + if errF := a.rewriteTableExprs(node, node.From, func(newNode, parent SQLNode) { + parent.(*Select).From = newNode.(TableExprs) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { + parent.(*Select).Where = newNode.(*Where) + }); errF != nil { + return errF + } + if errF := a.rewriteGroupBy(node, node.GroupBy, func(newNode, parent SQLNode) { + parent.(*Select).GroupBy = newNode.(GroupBy) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfWhere(node, node.Having, func(newNode, parent SQLNode) { + parent.(*Select).Having = newNode.(*Where) + }); errF != nil { + return errF + } + if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + parent.(*Select).OrderBy = newNode.(OrderBy) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*Select).Limit = newNode.(*Limit) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfSelectInto(node, node.Into, func(newNode, parent SQLNode) { + parent.(*Select).Into = newNode.(*SelectInto) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// VisitSimpleTableExpr will visit all parts of the AST -func VisitSimpleTableExpr(in SimpleTableExpr, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfSelectInto(parent SQLNode, node *SelectInto, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *DerivedTable: - return VisitRefOfDerivedTable(in, f) - case TableName: - return VisitTableName(in, f) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// rewriteSimpleTableExpr is part of the Rewrite implementation -func (a *application) rewriteSimpleTableExpr(parent SQLNode, node SimpleTableExpr, replacer replacerFunc) error { +func (a *application) rewriteRefOfSet(parent SQLNode, node *Set, replacer replacerFunc) error { if node == nil { return nil } - switch node := node.(type) { - case *DerivedTable: - return a.rewriteRefOfDerivedTable(parent, node, replacer) - case TableName: - return a.rewriteTableName(parent, node, replacer) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Set).Comments = newNode.(Comments) + }); errF != nil { + return errF + } + if errF := a.rewriteSetExprs(node, node.Exprs, func(newNode, parent SQLNode) { + parent.(*Set).Exprs = newNode.(SetExprs) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// EqualsStatement does deep equals between the two objects. -func EqualsStatement(inA, inB Statement) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteRefOfSetExpr(parent SQLNode, node *SetExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*SetExpr).Name = newNode.(ColIdent) + }); errF != nil { + return errF } - if inA == nil || inB == nil { - return false + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*SetExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF } - switch a := inA.(type) { - case *AlterDatabase: - b, ok := inB.(*AlterDatabase) - if !ok { - return false - } - return EqualsRefOfAlterDatabase(a, b) - case *AlterMigration: - b, ok := inB.(*AlterMigration) - if !ok { - return false - } - return EqualsRefOfAlterMigration(a, b) - case *AlterTable: - b, ok := inB.(*AlterTable) - if !ok { - return false - } - return EqualsRefOfAlterTable(a, b) - case *AlterView: - b, ok := inB.(*AlterView) - if !ok { - return false - } - return EqualsRefOfAlterView(a, b) - case *AlterVschema: - b, ok := inB.(*AlterVschema) - if !ok { - return false - } - return EqualsRefOfAlterVschema(a, b) - case *Begin: - b, ok := inB.(*Begin) - if !ok { - return false - } - return EqualsRefOfBegin(a, b) - case *CallProc: - b, ok := inB.(*CallProc) - if !ok { - return false - } - return EqualsRefOfCallProc(a, b) - case *Commit: - b, ok := inB.(*Commit) - if !ok { - return false - } - return EqualsRefOfCommit(a, b) - case *CreateDatabase: - b, ok := inB.(*CreateDatabase) - if !ok { - return false - } - return EqualsRefOfCreateDatabase(a, b) - case *CreateTable: - b, ok := inB.(*CreateTable) - if !ok { - return false - } - return EqualsRefOfCreateTable(a, b) - case *CreateView: - b, ok := inB.(*CreateView) - if !ok { - return false - } - return EqualsRefOfCreateView(a, b) - case *Delete: - b, ok := inB.(*Delete) - if !ok { - return false - } - return EqualsRefOfDelete(a, b) - case *DropDatabase: - b, ok := inB.(*DropDatabase) - if !ok { - return false - } - return EqualsRefOfDropDatabase(a, b) - case *DropTable: - b, ok := inB.(*DropTable) - if !ok { - return false - } - return EqualsRefOfDropTable(a, b) - case *DropView: - b, ok := inB.(*DropView) - if !ok { - return false - } - return EqualsRefOfDropView(a, b) - case *ExplainStmt: - b, ok := inB.(*ExplainStmt) - if !ok { - return false - } - return EqualsRefOfExplainStmt(a, b) - case *ExplainTab: - b, ok := inB.(*ExplainTab) - if !ok { - return false - } - return EqualsRefOfExplainTab(a, b) - case *Flush: - b, ok := inB.(*Flush) - if !ok { - return false - } - return EqualsRefOfFlush(a, b) - case *Insert: - b, ok := inB.(*Insert) - if !ok { - return false - } - return EqualsRefOfInsert(a, b) - case *Load: - b, ok := inB.(*Load) - if !ok { - return false - } - return EqualsRefOfLoad(a, b) - case *LockTables: - b, ok := inB.(*LockTables) - if !ok { - return false - } - return EqualsRefOfLockTables(a, b) - case *OtherAdmin: - b, ok := inB.(*OtherAdmin) - if !ok { - return false - } - return EqualsRefOfOtherAdmin(a, b) - case *OtherRead: - b, ok := inB.(*OtherRead) - if !ok { - return false - } - return EqualsRefOfOtherRead(a, b) - case *ParenSelect: - b, ok := inB.(*ParenSelect) - if !ok { - return false - } - return EqualsRefOfParenSelect(a, b) - case *Release: - b, ok := inB.(*Release) - if !ok { - return false - } - return EqualsRefOfRelease(a, b) - case *RenameTable: - b, ok := inB.(*RenameTable) - if !ok { - return false - } - return EqualsRefOfRenameTable(a, b) - case *RevertMigration: - b, ok := inB.(*RevertMigration) - if !ok { - return false - } - return EqualsRefOfRevertMigration(a, b) - case *Rollback: - b, ok := inB.(*Rollback) - if !ok { - return false - } - return EqualsRefOfRollback(a, b) - case *SRollback: - b, ok := inB.(*SRollback) - if !ok { - return false - } - return EqualsRefOfSRollback(a, b) - case *Savepoint: - b, ok := inB.(*Savepoint) - if !ok { - return false - } - return EqualsRefOfSavepoint(a, b) - case *Select: - b, ok := inB.(*Select) - if !ok { - return false - } - return EqualsRefOfSelect(a, b) - case *Set: - b, ok := inB.(*Set) - if !ok { - return false - } - return EqualsRefOfSet(a, b) - case *SetTransaction: - b, ok := inB.(*SetTransaction) - if !ok { - return false - } - return EqualsRefOfSetTransaction(a, b) - case *Show: - b, ok := inB.(*Show) - if !ok { - return false - } - return EqualsRefOfShow(a, b) - case *Stream: - b, ok := inB.(*Stream) - if !ok { - return false - } - return EqualsRefOfStream(a, b) - case *TruncateTable: - b, ok := inB.(*TruncateTable) - if !ok { - return false - } - return EqualsRefOfTruncateTable(a, b) - case *Union: - b, ok := inB.(*Union) - if !ok { - return false - } - return EqualsRefOfUnion(a, b) - case *UnlockTables: - b, ok := inB.(*UnlockTables) - if !ok { - return false - } - return EqualsRefOfUnlockTables(a, b) - case *Update: - b, ok := inB.(*Update) - if !ok { - return false - } - return EqualsRefOfUpdate(a, b) - case *Use: - b, ok := inB.(*Use) - if !ok { - return false - } - return EqualsRefOfUse(a, b) - case *VStream: - b, ok := inB.(*VStream) - if !ok { - return false + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil +} +func (a *application) rewriteRefOfSetTransaction(parent SQLNode, node *SetTransaction, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteSQLNode(node, node.SQLNode, func(newNode, parent SQLNode) { + parent.(*SetTransaction).SQLNode = newNode.(SQLNode) + }); errF != nil { + return errF + } + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*SetTransaction).Comments = newNode.(Comments) + }); errF != nil { + return errF + } + for i, el := range node.Characteristics { + if errF := a.rewriteCharacteristic(node, el, func(newNode, parent SQLNode) { + parent.(*SetTransaction).Characteristics[i] = newNode.(Characteristic) + }); errF != nil { + return errF } - return EqualsRefOfVStream(a, b) - default: - // this should never happen - return false } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// CloneStatement creates a deep clone of the input. -func CloneStatement(in Statement) Statement { - if in == nil { +func (a *application) rewriteRefOfShow(parent SQLNode, node *Show, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AlterDatabase: - return CloneRefOfAlterDatabase(in) - case *AlterMigration: - return CloneRefOfAlterMigration(in) - case *AlterTable: - return CloneRefOfAlterTable(in) - case *AlterView: - return CloneRefOfAlterView(in) - case *AlterVschema: - return CloneRefOfAlterVschema(in) - case *Begin: - return CloneRefOfBegin(in) - case *CallProc: - return CloneRefOfCallProc(in) - case *Commit: - return CloneRefOfCommit(in) - case *CreateDatabase: - return CloneRefOfCreateDatabase(in) - case *CreateTable: - return CloneRefOfCreateTable(in) - case *CreateView: - return CloneRefOfCreateView(in) - case *Delete: - return CloneRefOfDelete(in) - case *DropDatabase: - return CloneRefOfDropDatabase(in) - case *DropTable: - return CloneRefOfDropTable(in) - case *DropView: - return CloneRefOfDropView(in) - case *ExplainStmt: - return CloneRefOfExplainStmt(in) - case *ExplainTab: - return CloneRefOfExplainTab(in) - case *Flush: - return CloneRefOfFlush(in) - case *Insert: - return CloneRefOfInsert(in) - case *Load: - return CloneRefOfLoad(in) - case *LockTables: - return CloneRefOfLockTables(in) - case *OtherAdmin: - return CloneRefOfOtherAdmin(in) - case *OtherRead: - return CloneRefOfOtherRead(in) - case *ParenSelect: - return CloneRefOfParenSelect(in) - case *Release: - return CloneRefOfRelease(in) - case *RenameTable: - return CloneRefOfRenameTable(in) - case *RevertMigration: - return CloneRefOfRevertMigration(in) - case *Rollback: - return CloneRefOfRollback(in) - case *SRollback: - return CloneRefOfSRollback(in) - case *Savepoint: - return CloneRefOfSavepoint(in) - case *Select: - return CloneRefOfSelect(in) - case *Set: - return CloneRefOfSet(in) - case *SetTransaction: - return CloneRefOfSetTransaction(in) - case *Show: - return CloneRefOfShow(in) - case *Stream: - return CloneRefOfStream(in) - case *TruncateTable: - return CloneRefOfTruncateTable(in) - case *Union: - return CloneRefOfUnion(in) - case *UnlockTables: - return CloneRefOfUnlockTables(in) - case *Update: - return CloneRefOfUpdate(in) - case *Use: - return CloneRefOfUse(in) - case *VStream: - return CloneRefOfVStream(in) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteShowInternal(node, node.Internal, func(newNode, parent SQLNode) { + parent.(*Show).Internal = newNode.(ShowInternal) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil +} +func (a *application) rewriteRefOfShowBasic(parent SQLNode, node *ShowBasic, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteTableName(node, node.Tbl, func(newNode, parent SQLNode) { + parent.(*ShowBasic).Tbl = newNode.(TableName) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfShowFilter(node, node.Filter, func(newNode, parent SQLNode) { + parent.(*ShowBasic).Filter = newNode.(*ShowFilter) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil +} +func (a *application) rewriteRefOfShowCreate(parent SQLNode, node *ShowCreate, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteTableName(node, node.Op, func(newNode, parent SQLNode) { + parent.(*ShowCreate).Op = newNode.(TableName) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil +} +func (a *application) rewriteRefOfShowFilter(parent SQLNode, node *ShowFilter, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteExpr(node, node.Filter, func(newNode, parent SQLNode) { + parent.(*ShowFilter).Filter = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil +} +func (a *application) rewriteRefOfShowLegacy(parent SQLNode, node *ShowLegacy, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteTableName(node, node.OnTable, func(newNode, parent SQLNode) { + parent.(*ShowLegacy).OnTable = newNode.(TableName) + }); errF != nil { + return errF + } + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*ShowLegacy).Table = newNode.(TableName) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.ShowCollationFilterOpt, func(newNode, parent SQLNode) { + parent.(*ShowLegacy).ShowCollationFilterOpt = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil +} +func (a *application) rewriteRefOfStarExpr(parent SQLNode, node *StarExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteTableName(node, node.TableName, func(newNode, parent SQLNode) { + parent.(*StarExpr).TableName = newNode.(TableName) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// VisitStatement will visit all parts of the AST -func VisitStatement(in Statement, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfStream(parent SQLNode, node *Stream, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AlterDatabase: - return VisitRefOfAlterDatabase(in, f) - case *AlterMigration: - return VisitRefOfAlterMigration(in, f) - case *AlterTable: - return VisitRefOfAlterTable(in, f) - case *AlterView: - return VisitRefOfAlterView(in, f) - case *AlterVschema: - return VisitRefOfAlterVschema(in, f) - case *Begin: - return VisitRefOfBegin(in, f) - case *CallProc: - return VisitRefOfCallProc(in, f) - case *Commit: - return VisitRefOfCommit(in, f) - case *CreateDatabase: - return VisitRefOfCreateDatabase(in, f) - case *CreateTable: - return VisitRefOfCreateTable(in, f) - case *CreateView: - return VisitRefOfCreateView(in, f) - case *Delete: - return VisitRefOfDelete(in, f) - case *DropDatabase: - return VisitRefOfDropDatabase(in, f) - case *DropTable: - return VisitRefOfDropTable(in, f) - case *DropView: - return VisitRefOfDropView(in, f) - case *ExplainStmt: - return VisitRefOfExplainStmt(in, f) - case *ExplainTab: - return VisitRefOfExplainTab(in, f) - case *Flush: - return VisitRefOfFlush(in, f) - case *Insert: - return VisitRefOfInsert(in, f) - case *Load: - return VisitRefOfLoad(in, f) - case *LockTables: - return VisitRefOfLockTables(in, f) - case *OtherAdmin: - return VisitRefOfOtherAdmin(in, f) - case *OtherRead: - return VisitRefOfOtherRead(in, f) - case *ParenSelect: - return VisitRefOfParenSelect(in, f) - case *Release: - return VisitRefOfRelease(in, f) - case *RenameTable: - return VisitRefOfRenameTable(in, f) - case *RevertMigration: - return VisitRefOfRevertMigration(in, f) - case *Rollback: - return VisitRefOfRollback(in, f) - case *SRollback: - return VisitRefOfSRollback(in, f) - case *Savepoint: - return VisitRefOfSavepoint(in, f) - case *Select: - return VisitRefOfSelect(in, f) - case *Set: - return VisitRefOfSet(in, f) - case *SetTransaction: - return VisitRefOfSetTransaction(in, f) - case *Show: - return VisitRefOfShow(in, f) - case *Stream: - return VisitRefOfStream(in, f) - case *TruncateTable: - return VisitRefOfTruncateTable(in, f) - case *Union: - return VisitRefOfUnion(in, f) - case *UnlockTables: - return VisitRefOfUnlockTables(in, f) - case *Update: - return VisitRefOfUpdate(in, f) - case *Use: - return VisitRefOfUse(in, f) - case *VStream: - return VisitRefOfVStream(in, f) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Stream).Comments = newNode.(Comments) + }); errF != nil { + return errF + } + if errF := a.rewriteSelectExpr(node, node.SelectExpr, func(newNode, parent SQLNode) { + parent.(*Stream).SelectExpr = newNode.(SelectExpr) + }); errF != nil { + return errF + } + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*Stream).Table = newNode.(TableName) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil +} +func (a *application) rewriteRefOfSubquery(parent SQLNode, node *Subquery, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + parent.(*Subquery).Select = newNode.(SelectStatement) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// rewriteStatement is part of the Rewrite implementation -func (a *application) rewriteStatement(parent SQLNode, node Statement, replacer replacerFunc) error { +func (a *application) rewriteRefOfSubstrExpr(parent SQLNode, node *SubstrExpr, replacer replacerFunc) error { if node == nil { return nil } - switch node := node.(type) { - case *AlterDatabase: - return a.rewriteRefOfAlterDatabase(parent, node, replacer) - case *AlterMigration: - return a.rewriteRefOfAlterMigration(parent, node, replacer) - case *AlterTable: - return a.rewriteRefOfAlterTable(parent, node, replacer) - case *AlterView: - return a.rewriteRefOfAlterView(parent, node, replacer) - case *AlterVschema: - return a.rewriteRefOfAlterVschema(parent, node, replacer) - case *Begin: - return a.rewriteRefOfBegin(parent, node, replacer) - case *CallProc: - return a.rewriteRefOfCallProc(parent, node, replacer) - case *Commit: - return a.rewriteRefOfCommit(parent, node, replacer) - case *CreateDatabase: - return a.rewriteRefOfCreateDatabase(parent, node, replacer) - case *CreateTable: - return a.rewriteRefOfCreateTable(parent, node, replacer) - case *CreateView: - return a.rewriteRefOfCreateView(parent, node, replacer) - case *Delete: - return a.rewriteRefOfDelete(parent, node, replacer) - case *DropDatabase: - return a.rewriteRefOfDropDatabase(parent, node, replacer) - case *DropTable: - return a.rewriteRefOfDropTable(parent, node, replacer) - case *DropView: - return a.rewriteRefOfDropView(parent, node, replacer) - case *ExplainStmt: - return a.rewriteRefOfExplainStmt(parent, node, replacer) - case *ExplainTab: - return a.rewriteRefOfExplainTab(parent, node, replacer) - case *Flush: - return a.rewriteRefOfFlush(parent, node, replacer) - case *Insert: - return a.rewriteRefOfInsert(parent, node, replacer) - case *Load: - return a.rewriteRefOfLoad(parent, node, replacer) - case *LockTables: - return a.rewriteRefOfLockTables(parent, node, replacer) - case *OtherAdmin: - return a.rewriteRefOfOtherAdmin(parent, node, replacer) - case *OtherRead: - return a.rewriteRefOfOtherRead(parent, node, replacer) - case *ParenSelect: - return a.rewriteRefOfParenSelect(parent, node, replacer) - case *Release: - return a.rewriteRefOfRelease(parent, node, replacer) - case *RenameTable: - return a.rewriteRefOfRenameTable(parent, node, replacer) - case *RevertMigration: - return a.rewriteRefOfRevertMigration(parent, node, replacer) - case *Rollback: - return a.rewriteRefOfRollback(parent, node, replacer) - case *SRollback: - return a.rewriteRefOfSRollback(parent, node, replacer) - case *Savepoint: - return a.rewriteRefOfSavepoint(parent, node, replacer) - case *Select: - return a.rewriteRefOfSelect(parent, node, replacer) - case *Set: - return a.rewriteRefOfSet(parent, node, replacer) - case *SetTransaction: - return a.rewriteRefOfSetTransaction(parent, node, replacer) - case *Show: - return a.rewriteRefOfShow(parent, node, replacer) - case *Stream: - return a.rewriteRefOfStream(parent, node, replacer) - case *TruncateTable: - return a.rewriteRefOfTruncateTable(parent, node, replacer) - case *Union: - return a.rewriteRefOfUnion(parent, node, replacer) - case *UnlockTables: - return a.rewriteRefOfUnlockTables(parent, node, replacer) - case *Update: - return a.rewriteRefOfUpdate(parent, node, replacer) - case *Use: - return a.rewriteRefOfUse(parent, node, replacer) - case *VStream: - return a.rewriteRefOfVStream(parent, node, replacer) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { + parent.(*SubstrExpr).Name = newNode.(*ColName) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLiteral(node, node.StrVal, func(newNode, parent SQLNode) { + parent.(*SubstrExpr).StrVal = newNode.(*Literal) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.From, func(newNode, parent SQLNode) { + parent.(*SubstrExpr).From = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.To, func(newNode, parent SQLNode) { + parent.(*SubstrExpr).To = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil +} +func (a *application) rewriteRefOfTableIdent(parent SQLNode, node *TableIdent, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// EqualsTableExpr does deep equals between the two objects. -func EqualsTableExpr(inA, inB TableExpr) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteRefOfTableName(parent SQLNode, node *TableName, replacer replacerFunc) error { + if node == nil { + return nil } - if inA == nil || inB == nil { - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - switch a := inA.(type) { - case *AliasedTableExpr: - b, ok := inB.(*AliasedTableExpr) - if !ok { - return false + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*TableName).Name = newNode.(TableIdent) + }); errF != nil { + return errF + } + if errF := a.rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { + parent.(*TableName).Qualifier = newNode.(TableIdent) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil +} +func (a *application) rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + for i, el := range node.Columns { + if errF := a.rewriteRefOfColumnDefinition(node, el, func(newNode, parent SQLNode) { + parent.(*TableSpec).Columns[i] = newNode.(*ColumnDefinition) + }); errF != nil { + return errF } - return EqualsRefOfAliasedTableExpr(a, b) - case *JoinTableExpr: - b, ok := inB.(*JoinTableExpr) - if !ok { - return false + } + for i, el := range node.Indexes { + if errF := a.rewriteRefOfIndexDefinition(node, el, func(newNode, parent SQLNode) { + parent.(*TableSpec).Indexes[i] = newNode.(*IndexDefinition) + }); errF != nil { + return errF } - return EqualsRefOfJoinTableExpr(a, b) - case *ParenTableExpr: - b, ok := inB.(*ParenTableExpr) - if !ok { - return false + } + for i, el := range node.Constraints { + if errF := a.rewriteRefOfConstraintDefinition(node, el, func(newNode, parent SQLNode) { + parent.(*TableSpec).Constraints[i] = newNode.(*ConstraintDefinition) + }); errF != nil { + return errF } - return EqualsRefOfParenTableExpr(a, b) - default: - // this should never happen - return false } + if errF := a.rewriteTableOptions(node, node.Options, func(newNode, parent SQLNode) { + parent.(*TableSpec).Options = newNode.(TableOptions) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// CloneTableExpr creates a deep clone of the input. -func CloneTableExpr(in TableExpr) TableExpr { - if in == nil { +func (a *application) rewriteRefOfTablespaceOperation(parent SQLNode, node *TablespaceOperation, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AliasedTableExpr: - return CloneRefOfAliasedTableExpr(in) - case *JoinTableExpr: - return CloneRefOfJoinTableExpr(in) - case *ParenTableExpr: - return CloneRefOfParenTableExpr(in) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// VisitTableExpr will visit all parts of the AST -func VisitTableExpr(in TableExpr, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfTimestampFuncExpr(parent SQLNode, node *TimestampFuncExpr, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AliasedTableExpr: - return VisitRefOfAliasedTableExpr(in, f) - case *JoinTableExpr: - return VisitRefOfJoinTableExpr(in, f) - case *ParenTableExpr: - return VisitRefOfParenTableExpr(in, f) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteExpr(node, node.Expr1, func(newNode, parent SQLNode) { + parent.(*TimestampFuncExpr).Expr1 = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Expr2, func(newNode, parent SQLNode) { + parent.(*TimestampFuncExpr).Expr2 = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// rewriteTableExpr is part of the Rewrite implementation -func (a *application) rewriteTableExpr(parent SQLNode, node TableExpr, replacer replacerFunc) error { +func (a *application) rewriteRefOfTruncateTable(parent SQLNode, node *TruncateTable, replacer replacerFunc) error { if node == nil { return nil } - switch node := node.(type) { - case *AliasedTableExpr: - return a.rewriteRefOfAliasedTableExpr(parent, node, replacer) - case *JoinTableExpr: - return a.rewriteRefOfJoinTableExpr(parent, node, replacer) - case *ParenTableExpr: - return a.rewriteRefOfParenTableExpr(parent, node, replacer) - default: - // this should never happen + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*TruncateTable).Table = newNode.(TableName) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// VisitAccessMode will visit all parts of the AST -func VisitAccessMode(in AccessMode, f Visit) error { - _, err := f(in) - return err +func (a *application) rewriteRefOfUnaryExpr(parent SQLNode, node *UnaryExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*UnaryExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// rewriteAccessMode is part of the Rewrite implementation -func (a *application) rewriteAccessMode(parent SQLNode, node AccessMode, replacer replacerFunc) error { +func (a *application) rewriteRefOfUnion(parent SQLNode, node *Union, replacer replacerFunc) error { + if node == nil { + return nil + } cur := Cursor{ node: node, parent: parent, replacer: replacer, } - if a.pre != nil && !a.pre(&cur) { - return nil + if a.pre != nil && !a.pre(&cur) { + return nil + } + if errF := a.rewriteSelectStatement(node, node.FirstStatement, func(newNode, parent SQLNode) { + parent.(*Union).FirstStatement = newNode.(SelectStatement) + }); errF != nil { + return errF + } + for i, el := range node.UnionSelects { + if errF := a.rewriteRefOfUnionSelect(node, el, func(newNode, parent SQLNode) { + parent.(*Union).UnionSelects[i] = newNode.(*UnionSelect) + }); errF != nil { + return errF + } + } + if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + parent.(*Union).OrderBy = newNode.(OrderBy) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*Union).Limit = newNode.(*Limit) + }); errF != nil { + return errF } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// VisitAlgorithmValue will visit all parts of the AST -func VisitAlgorithmValue(in AlgorithmValue, f Visit) error { - _, err := f(in) - return err -} - -// rewriteAlgorithmValue is part of the Rewrite implementation -func (a *application) rewriteAlgorithmValue(parent SQLNode, node AlgorithmValue, replacer replacerFunc) error { +func (a *application) rewriteRefOfUnionSelect(parent SQLNode, node *UnionSelect, replacer replacerFunc) error { + if node == nil { + return nil + } cur := Cursor{ node: node, parent: parent, @@ -13425,20 +12664,20 @@ func (a *application) rewriteAlgorithmValue(parent SQLNode, node AlgorithmValue, if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteSelectStatement(node, node.Statement, func(newNode, parent SQLNode) { + parent.(*UnionSelect).Statement = newNode.(SelectStatement) + }); errF != nil { + return errF + } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// VisitArgument will visit all parts of the AST -func VisitArgument(in Argument, f Visit) error { - _, err := f(in) - return err -} - -// rewriteArgument is part of the Rewrite implementation -func (a *application) rewriteArgument(parent SQLNode, node Argument, replacer replacerFunc) error { +func (a *application) rewriteRefOfUnlockTables(parent SQLNode, node *UnlockTables, replacer replacerFunc) error { + if node == nil { + return nil + } cur := Cursor{ node: node, parent: parent, @@ -13452,15 +12691,10 @@ func (a *application) rewriteArgument(parent SQLNode, node Argument, replacer re } return nil } - -// VisitBoolVal will visit all parts of the AST -func VisitBoolVal(in BoolVal, f Visit) error { - _, err := f(in) - return err -} - -// rewriteBoolVal is part of the Rewrite implementation -func (a *application) rewriteBoolVal(parent SQLNode, node BoolVal, replacer replacerFunc) error { +func (a *application) rewriteRefOfUpdate(parent SQLNode, node *Update, replacer replacerFunc) error { + if node == nil { + return nil + } cur := Cursor{ node: node, parent: parent, @@ -13469,20 +12703,45 @@ func (a *application) rewriteBoolVal(parent SQLNode, node BoolVal, replacer repl if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Update).Comments = newNode.(Comments) + }); errF != nil { + return errF + } + if errF := a.rewriteTableExprs(node, node.TableExprs, func(newNode, parent SQLNode) { + parent.(*Update).TableExprs = newNode.(TableExprs) + }); errF != nil { + return errF + } + if errF := a.rewriteUpdateExprs(node, node.Exprs, func(newNode, parent SQLNode) { + parent.(*Update).Exprs = newNode.(UpdateExprs) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { + parent.(*Update).Where = newNode.(*Where) + }); errF != nil { + return errF + } + if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + parent.(*Update).OrderBy = newNode.(OrderBy) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*Update).Limit = newNode.(*Limit) + }); errF != nil { + return errF + } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// VisitIsolationLevel will visit all parts of the AST -func VisitIsolationLevel(in IsolationLevel, f Visit) error { - _, err := f(in) - return err -} - -// rewriteIsolationLevel is part of the Rewrite implementation -func (a *application) rewriteIsolationLevel(parent SQLNode, node IsolationLevel, replacer replacerFunc) error { +func (a *application) rewriteRefOfUpdateExpr(parent SQLNode, node *UpdateExpr, replacer replacerFunc) error { + if node == nil { + return nil + } cur := Cursor{ node: node, parent: parent, @@ -13491,20 +12750,25 @@ func (a *application) rewriteIsolationLevel(parent SQLNode, node IsolationLevel, if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { + parent.(*UpdateExpr).Name = newNode.(*ColName) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*UpdateExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// VisitReferenceAction will visit all parts of the AST -func VisitReferenceAction(in ReferenceAction, f Visit) error { - _, err := f(in) - return err -} - -// rewriteReferenceAction is part of the Rewrite implementation -func (a *application) rewriteReferenceAction(parent SQLNode, node ReferenceAction, replacer replacerFunc) error { +func (a *application) rewriteRefOfUse(parent SQLNode, node *Use, replacer replacerFunc) error { + if node == nil { + return nil + } cur := Cursor{ node: node, parent: parent, @@ -13513,157 +12777,59 @@ func (a *application) rewriteReferenceAction(parent SQLNode, node ReferenceActio if a.pre != nil && !a.pre(&cur) { return nil } + if errF := a.rewriteTableIdent(node, node.DBName, func(newNode, parent SQLNode) { + parent.(*Use).DBName = newNode.(TableIdent) + }); errF != nil { + return errF + } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsSliceOfRefOfColumnDefinition does deep equals between the two objects. -func EqualsSliceOfRefOfColumnDefinition(a, b []*ColumnDefinition) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsRefOfColumnDefinition(a[i], b[i]) { - return false - } - } - return true -} - -// CloneSliceOfRefOfColumnDefinition creates a deep clone of the input. -func CloneSliceOfRefOfColumnDefinition(n []*ColumnDefinition) []*ColumnDefinition { - res := make([]*ColumnDefinition, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfColumnDefinition(x)) - } - return res -} - -// EqualsSliceOfCollateAndCharset does deep equals between the two objects. -func EqualsSliceOfCollateAndCharset(a, b []CollateAndCharset) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsCollateAndCharset(a[i], b[i]) { - return false - } - } - return true -} - -// CloneSliceOfCollateAndCharset creates a deep clone of the input. -func CloneSliceOfCollateAndCharset(n []CollateAndCharset) []CollateAndCharset { - res := make([]CollateAndCharset, 0, len(n)) - for _, x := range n { - res = append(res, CloneCollateAndCharset(x)) - } - return res -} - -// EqualsSliceOfAlterOption does deep equals between the two objects. -func EqualsSliceOfAlterOption(a, b []AlterOption) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsAlterOption(a[i], b[i]) { - return false - } - } - return true -} - -// CloneSliceOfAlterOption creates a deep clone of the input. -func CloneSliceOfAlterOption(n []AlterOption) []AlterOption { - res := make([]AlterOption, 0, len(n)) - for _, x := range n { - res = append(res, CloneAlterOption(x)) - } - return res -} - -// EqualsSliceOfColIdent does deep equals between the two objects. -func EqualsSliceOfColIdent(a, b []ColIdent) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsColIdent(a[i], b[i]) { - return false - } - } - return true -} - -// CloneSliceOfColIdent creates a deep clone of the input. -func CloneSliceOfColIdent(n []ColIdent) []ColIdent { - res := make([]ColIdent, 0, len(n)) - for _, x := range n { - res = append(res, CloneColIdent(x)) +func (a *application) rewriteRefOfVStream(parent SQLNode, node *VStream, replacer replacerFunc) error { + if node == nil { + return nil } - return res -} - -// EqualsSliceOfRefOfWhen does deep equals between the two objects. -func EqualsSliceOfRefOfWhen(a, b []*When) bool { - if len(a) != len(b) { - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - for i := 0; i < len(a); i++ { - if !EqualsRefOfWhen(a[i], b[i]) { - return false - } + if a.pre != nil && !a.pre(&cur) { + return nil } - return true -} - -// CloneSliceOfRefOfWhen creates a deep clone of the input. -func CloneSliceOfRefOfWhen(n []*When) []*When { - res := make([]*When, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfWhen(x)) + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*VStream).Comments = newNode.(Comments) + }); errF != nil { + return errF } - return res -} - -// EqualsRefOfColIdent does deep equals between the two objects. -func EqualsRefOfColIdent(a, b *ColIdent) bool { - if a == b { - return true + if errF := a.rewriteSelectExpr(node, node.SelectExpr, func(newNode, parent SQLNode) { + parent.(*VStream).SelectExpr = newNode.(SelectExpr) + }); errF != nil { + return errF } - if a == nil || b == nil { - return false + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*VStream).Table = newNode.(TableName) + }); errF != nil { + return errF } - return a.val == b.val && - a.lowered == b.lowered && - a.at == b.at -} - -// CloneRefOfColIdent creates a deep clone of the input. -func CloneRefOfColIdent(n *ColIdent) *ColIdent { - if n == nil { - return nil + if errF := a.rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { + parent.(*VStream).Where = newNode.(*Where) + }); errF != nil { + return errF } - out := *n - return &out -} - -// VisitRefOfColIdent will visit all parts of the AST -func VisitRefOfColIdent(in *ColIdent, f Visit) error { - if in == nil { - return nil + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*VStream).Limit = newNode.(*Limit) + }); errF != nil { + return errF } - if cont, err := f(in); err != nil || !cont { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfColIdent is part of the Rewrite implementation -func (a *application) rewriteRefOfColIdent(parent SQLNode, node *ColIdent, replacer replacerFunc) error { +func (a *application) rewriteRefOfValidation(parent SQLNode, node *Validation, replacer replacerFunc) error { if node == nil { return nil } @@ -13680,159 +12846,51 @@ func (a *application) rewriteRefOfColIdent(parent SQLNode, node *ColIdent, repla } return nil } - -// EqualsColumnType does deep equals between the two objects. -func EqualsColumnType(a, b ColumnType) bool { - return a.Type == b.Type && - a.Unsigned == b.Unsigned && - a.Zerofill == b.Zerofill && - a.Charset == b.Charset && - a.Collate == b.Collate && - EqualsRefOfColumnTypeOptions(a.Options, b.Options) && - EqualsRefOfLiteral(a.Length, b.Length) && - EqualsRefOfLiteral(a.Scale, b.Scale) && - EqualsSliceOfString(a.EnumValues, b.EnumValues) -} - -// CloneColumnType creates a deep clone of the input. -func CloneColumnType(n ColumnType) ColumnType { - return *CloneRefOfColumnType(&n) -} - -// EqualsRefOfColumnTypeOptions does deep equals between the two objects. -func EqualsRefOfColumnTypeOptions(a, b *ColumnTypeOptions) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.NotNull == b.NotNull && - a.Autoincrement == b.Autoincrement && - EqualsExpr(a.Default, b.Default) && - EqualsExpr(a.OnUpdate, b.OnUpdate) && - EqualsRefOfLiteral(a.Comment, b.Comment) && - a.KeyOpt == b.KeyOpt -} - -// CloneRefOfColumnTypeOptions creates a deep clone of the input. -func CloneRefOfColumnTypeOptions(n *ColumnTypeOptions) *ColumnTypeOptions { - if n == nil { +func (a *application) rewriteRefOfValuesFuncExpr(parent SQLNode, node *ValuesFuncExpr, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.Default = CloneExpr(n.Default) - out.OnUpdate = CloneExpr(n.OnUpdate) - out.Comment = CloneRefOfLiteral(n.Comment) - return &out -} - -// EqualsSliceOfString does deep equals between the two objects. -func EqualsSliceOfString(a, b []string) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if a[i] != b[i] { - return false - } - } - return true -} - -// CloneSliceOfString creates a deep clone of the input. -func CloneSliceOfString(n []string) []string { - res := make([]string, 0, len(n)) - copy(res, n) - return res -} - -// EqualsSliceOfRefOfIndexColumn does deep equals between the two objects. -func EqualsSliceOfRefOfIndexColumn(a, b []*IndexColumn) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsRefOfIndexColumn(a[i], b[i]) { - return false - } - } - return true -} - -// CloneSliceOfRefOfIndexColumn creates a deep clone of the input. -func CloneSliceOfRefOfIndexColumn(n []*IndexColumn) []*IndexColumn { - res := make([]*IndexColumn, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfIndexColumn(x)) - } - return res -} - -// EqualsSliceOfRefOfIndexOption does deep equals between the two objects. -func EqualsSliceOfRefOfIndexOption(a, b []*IndexOption) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsRefOfIndexOption(a[i], b[i]) { - return false - } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return true -} - -// CloneSliceOfRefOfIndexOption creates a deep clone of the input. -func CloneSliceOfRefOfIndexOption(n []*IndexOption) []*IndexOption { - res := make([]*IndexOption, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfIndexOption(x)) + if a.pre != nil && !a.pre(&cur) { + return nil } - return res -} - -// EqualsRefOfJoinCondition does deep equals between the two objects. -func EqualsRefOfJoinCondition(a, b *JoinCondition) bool { - if a == b { - return true + if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { + parent.(*ValuesFuncExpr).Name = newNode.(*ColName) + }); errF != nil { + return errF } - if a == nil || b == nil { - return false + if a.post != nil && !a.post(&cur) { + return errAbort } - return EqualsExpr(a.On, b.On) && - EqualsColumns(a.Using, b.Using) + return nil } - -// CloneRefOfJoinCondition creates a deep clone of the input. -func CloneRefOfJoinCondition(n *JoinCondition) *JoinCondition { - if n == nil { +func (a *application) rewriteRefOfVindexParam(parent SQLNode, node *VindexParam, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.On = CloneExpr(n.On) - out.Using = CloneColumns(n.Using) - return &out -} - -// VisitRefOfJoinCondition will visit all parts of the AST -func VisitRefOfJoinCondition(in *JoinCondition, f Visit) error { - if in == nil { - return nil + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if cont, err := f(in); err != nil || !cont { - return err + if a.pre != nil && !a.pre(&cur) { + return nil } - if err := VisitExpr(in.On, f); err != nil { - return err + if errF := a.rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { + parent.(*VindexParam).Key = newNode.(ColIdent) + }); errF != nil { + return errF } - if err := VisitColumns(in.Using, f); err != nil { - return err + if a.post != nil && !a.post(&cur) { + return errAbort } return nil } - -// rewriteRefOfJoinCondition is part of the Rewrite implementation -func (a *application) rewriteRefOfJoinCondition(parent SQLNode, node *JoinCondition, replacer replacerFunc) error { +func (a *application) rewriteRefOfVindexSpec(parent SQLNode, node *VindexSpec, replacer replacerFunc) error { if node == nil { return nil } @@ -13844,186 +12902,437 @@ func (a *application) rewriteRefOfJoinCondition(parent SQLNode, node *JoinCondit if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteExpr(node, node.On, func(newNode, parent SQLNode) { - parent.(*JoinCondition).On = newNode.(Expr) + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*VindexSpec).Name = newNode.(ColIdent) }); errF != nil { return errF } - if errF := a.rewriteColumns(node, node.Using, func(newNode, parent SQLNode) { - parent.(*JoinCondition).Using = newNode.(Columns) + if errF := a.rewriteColIdent(node, node.Type, func(newNode, parent SQLNode) { + parent.(*VindexSpec).Type = newNode.(ColIdent) }); errF != nil { return errF } + for i, el := range node.Params { + if errF := a.rewriteVindexParam(node, el, func(newNode, parent SQLNode) { + parent.(*VindexSpec).Params[i] = newNode.(VindexParam) + }); errF != nil { + return errF + } + } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsTableAndLockTypes does deep equals between the two objects. -func EqualsTableAndLockTypes(a, b TableAndLockTypes) bool { - if len(a) != len(b) { - return false +func (a *application) rewriteRefOfWhen(parent SQLNode, node *When, replacer replacerFunc) error { + if node == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsRefOfTableAndLockType(a[i], b[i]) { - return false - } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return true -} - -// CloneTableAndLockTypes creates a deep clone of the input. -func CloneTableAndLockTypes(n TableAndLockTypes) TableAndLockTypes { - res := make(TableAndLockTypes, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfTableAndLockType(x)) + if a.pre != nil && !a.pre(&cur) { + return nil } - return res -} - -// EqualsSliceOfRefOfPartitionDefinition does deep equals between the two objects. -func EqualsSliceOfRefOfPartitionDefinition(a, b []*PartitionDefinition) bool { - if len(a) != len(b) { - return false + if errF := a.rewriteExpr(node, node.Cond, func(newNode, parent SQLNode) { + parent.(*When).Cond = newNode.(Expr) + }); errF != nil { + return errF } - for i := 0; i < len(a); i++ { - if !EqualsRefOfPartitionDefinition(a[i], b[i]) { - return false - } + if errF := a.rewriteExpr(node, node.Val, func(newNode, parent SQLNode) { + parent.(*When).Val = newNode.(Expr) + }); errF != nil { + return errF } - return true -} - -// CloneSliceOfRefOfPartitionDefinition creates a deep clone of the input. -func CloneSliceOfRefOfPartitionDefinition(n []*PartitionDefinition) []*PartitionDefinition { - res := make([]*PartitionDefinition, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfPartitionDefinition(x)) + if a.post != nil && !a.post(&cur) { + return errAbort } - return res + return nil } - -// EqualsSliceOfRefOfRenameTablePair does deep equals between the two objects. -func EqualsSliceOfRefOfRenameTablePair(a, b []*RenameTablePair) bool { - if len(a) != len(b) { - return false +func (a *application) rewriteRefOfWhere(parent SQLNode, node *Where, replacer replacerFunc) error { + if node == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsRefOfRenameTablePair(a[i], b[i]) { - return false - } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return true -} - -// CloneSliceOfRefOfRenameTablePair creates a deep clone of the input. -func CloneSliceOfRefOfRenameTablePair(n []*RenameTablePair) []*RenameTablePair { - res := make([]*RenameTablePair, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfRenameTablePair(x)) + if a.pre != nil && !a.pre(&cur) { + return nil } - return res -} - -// EqualsRefOfBool does deep equals between the two objects. -func EqualsRefOfBool(a, b *bool) bool { - if a == b { - return true + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*Where).Expr = newNode.(Expr) + }); errF != nil { + return errF } - if a == nil || b == nil { - return false + if a.post != nil && !a.post(&cur) { + return errAbort } - return *a == *b + return nil } - -// CloneRefOfBool creates a deep clone of the input. -func CloneRefOfBool(n *bool) *bool { - if n == nil { +func (a *application) rewriteRefOfXorExpr(parent SQLNode, node *XorExpr, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - return &out -} - -// EqualsSliceOfCharacteristic does deep equals between the two objects. -func EqualsSliceOfCharacteristic(a, b []Characteristic) bool { - if len(a) != len(b) { - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - for i := 0; i < len(a); i++ { - if !EqualsCharacteristic(a[i], b[i]) { - return false - } + if a.pre != nil && !a.pre(&cur) { + return nil } - return true -} - -// CloneSliceOfCharacteristic creates a deep clone of the input. -func CloneSliceOfCharacteristic(n []Characteristic) []Characteristic { - res := make([]Characteristic, 0, len(n)) - for _, x := range n { - res = append(res, CloneCharacteristic(x)) + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*XorExpr).Left = newNode.(Expr) + }); errF != nil { + return errF } - return res + if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + parent.(*XorExpr).Right = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// EqualsRefOfShowTablesOpt does deep equals between the two objects. -func EqualsRefOfShowTablesOpt(a, b *ShowTablesOpt) bool { - if a == b { - return true +func (a *application) rewriteReferenceAction(parent SQLNode, node ReferenceAction, replacer replacerFunc) error { + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if a == nil || b == nil { - return false + if a.pre != nil && !a.pre(&cur) { + return nil } - return a.Full == b.Full && - a.DbName == b.DbName && - EqualsRefOfShowFilter(a.Filter, b.Filter) + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// CloneRefOfShowTablesOpt creates a deep clone of the input. -func CloneRefOfShowTablesOpt(n *ShowTablesOpt) *ShowTablesOpt { - if n == nil { +func (a *application) rewriteSQLNode(parent SQLNode, node SQLNode, replacer replacerFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case AccessMode: + return a.rewriteAccessMode(parent, node, replacer) + case *AddColumns: + return a.rewriteRefOfAddColumns(parent, node, replacer) + case *AddConstraintDefinition: + return a.rewriteRefOfAddConstraintDefinition(parent, node, replacer) + case *AddIndexDefinition: + return a.rewriteRefOfAddIndexDefinition(parent, node, replacer) + case AlgorithmValue: + return a.rewriteAlgorithmValue(parent, node, replacer) + case *AliasedExpr: + return a.rewriteRefOfAliasedExpr(parent, node, replacer) + case *AliasedTableExpr: + return a.rewriteRefOfAliasedTableExpr(parent, node, replacer) + case *AlterCharset: + return a.rewriteRefOfAlterCharset(parent, node, replacer) + case *AlterColumn: + return a.rewriteRefOfAlterColumn(parent, node, replacer) + case *AlterDatabase: + return a.rewriteRefOfAlterDatabase(parent, node, replacer) + case *AlterMigration: + return a.rewriteRefOfAlterMigration(parent, node, replacer) + case *AlterTable: + return a.rewriteRefOfAlterTable(parent, node, replacer) + case *AlterView: + return a.rewriteRefOfAlterView(parent, node, replacer) + case *AlterVschema: + return a.rewriteRefOfAlterVschema(parent, node, replacer) + case *AndExpr: + return a.rewriteRefOfAndExpr(parent, node, replacer) + case Argument: + return a.rewriteArgument(parent, node, replacer) + case *AutoIncSpec: + return a.rewriteRefOfAutoIncSpec(parent, node, replacer) + case *Begin: + return a.rewriteRefOfBegin(parent, node, replacer) + case *BinaryExpr: + return a.rewriteRefOfBinaryExpr(parent, node, replacer) + case BoolVal: + return a.rewriteBoolVal(parent, node, replacer) + case *CallProc: + return a.rewriteRefOfCallProc(parent, node, replacer) + case *CaseExpr: + return a.rewriteRefOfCaseExpr(parent, node, replacer) + case *ChangeColumn: + return a.rewriteRefOfChangeColumn(parent, node, replacer) + case *CheckConstraintDefinition: + return a.rewriteRefOfCheckConstraintDefinition(parent, node, replacer) + case ColIdent: + return a.rewriteColIdent(parent, node, replacer) + case *ColName: + return a.rewriteRefOfColName(parent, node, replacer) + case *CollateExpr: + return a.rewriteRefOfCollateExpr(parent, node, replacer) + case *ColumnDefinition: + return a.rewriteRefOfColumnDefinition(parent, node, replacer) + case *ColumnType: + return a.rewriteRefOfColumnType(parent, node, replacer) + case Columns: + return a.rewriteColumns(parent, node, replacer) + case Comments: + return a.rewriteComments(parent, node, replacer) + case *Commit: + return a.rewriteRefOfCommit(parent, node, replacer) + case *ComparisonExpr: + return a.rewriteRefOfComparisonExpr(parent, node, replacer) + case *ConstraintDefinition: + return a.rewriteRefOfConstraintDefinition(parent, node, replacer) + case *ConvertExpr: + return a.rewriteRefOfConvertExpr(parent, node, replacer) + case *ConvertType: + return a.rewriteRefOfConvertType(parent, node, replacer) + case *ConvertUsingExpr: + return a.rewriteRefOfConvertUsingExpr(parent, node, replacer) + case *CreateDatabase: + return a.rewriteRefOfCreateDatabase(parent, node, replacer) + case *CreateTable: + return a.rewriteRefOfCreateTable(parent, node, replacer) + case *CreateView: + return a.rewriteRefOfCreateView(parent, node, replacer) + case *CurTimeFuncExpr: + return a.rewriteRefOfCurTimeFuncExpr(parent, node, replacer) + case *Default: + return a.rewriteRefOfDefault(parent, node, replacer) + case *Delete: + return a.rewriteRefOfDelete(parent, node, replacer) + case *DerivedTable: + return a.rewriteRefOfDerivedTable(parent, node, replacer) + case *DropColumn: + return a.rewriteRefOfDropColumn(parent, node, replacer) + case *DropDatabase: + return a.rewriteRefOfDropDatabase(parent, node, replacer) + case *DropKey: + return a.rewriteRefOfDropKey(parent, node, replacer) + case *DropTable: + return a.rewriteRefOfDropTable(parent, node, replacer) + case *DropView: + return a.rewriteRefOfDropView(parent, node, replacer) + case *ExistsExpr: + return a.rewriteRefOfExistsExpr(parent, node, replacer) + case *ExplainStmt: + return a.rewriteRefOfExplainStmt(parent, node, replacer) + case *ExplainTab: + return a.rewriteRefOfExplainTab(parent, node, replacer) + case Exprs: + return a.rewriteExprs(parent, node, replacer) + case *Flush: + return a.rewriteRefOfFlush(parent, node, replacer) + case *Force: + return a.rewriteRefOfForce(parent, node, replacer) + case *ForeignKeyDefinition: + return a.rewriteRefOfForeignKeyDefinition(parent, node, replacer) + case *FuncExpr: + return a.rewriteRefOfFuncExpr(parent, node, replacer) + case GroupBy: + return a.rewriteGroupBy(parent, node, replacer) + case *GroupConcatExpr: + return a.rewriteRefOfGroupConcatExpr(parent, node, replacer) + case *IndexDefinition: + return a.rewriteRefOfIndexDefinition(parent, node, replacer) + case *IndexHints: + return a.rewriteRefOfIndexHints(parent, node, replacer) + case *IndexInfo: + return a.rewriteRefOfIndexInfo(parent, node, replacer) + case *Insert: + return a.rewriteRefOfInsert(parent, node, replacer) + case *IntervalExpr: + return a.rewriteRefOfIntervalExpr(parent, node, replacer) + case *IsExpr: + return a.rewriteRefOfIsExpr(parent, node, replacer) + case IsolationLevel: + return a.rewriteIsolationLevel(parent, node, replacer) + case JoinCondition: + return a.rewriteJoinCondition(parent, node, replacer) + case *JoinTableExpr: + return a.rewriteRefOfJoinTableExpr(parent, node, replacer) + case *KeyState: + return a.rewriteRefOfKeyState(parent, node, replacer) + case *Limit: + return a.rewriteRefOfLimit(parent, node, replacer) + case ListArg: + return a.rewriteListArg(parent, node, replacer) + case *Literal: + return a.rewriteRefOfLiteral(parent, node, replacer) + case *Load: + return a.rewriteRefOfLoad(parent, node, replacer) + case *LockOption: + return a.rewriteRefOfLockOption(parent, node, replacer) + case *LockTables: + return a.rewriteRefOfLockTables(parent, node, replacer) + case *MatchExpr: + return a.rewriteRefOfMatchExpr(parent, node, replacer) + case *ModifyColumn: + return a.rewriteRefOfModifyColumn(parent, node, replacer) + case *Nextval: + return a.rewriteRefOfNextval(parent, node, replacer) + case *NotExpr: + return a.rewriteRefOfNotExpr(parent, node, replacer) + case *NullVal: + return a.rewriteRefOfNullVal(parent, node, replacer) + case OnDup: + return a.rewriteOnDup(parent, node, replacer) + case *OptLike: + return a.rewriteRefOfOptLike(parent, node, replacer) + case *OrExpr: + return a.rewriteRefOfOrExpr(parent, node, replacer) + case *Order: + return a.rewriteRefOfOrder(parent, node, replacer) + case OrderBy: + return a.rewriteOrderBy(parent, node, replacer) + case *OrderByOption: + return a.rewriteRefOfOrderByOption(parent, node, replacer) + case *OtherAdmin: + return a.rewriteRefOfOtherAdmin(parent, node, replacer) + case *OtherRead: + return a.rewriteRefOfOtherRead(parent, node, replacer) + case *ParenSelect: + return a.rewriteRefOfParenSelect(parent, node, replacer) + case *ParenTableExpr: + return a.rewriteRefOfParenTableExpr(parent, node, replacer) + case *PartitionDefinition: + return a.rewriteRefOfPartitionDefinition(parent, node, replacer) + case *PartitionSpec: + return a.rewriteRefOfPartitionSpec(parent, node, replacer) + case Partitions: + return a.rewritePartitions(parent, node, replacer) + case *RangeCond: + return a.rewriteRefOfRangeCond(parent, node, replacer) + case ReferenceAction: + return a.rewriteReferenceAction(parent, node, replacer) + case *Release: + return a.rewriteRefOfRelease(parent, node, replacer) + case *RenameIndex: + return a.rewriteRefOfRenameIndex(parent, node, replacer) + case *RenameTable: + return a.rewriteRefOfRenameTable(parent, node, replacer) + case *RenameTableName: + return a.rewriteRefOfRenameTableName(parent, node, replacer) + case *RevertMigration: + return a.rewriteRefOfRevertMigration(parent, node, replacer) + case *Rollback: + return a.rewriteRefOfRollback(parent, node, replacer) + case *SRollback: + return a.rewriteRefOfSRollback(parent, node, replacer) + case *Savepoint: + return a.rewriteRefOfSavepoint(parent, node, replacer) + case *Select: + return a.rewriteRefOfSelect(parent, node, replacer) + case SelectExprs: + return a.rewriteSelectExprs(parent, node, replacer) + case *SelectInto: + return a.rewriteRefOfSelectInto(parent, node, replacer) + case *Set: + return a.rewriteRefOfSet(parent, node, replacer) + case *SetExpr: + return a.rewriteRefOfSetExpr(parent, node, replacer) + case SetExprs: + return a.rewriteSetExprs(parent, node, replacer) + case *SetTransaction: + return a.rewriteRefOfSetTransaction(parent, node, replacer) + case *Show: + return a.rewriteRefOfShow(parent, node, replacer) + case *ShowBasic: + return a.rewriteRefOfShowBasic(parent, node, replacer) + case *ShowCreate: + return a.rewriteRefOfShowCreate(parent, node, replacer) + case *ShowFilter: + return a.rewriteRefOfShowFilter(parent, node, replacer) + case *ShowLegacy: + return a.rewriteRefOfShowLegacy(parent, node, replacer) + case *StarExpr: + return a.rewriteRefOfStarExpr(parent, node, replacer) + case *Stream: + return a.rewriteRefOfStream(parent, node, replacer) + case *Subquery: + return a.rewriteRefOfSubquery(parent, node, replacer) + case *SubstrExpr: + return a.rewriteRefOfSubstrExpr(parent, node, replacer) + case TableExprs: + return a.rewriteTableExprs(parent, node, replacer) + case TableIdent: + return a.rewriteTableIdent(parent, node, replacer) + case TableName: + return a.rewriteTableName(parent, node, replacer) + case TableNames: + return a.rewriteTableNames(parent, node, replacer) + case TableOptions: + return a.rewriteTableOptions(parent, node, replacer) + case *TableSpec: + return a.rewriteRefOfTableSpec(parent, node, replacer) + case *TablespaceOperation: + return a.rewriteRefOfTablespaceOperation(parent, node, replacer) + case *TimestampFuncExpr: + return a.rewriteRefOfTimestampFuncExpr(parent, node, replacer) + case *TruncateTable: + return a.rewriteRefOfTruncateTable(parent, node, replacer) + case *UnaryExpr: + return a.rewriteRefOfUnaryExpr(parent, node, replacer) + case *Union: + return a.rewriteRefOfUnion(parent, node, replacer) + case *UnionSelect: + return a.rewriteRefOfUnionSelect(parent, node, replacer) + case *UnlockTables: + return a.rewriteRefOfUnlockTables(parent, node, replacer) + case *Update: + return a.rewriteRefOfUpdate(parent, node, replacer) + case *UpdateExpr: + return a.rewriteRefOfUpdateExpr(parent, node, replacer) + case UpdateExprs: + return a.rewriteUpdateExprs(parent, node, replacer) + case *Use: + return a.rewriteRefOfUse(parent, node, replacer) + case *VStream: + return a.rewriteRefOfVStream(parent, node, replacer) + case ValTuple: + return a.rewriteValTuple(parent, node, replacer) + case *Validation: + return a.rewriteRefOfValidation(parent, node, replacer) + case Values: + return a.rewriteValues(parent, node, replacer) + case *ValuesFuncExpr: + return a.rewriteRefOfValuesFuncExpr(parent, node, replacer) + case VindexParam: + return a.rewriteVindexParam(parent, node, replacer) + case *VindexSpec: + return a.rewriteRefOfVindexSpec(parent, node, replacer) + case *When: + return a.rewriteRefOfWhen(parent, node, replacer) + case *Where: + return a.rewriteRefOfWhere(parent, node, replacer) + case *XorExpr: + return a.rewriteRefOfXorExpr(parent, node, replacer) + default: + // this should never happen return nil } - out := *n - out.Filter = CloneRefOfShowFilter(n.Filter) - return &out -} - -// EqualsRefOfTableIdent does deep equals between the two objects. -func EqualsRefOfTableIdent(a, b *TableIdent) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.v == b.v } - -// CloneRefOfTableIdent creates a deep clone of the input. -func CloneRefOfTableIdent(n *TableIdent) *TableIdent { - if n == nil { +func (a *application) rewriteSelectExpr(parent SQLNode, node SelectExpr, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - return &out -} - -// VisitRefOfTableIdent will visit all parts of the AST -func VisitRefOfTableIdent(in *TableIdent, f Visit) error { - if in == nil { + switch node := node.(type) { + case *AliasedExpr: + return a.rewriteRefOfAliasedExpr(parent, node, replacer) + case *Nextval: + return a.rewriteRefOfNextval(parent, node, replacer) + case *StarExpr: + return a.rewriteRefOfStarExpr(parent, node, replacer) + default: + // this should never happen return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil } - -// rewriteRefOfTableIdent is part of the Rewrite implementation -func (a *application) rewriteRefOfTableIdent(parent SQLNode, node *TableIdent, replacer replacerFunc) error { +func (a *application) rewriteSelectExprs(parent SQLNode, node SelectExprs, replacer replacerFunc) error { if node == nil { return nil } @@ -14035,54 +13344,35 @@ func (a *application) rewriteRefOfTableIdent(parent SQLNode, node *TableIdent, r if a.pre != nil && !a.pre(&cur) { return nil } + for i, el := range node { + if errF := a.rewriteSelectExpr(node, el, func(newNode, parent SQLNode) { + parent.(SelectExprs)[i] = newNode.(SelectExpr) + }); errF != nil { + return errF + } + } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfTableName does deep equals between the two objects. -func EqualsRefOfTableName(a, b *TableName) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsTableIdent(a.Name, b.Name) && - EqualsTableIdent(a.Qualifier, b.Qualifier) -} - -// CloneRefOfTableName creates a deep clone of the input. -func CloneRefOfTableName(n *TableName) *TableName { - if n == nil { +func (a *application) rewriteSelectStatement(parent SQLNode, node SelectStatement, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.Name = CloneTableIdent(n.Name) - out.Qualifier = CloneTableIdent(n.Qualifier) - return &out -} - -// VisitRefOfTableName will visit all parts of the AST -func VisitRefOfTableName(in *TableName, f Visit) error { - if in == nil { + switch node := node.(type) { + case *ParenSelect: + return a.rewriteRefOfParenSelect(parent, node, replacer) + case *Select: + return a.rewriteRefOfSelect(parent, node, replacer) + case *Union: + return a.rewriteRefOfUnion(parent, node, replacer) + default: + // this should never happen return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableIdent(in.Name, f); err != nil { - return err - } - if err := VisitTableIdent(in.Qualifier, f); err != nil { - return err - } - return nil -} - -// rewriteRefOfTableName is part of the Rewrite implementation -func (a *application) rewriteRefOfTableName(parent SQLNode, node *TableName, replacer replacerFunc) error { +} +func (a *application) rewriteSetExprs(parent SQLNode, node SetExprs, replacer replacerFunc) error { if node == nil { return nil } @@ -14094,151 +13384,227 @@ func (a *application) rewriteRefOfTableName(parent SQLNode, node *TableName, rep if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { - parent.(*TableName).Name = newNode.(TableIdent) - }); errF != nil { - return errF - } - if errF := a.rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { - parent.(*TableName).Qualifier = newNode.(TableIdent) - }); errF != nil { - return errF + for i, el := range node { + if errF := a.rewriteRefOfSetExpr(node, el, func(newNode, parent SQLNode) { + parent.(SetExprs)[i] = newNode.(*SetExpr) + }); errF != nil { + return errF + } } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsRefOfTableOption does deep equals between the two objects. -func EqualsRefOfTableOption(a, b *TableOption) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +func (a *application) rewriteShowInternal(parent SQLNode, node ShowInternal, replacer replacerFunc) error { + if node == nil { + return nil } - return a.Name == b.Name && - a.String == b.String && - EqualsRefOfLiteral(a.Value, b.Value) && - EqualsTableNames(a.Tables, b.Tables) -} - -// CloneRefOfTableOption creates a deep clone of the input. -func CloneRefOfTableOption(n *TableOption) *TableOption { - if n == nil { + switch node := node.(type) { + case *ShowBasic: + return a.rewriteRefOfShowBasic(parent, node, replacer) + case *ShowCreate: + return a.rewriteRefOfShowCreate(parent, node, replacer) + case *ShowLegacy: + return a.rewriteRefOfShowLegacy(parent, node, replacer) + default: + // this should never happen return nil } - out := *n - out.Value = CloneRefOfLiteral(n.Value) - out.Tables = CloneTableNames(n.Tables) - return &out } - -// EqualsSliceOfRefOfIndexDefinition does deep equals between the two objects. -func EqualsSliceOfRefOfIndexDefinition(a, b []*IndexDefinition) bool { - if len(a) != len(b) { - return false +func (a *application) rewriteSimpleTableExpr(parent SQLNode, node SimpleTableExpr, replacer replacerFunc) error { + if node == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsRefOfIndexDefinition(a[i], b[i]) { - return false - } + switch node := node.(type) { + case *DerivedTable: + return a.rewriteRefOfDerivedTable(parent, node, replacer) + case TableName: + return a.rewriteTableName(parent, node, replacer) + default: + // this should never happen + return nil } - return true } - -// CloneSliceOfRefOfIndexDefinition creates a deep clone of the input. -func CloneSliceOfRefOfIndexDefinition(n []*IndexDefinition) []*IndexDefinition { - res := make([]*IndexDefinition, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfIndexDefinition(x)) +func (a *application) rewriteStatement(parent SQLNode, node Statement, replacer replacerFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *AlterDatabase: + return a.rewriteRefOfAlterDatabase(parent, node, replacer) + case *AlterMigration: + return a.rewriteRefOfAlterMigration(parent, node, replacer) + case *AlterTable: + return a.rewriteRefOfAlterTable(parent, node, replacer) + case *AlterView: + return a.rewriteRefOfAlterView(parent, node, replacer) + case *AlterVschema: + return a.rewriteRefOfAlterVschema(parent, node, replacer) + case *Begin: + return a.rewriteRefOfBegin(parent, node, replacer) + case *CallProc: + return a.rewriteRefOfCallProc(parent, node, replacer) + case *Commit: + return a.rewriteRefOfCommit(parent, node, replacer) + case *CreateDatabase: + return a.rewriteRefOfCreateDatabase(parent, node, replacer) + case *CreateTable: + return a.rewriteRefOfCreateTable(parent, node, replacer) + case *CreateView: + return a.rewriteRefOfCreateView(parent, node, replacer) + case *Delete: + return a.rewriteRefOfDelete(parent, node, replacer) + case *DropDatabase: + return a.rewriteRefOfDropDatabase(parent, node, replacer) + case *DropTable: + return a.rewriteRefOfDropTable(parent, node, replacer) + case *DropView: + return a.rewriteRefOfDropView(parent, node, replacer) + case *ExplainStmt: + return a.rewriteRefOfExplainStmt(parent, node, replacer) + case *ExplainTab: + return a.rewriteRefOfExplainTab(parent, node, replacer) + case *Flush: + return a.rewriteRefOfFlush(parent, node, replacer) + case *Insert: + return a.rewriteRefOfInsert(parent, node, replacer) + case *Load: + return a.rewriteRefOfLoad(parent, node, replacer) + case *LockTables: + return a.rewriteRefOfLockTables(parent, node, replacer) + case *OtherAdmin: + return a.rewriteRefOfOtherAdmin(parent, node, replacer) + case *OtherRead: + return a.rewriteRefOfOtherRead(parent, node, replacer) + case *ParenSelect: + return a.rewriteRefOfParenSelect(parent, node, replacer) + case *Release: + return a.rewriteRefOfRelease(parent, node, replacer) + case *RenameTable: + return a.rewriteRefOfRenameTable(parent, node, replacer) + case *RevertMigration: + return a.rewriteRefOfRevertMigration(parent, node, replacer) + case *Rollback: + return a.rewriteRefOfRollback(parent, node, replacer) + case *SRollback: + return a.rewriteRefOfSRollback(parent, node, replacer) + case *Savepoint: + return a.rewriteRefOfSavepoint(parent, node, replacer) + case *Select: + return a.rewriteRefOfSelect(parent, node, replacer) + case *Set: + return a.rewriteRefOfSet(parent, node, replacer) + case *SetTransaction: + return a.rewriteRefOfSetTransaction(parent, node, replacer) + case *Show: + return a.rewriteRefOfShow(parent, node, replacer) + case *Stream: + return a.rewriteRefOfStream(parent, node, replacer) + case *TruncateTable: + return a.rewriteRefOfTruncateTable(parent, node, replacer) + case *Union: + return a.rewriteRefOfUnion(parent, node, replacer) + case *UnlockTables: + return a.rewriteRefOfUnlockTables(parent, node, replacer) + case *Update: + return a.rewriteRefOfUpdate(parent, node, replacer) + case *Use: + return a.rewriteRefOfUse(parent, node, replacer) + case *VStream: + return a.rewriteRefOfVStream(parent, node, replacer) + default: + // this should never happen + return nil } - return res } - -// EqualsSliceOfRefOfConstraintDefinition does deep equals between the two objects. -func EqualsSliceOfRefOfConstraintDefinition(a, b []*ConstraintDefinition) bool { - if len(a) != len(b) { - return false +func (a *application) rewriteTableExpr(parent SQLNode, node TableExpr, replacer replacerFunc) error { + if node == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsRefOfConstraintDefinition(a[i], b[i]) { - return false - } + switch node := node.(type) { + case *AliasedTableExpr: + return a.rewriteRefOfAliasedTableExpr(parent, node, replacer) + case *JoinTableExpr: + return a.rewriteRefOfJoinTableExpr(parent, node, replacer) + case *ParenTableExpr: + return a.rewriteRefOfParenTableExpr(parent, node, replacer) + default: + // this should never happen + return nil } - return true } - -// CloneSliceOfRefOfConstraintDefinition creates a deep clone of the input. -func CloneSliceOfRefOfConstraintDefinition(n []*ConstraintDefinition) []*ConstraintDefinition { - res := make([]*ConstraintDefinition, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfConstraintDefinition(x)) +func (a *application) rewriteTableExprs(parent SQLNode, node TableExprs, replacer replacerFunc) error { + if node == nil { + return nil } - return res -} - -// EqualsSliceOfRefOfUnionSelect does deep equals between the two objects. -func EqualsSliceOfRefOfUnionSelect(a, b []*UnionSelect) bool { - if len(a) != len(b) { - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - for i := 0; i < len(a); i++ { - if !EqualsRefOfUnionSelect(a[i], b[i]) { - return false + if a.pre != nil && !a.pre(&cur) { + return nil + } + for i, el := range node { + if errF := a.rewriteTableExpr(node, el, func(newNode, parent SQLNode) { + parent.(TableExprs)[i] = newNode.(TableExpr) + }); errF != nil { + return errF } } - return true -} - -// CloneSliceOfRefOfUnionSelect creates a deep clone of the input. -func CloneSliceOfRefOfUnionSelect(n []*UnionSelect) []*UnionSelect { - res := make([]*UnionSelect, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfUnionSelect(x)) + if a.post != nil && !a.post(&cur) { + return errAbort } - return res + return nil } - -// EqualsRefOfVindexParam does deep equals between the two objects. -func EqualsRefOfVindexParam(a, b *VindexParam) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +func (a *application) rewriteTableIdent(parent SQLNode, node TableIdent, replacer replacerFunc) error { + var err error + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return a.Val == b.Val && - EqualsColIdent(a.Key, b.Key) -} - -// CloneRefOfVindexParam creates a deep clone of the input. -func CloneRefOfVindexParam(n *VindexParam) *VindexParam { - if n == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - out := *n - out.Key = CloneColIdent(n.Key) - return &out + if err != nil { + return err + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// VisitRefOfVindexParam will visit all parts of the AST -func VisitRefOfVindexParam(in *VindexParam, f Visit) error { - if in == nil { +func (a *application) rewriteTableName(parent SQLNode, node TableName, replacer replacerFunc) error { + var err error + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } - if cont, err := f(in); err != nil || !cont { - return err + if errF := a.rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Name' on 'TableName'") + }); errF != nil { + return errF } - if err := VisitColIdent(in.Key, f); err != nil { + if errF := a.rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Qualifier' on 'TableName'") + }); errF != nil { + return errF + } + if err != nil { return err } + if a.post != nil && !a.post(&cur) { + return errAbort + } return nil } - -// rewriteRefOfVindexParam is part of the Rewrite implementation -func (a *application) rewriteRefOfVindexParam(parent SQLNode, node *VindexParam, replacer replacerFunc) error { +func (a *application) rewriteTableNames(parent SQLNode, node TableNames, replacer replacerFunc) error { if node == nil { return nil } @@ -14250,161 +13616,127 @@ func (a *application) rewriteRefOfVindexParam(parent SQLNode, node *VindexParam, if a.pre != nil && !a.pre(&cur) { return nil } - if errF := a.rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { - parent.(*VindexParam).Key = newNode.(ColIdent) - }); errF != nil { - return errF + for i, el := range node { + if errF := a.rewriteTableName(node, el, func(newNode, parent SQLNode) { + parent.(TableNames)[i] = newNode.(TableName) + }); errF != nil { + return errF + } } if a.post != nil && !a.post(&cur) { return errAbort } return nil } - -// EqualsSliceOfVindexParam does deep equals between the two objects. -func EqualsSliceOfVindexParam(a, b []VindexParam) bool { - if len(a) != len(b) { - return false +func (a *application) rewriteTableOptions(parent SQLNode, node TableOptions, replacer replacerFunc) error { + if node == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsVindexParam(a[i], b[i]) { - return false - } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return true -} - -// CloneSliceOfVindexParam creates a deep clone of the input. -func CloneSliceOfVindexParam(n []VindexParam) []VindexParam { - res := make([]VindexParam, 0, len(n)) - for _, x := range n { - res = append(res, CloneVindexParam(x)) + if a.pre != nil && !a.pre(&cur) { + return nil } - return res -} - -// EqualsCollateAndCharset does deep equals between the two objects. -func EqualsCollateAndCharset(a, b CollateAndCharset) bool { - return a.IsDefault == b.IsDefault && - a.Value == b.Value && - a.Type == b.Type -} - -// CloneCollateAndCharset creates a deep clone of the input. -func CloneCollateAndCharset(n CollateAndCharset) CollateAndCharset { - return *CloneRefOfCollateAndCharset(&n) + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// EqualsRefOfIndexColumn does deep equals between the two objects. -func EqualsRefOfIndexColumn(a, b *IndexColumn) bool { - if a == b { - return true +func (a *application) rewriteUpdateExprs(parent SQLNode, node UpdateExprs, replacer replacerFunc) error { + if node == nil { + return nil } - if a == nil || b == nil { - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return EqualsColIdent(a.Column, b.Column) && - EqualsRefOfLiteral(a.Length, b.Length) && - a.Direction == b.Direction -} - -// CloneRefOfIndexColumn creates a deep clone of the input. -func CloneRefOfIndexColumn(n *IndexColumn) *IndexColumn { - if n == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - out := *n - out.Column = CloneColIdent(n.Column) - out.Length = CloneRefOfLiteral(n.Length) - return &out -} - -// EqualsRefOfIndexOption does deep equals between the two objects. -func EqualsRefOfIndexOption(a, b *IndexOption) bool { - if a == b { - return true + for i, el := range node { + if errF := a.rewriteRefOfUpdateExpr(node, el, func(newNode, parent SQLNode) { + parent.(UpdateExprs)[i] = newNode.(*UpdateExpr) + }); errF != nil { + return errF + } } - if a == nil || b == nil { - return false + if a.post != nil && !a.post(&cur) { + return errAbort } - return a.Name == b.Name && - a.String == b.String && - EqualsRefOfLiteral(a.Value, b.Value) + return nil } - -// CloneRefOfIndexOption creates a deep clone of the input. -func CloneRefOfIndexOption(n *IndexOption) *IndexOption { - if n == nil { +func (a *application) rewriteValTuple(parent SQLNode, node ValTuple, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.Value = CloneRefOfLiteral(n.Value) - return &out -} - -// EqualsRefOfTableAndLockType does deep equals between the two objects. -func EqualsRefOfTableAndLockType(a, b *TableAndLockType) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - return EqualsTableExpr(a.Table, b.Table) && - a.Lock == b.Lock -} - -// CloneRefOfTableAndLockType creates a deep clone of the input. -func CloneRefOfTableAndLockType(n *TableAndLockType) *TableAndLockType { - if n == nil { + if a.pre != nil && !a.pre(&cur) { return nil } - out := *n - out.Table = CloneTableExpr(n.Table) - return &out -} - -// EqualsRefOfRenameTablePair does deep equals between the two objects. -func EqualsRefOfRenameTablePair(a, b *RenameTablePair) bool { - if a == b { - return true + for i, el := range node { + if errF := a.rewriteExpr(node, el, func(newNode, parent SQLNode) { + parent.(ValTuple)[i] = newNode.(Expr) + }); errF != nil { + return errF + } } - if a == nil || b == nil { - return false + if a.post != nil && !a.post(&cur) { + return errAbort } - return EqualsTableName(a.FromTable, b.FromTable) && - EqualsTableName(a.ToTable, b.ToTable) + return nil } - -// CloneRefOfRenameTablePair creates a deep clone of the input. -func CloneRefOfRenameTablePair(n *RenameTablePair) *RenameTablePair { - if n == nil { +func (a *application) rewriteValues(parent SQLNode, node Values, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.FromTable = CloneTableName(n.FromTable) - out.ToTable = CloneTableName(n.ToTable) - return &out -} - -// EqualsRefOfCollateAndCharset does deep equals between the two objects. -func EqualsRefOfCollateAndCharset(a, b *CollateAndCharset) bool { - if a == b { - return true + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, } - if a == nil || b == nil { - return false + if a.pre != nil && !a.pre(&cur) { + return nil } - return a.IsDefault == b.IsDefault && - a.Value == b.Value && - a.Type == b.Type + for i, el := range node { + if errF := a.rewriteValTuple(node, el, func(newNode, parent SQLNode) { + parent.(Values)[i] = newNode.(ValTuple) + }); errF != nil { + return errF + } + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } - -// CloneRefOfCollateAndCharset creates a deep clone of the input. -func CloneRefOfCollateAndCharset(n *CollateAndCharset) *CollateAndCharset { - if n == nil { +func (a *application) rewriteVindexParam(parent SQLNode, node VindexParam, replacer replacerFunc) error { + var err error + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if a.pre != nil && !a.pre(&cur) { return nil } - out := *n - return &out + if errF := a.rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Key' on 'VindexParam'") + }); errF != nil { + return errF + } + if err != nil { + return err + } + if a.post != nil && !a.post(&cur) { + return errAbort + } + return nil } From 2a241a5b54c3d9875b16000e2d970451686e6b18 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Sat, 20 Mar 2021 09:20:03 +0100 Subject: [PATCH 11/15] move the cursor into the application to save on object creation Signed-off-by: Andres Taylor --- .../asthelpergen/integration/ast_helper.go | 225 +- go/tools/asthelpergen/integration/types.go | 1 + go/tools/asthelpergen/rewrite_gen.go | 105 +- go/vt/sqlparser/ast_helper.go | 2265 +++++++++-------- go/vt/sqlparser/rewriter_api.go | 2 +- go/vt/sqlparser/walker_test.go | 2 +- 6 files changed, 1369 insertions(+), 1231 deletions(-) diff --git a/go/tools/asthelpergen/integration/ast_helper.go b/go/tools/asthelpergen/integration/ast_helper.go index 1c4cb991de7..d2775185d0b 100644 --- a/go/tools/asthelpergen/integration/ast_helper.go +++ b/go/tools/asthelpergen/integration/ast_helper.go @@ -797,15 +797,16 @@ func (a *application) rewriteAST(parent AST, node AST, replacer replacerFunc) er } } func (a *application) rewriteBasicType(parent AST, node BasicType, replacer replacerFunc) error { - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -814,33 +815,35 @@ func (a *application) rewriteBytes(parent AST, node Bytes, replacer replacerFunc if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil } func (a *application) rewriteInterfaceContainer(parent AST, node InterfaceContainer, replacer replacerFunc) error { var err error - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if err != nil { return err } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -849,12 +852,10 @@ func (a *application) rewriteInterfaceSlice(parent AST, node InterfaceSlice, rep if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node { @@ -864,7 +865,10 @@ func (a *application) rewriteInterfaceSlice(parent AST, node InterfaceSlice, rep return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -873,12 +877,10 @@ func (a *application) rewriteLeafSlice(parent AST, node LeafSlice, replacer repl if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node { @@ -888,7 +890,10 @@ func (a *application) rewriteLeafSlice(parent AST, node LeafSlice, replacer repl return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -897,15 +902,16 @@ func (a *application) rewriteRefOfInterfaceContainer(parent AST, node *Interface if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -914,15 +920,16 @@ func (a *application) rewriteRefOfLeaf(parent AST, node *Leaf, replacer replacer if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -931,15 +938,16 @@ func (a *application) rewriteRefOfNoCloneType(parent AST, node *NoCloneType, rep if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -948,12 +956,10 @@ func (a *application) rewriteRefOfRefContainer(parent AST, node *RefContainer, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteAST(node, node.ASTType, func(newNode, parent AST) { @@ -966,7 +972,10 @@ func (a *application) rewriteRefOfRefContainer(parent AST, node *RefContainer, r }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -975,12 +984,10 @@ func (a *application) rewriteRefOfRefSliceContainer(parent AST, node *RefSliceCo if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node.ASTElements { @@ -997,7 +1004,10 @@ func (a *application) rewriteRefOfRefSliceContainer(parent AST, node *RefSliceCo return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -1006,12 +1016,10 @@ func (a *application) rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer re if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteSubIface(node, node.inner, func(newNode, parent AST) { @@ -1019,7 +1027,10 @@ func (a *application) rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer re }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -1028,12 +1039,10 @@ func (a *application) rewriteRefOfValueContainer(parent AST, node *ValueContaine if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteAST(node, node.ASTType, func(newNode, parent AST) { @@ -1046,7 +1055,10 @@ func (a *application) rewriteRefOfValueContainer(parent AST, node *ValueContaine }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -1055,12 +1067,10 @@ func (a *application) rewriteRefOfValueSliceContainer(parent AST, node *ValueSli if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node.ASTElements { @@ -1077,7 +1087,10 @@ func (a *application) rewriteRefOfValueSliceContainer(parent AST, node *ValueSli return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -1096,12 +1109,10 @@ func (a *application) rewriteSubIface(parent AST, node SubIface, replacer replac } func (a *application) rewriteValueContainer(parent AST, node ValueContainer, replacer replacerFunc) error { var err error - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteAST(node, node.ASTType, func(newNode, parent AST) { @@ -1117,19 +1128,20 @@ func (a *application) rewriteValueContainer(parent AST, node ValueContainer, rep if err != nil { return err } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil } func (a *application) rewriteValueSliceContainer(parent AST, node ValueSliceContainer, replacer replacerFunc) error { var err error - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for _, el := range node.ASTElements { @@ -1149,7 +1161,10 @@ func (a *application) rewriteValueSliceContainer(parent AST, node ValueSliceCont if err != nil { return err } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil diff --git a/go/tools/asthelpergen/integration/types.go b/go/tools/asthelpergen/integration/types.go index b45b6950839..020a777248c 100644 --- a/go/tools/asthelpergen/integration/types.go +++ b/go/tools/asthelpergen/integration/types.go @@ -178,4 +178,5 @@ var errAbort = fmt.Errorf("this error is to abort the rewriter, it is not an act type application struct { pre, post ApplyFunc + cur Cursor } diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go index 56cf2dcffe1..3d74570ff89 100644 --- a/go/tools/asthelpergen/rewrite_gen.go +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -89,17 +89,13 @@ func (e rewriteGen) structMethod(t types.Type, strct *types.Struct, spi generato return nil } - stmts := []jen.Code{ - jen.Var().Id("err").Error(), - createCursor(), - executePre(), - } + stmts := []jen.Code{jen.Var().Id("err").Error()} + stmts = append(stmts, executePre()...) stmts = append(stmts, e.rewriteAllStructFields(t, strct, spi, true)...) - stmts = append(stmts, - jen.If(jen.Id("err != nil")).Block(jen.Return(jen.Err())), - executePost(), - returnNil(), - ) + stmts = append(stmts, jen.If(jen.Id("err != nil")).Block(jen.Return(jen.Err()))) + stmts = append(stmts, executePost()...) + stmts = append(stmts, returnNil()) + e.rewriteFunc(t, stmts, spi) return nil @@ -110,49 +106,26 @@ func (e rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi gen return nil } - stmts := []jen.Code{ - /* - if node == nil { return nil } - */ - jen.If(jen.Id("node == nil").Block(returnNil())), - - /* - cur := Cursor{ - parent: parent, - replacer: replacer, - node: node, - } - */ - createCursor(), - - /* - if !pre(&cur) { - return nil - } - */ - executePre(), - } + /* + if node == nil { return nil } + */ + stmts := []jen.Code{jen.If(jen.Id("node == nil").Block(returnNil()))} + /* + if !pre(&cur) { + return nil + } + */ + stmts = append(stmts, executePre()...) stmts = append(stmts, e.rewriteAllStructFields(t, strct, spi, false)...) + stmts = append(stmts, executePost()...) + stmts = append(stmts, returnNil()) - stmts = append(stmts, - executePost(), - returnNil(), - ) e.rewriteFunc(t, stmts, spi) return nil } -func createCursor() *jen.Statement { - return jen.Id("cur := Cursor").Values( - jen.Dict{ - jen.Id("parent"): jen.Id("parent"), - jen.Id("replacer"): jen.Id("replacer"), - jen.Id("node"): jen.Id("node"), - }) -} - func (e rewriteGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { if !shouldAdd(t, spi.iface()) { return nil @@ -189,9 +162,8 @@ func (e rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorS */ stmts := []jen.Code{ jen.If(jen.Id("node == nil").Block(returnNil())), - createCursor(), - executePre(), } + stmts = append(stmts, executePre()...) if shouldAdd(slice.Elem(), spi.iface()) { /* @@ -208,27 +180,29 @@ func (e rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorS Block(e.rewriteChild(t, slice.Elem(), "notUsed", jen.Id("el"), jen.Index(jen.Id("i")), false))) } - stmts = append(stmts, - /* - if !post(&cur) { - return errAbort - } - return nil + stmts = append(stmts, executePost()...) + stmts = append(stmts, returnNil()) - */ - executePost(), - returnNil(), - ) e.rewriteFunc(t, stmts, spi) return nil } -func executePre() *jen.Statement { - return jen.If(jen.Id("a.pre!= nil && !a.pre(&cur)")).Block(returnNil()) +func executePre() []jen.Code { + return []jen.Code{ + jen.Id("a.cur.replacer = replacer"), + jen.Id("a.cur.parent = parent"), + jen.Id("a.cur.node = node"), + jen.If(jen.Id("a.pre!= nil && !a.pre(&a.cur)")).Block(returnNil()), + } } -func executePost() *jen.Statement { - return jen.If(jen.Id("a.post != nil && !a.post(&cur)")).Block(jen.Return(jen.Id(abort))) +func executePost() []jen.Code { + return []jen.Code{ + jen.Id("a.cur.replacer = replacer"), + jen.Id("a.cur.parent = parent"), + jen.Id("a.cur.node = node"), + jen.If(jen.Id("a.post != nil && !a.post(&a.cur)")).Block(jen.Return(jen.Id(abort))), + } } func (e rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { @@ -236,12 +210,9 @@ func (e rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI) return nil } - stmts := []jen.Code{ - createCursor(), - executePre(), - executePost(), - returnNil(), - } + stmts := executePre() + stmts = append(stmts, executePost()...) + stmts = append(stmts, returnNil()) e.rewriteFunc(t, stmts, spi) return nil diff --git a/go/vt/sqlparser/ast_helper.go b/go/vt/sqlparser/ast_helper.go index 1f45bd76d86..2cb85eaeda6 100644 --- a/go/vt/sqlparser/ast_helper.go +++ b/go/vt/sqlparser/ast_helper.go @@ -9297,29 +9297,31 @@ func VisitVindexParam(in VindexParam, f Visit) error { return nil } func (a *application) rewriteAccessMode(parent SQLNode, node AccessMode, replacer replacerFunc) error { - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil } func (a *application) rewriteAlgorithmValue(parent SQLNode, node AlgorithmValue, replacer replacerFunc) error { - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9373,29 +9375,31 @@ func (a *application) rewriteAlterOption(parent SQLNode, node AlterOption, repla } } func (a *application) rewriteArgument(parent SQLNode, node Argument, replacer replacerFunc) error { - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil } func (a *application) rewriteBoolVal(parent SQLNode, node BoolVal, replacer replacerFunc) error { - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9416,18 +9420,19 @@ func (a *application) rewriteCharacteristic(parent SQLNode, node Characteristic, } func (a *application) rewriteColIdent(parent SQLNode, node ColIdent, replacer replacerFunc) error { var err error - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if err != nil { return err } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9452,12 +9457,10 @@ func (a *application) rewriteColumns(parent SQLNode, node Columns, replacer repl if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node { @@ -9467,7 +9470,10 @@ func (a *application) rewriteColumns(parent SQLNode, node Columns, replacer repl return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9476,15 +9482,16 @@ func (a *application) rewriteComments(parent SQLNode, node Comments, replacer re if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9635,12 +9642,10 @@ func (a *application) rewriteExprs(parent SQLNode, node Exprs, replacer replacer if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node { @@ -9650,7 +9655,10 @@ func (a *application) rewriteExprs(parent SQLNode, node Exprs, replacer replacer return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9659,12 +9667,10 @@ func (a *application) rewriteGroupBy(parent SQLNode, node GroupBy, replacer repl if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node { @@ -9674,7 +9680,10 @@ func (a *application) rewriteGroupBy(parent SQLNode, node GroupBy, replacer repl return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9698,27 +9707,26 @@ func (a *application) rewriteInsertRows(parent SQLNode, node InsertRows, replace } } func (a *application) rewriteIsolationLevel(parent SQLNode, node IsolationLevel, replacer replacerFunc) error { - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil } func (a *application) rewriteJoinCondition(parent SQLNode, node JoinCondition, replacer replacerFunc) error { var err error - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.On, func(newNode, parent SQLNode) { @@ -9734,7 +9742,10 @@ func (a *application) rewriteJoinCondition(parent SQLNode, node JoinCondition, r if err != nil { return err } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9743,15 +9754,16 @@ func (a *application) rewriteListArg(parent SQLNode, node ListArg, replacer repl if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9760,12 +9772,10 @@ func (a *application) rewriteOnDup(parent SQLNode, node OnDup, replacer replacer if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node { @@ -9775,7 +9785,10 @@ func (a *application) rewriteOnDup(parent SQLNode, node OnDup, replacer replacer return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9784,12 +9797,10 @@ func (a *application) rewriteOrderBy(parent SQLNode, node OrderBy, replacer repl if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node { @@ -9799,7 +9810,10 @@ func (a *application) rewriteOrderBy(parent SQLNode, node OrderBy, replacer repl return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9808,12 +9822,10 @@ func (a *application) rewritePartitions(parent SQLNode, node Partitions, replace if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node { @@ -9823,7 +9835,10 @@ func (a *application) rewritePartitions(parent SQLNode, node Partitions, replace return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9832,12 +9847,10 @@ func (a *application) rewriteRefOfAddColumns(parent SQLNode, node *AddColumns, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node.Columns { @@ -9857,7 +9870,10 @@ func (a *application) rewriteRefOfAddColumns(parent SQLNode, node *AddColumns, r }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9866,12 +9882,10 @@ func (a *application) rewriteRefOfAddConstraintDefinition(parent SQLNode, node * if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteRefOfConstraintDefinition(node, node.ConstraintDefinition, func(newNode, parent SQLNode) { @@ -9879,7 +9893,10 @@ func (a *application) rewriteRefOfAddConstraintDefinition(parent SQLNode, node * }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9888,12 +9905,10 @@ func (a *application) rewriteRefOfAddIndexDefinition(parent SQLNode, node *AddIn if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteRefOfIndexDefinition(node, node.IndexDefinition, func(newNode, parent SQLNode) { @@ -9901,7 +9916,10 @@ func (a *application) rewriteRefOfAddIndexDefinition(parent SQLNode, node *AddIn }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9910,12 +9928,10 @@ func (a *application) rewriteRefOfAliasedExpr(parent SQLNode, node *AliasedExpr, if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -9928,7 +9944,10 @@ func (a *application) rewriteRefOfAliasedExpr(parent SQLNode, node *AliasedExpr, }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9937,12 +9956,10 @@ func (a *application) rewriteRefOfAliasedTableExpr(parent SQLNode, node *Aliased if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteSimpleTableExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -9965,7 +9982,10 @@ func (a *application) rewriteRefOfAliasedTableExpr(parent SQLNode, node *Aliased }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9974,15 +9994,16 @@ func (a *application) rewriteRefOfAlterCharset(parent SQLNode, node *AlterCharse if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -9991,12 +10012,10 @@ func (a *application) rewriteRefOfAlterColumn(parent SQLNode, node *AlterColumn, if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteRefOfColName(node, node.Column, func(newNode, parent SQLNode) { @@ -10009,7 +10028,10 @@ func (a *application) rewriteRefOfAlterColumn(parent SQLNode, node *AlterColumn, }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10018,15 +10040,16 @@ func (a *application) rewriteRefOfAlterDatabase(parent SQLNode, node *AlterDatab if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10035,15 +10058,16 @@ func (a *application) rewriteRefOfAlterMigration(parent SQLNode, node *AlterMigr if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10052,12 +10076,10 @@ func (a *application) rewriteRefOfAlterTable(parent SQLNode, node *AlterTable, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { @@ -10077,7 +10099,10 @@ func (a *application) rewriteRefOfAlterTable(parent SQLNode, node *AlterTable, r }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10086,12 +10111,10 @@ func (a *application) rewriteRefOfAlterView(parent SQLNode, node *AlterView, rep if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { @@ -10109,7 +10132,10 @@ func (a *application) rewriteRefOfAlterView(parent SQLNode, node *AlterView, rep }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10118,12 +10144,10 @@ func (a *application) rewriteRefOfAlterVschema(parent SQLNode, node *AlterVschem if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { @@ -10148,7 +10172,10 @@ func (a *application) rewriteRefOfAlterVschema(parent SQLNode, node *AlterVschem }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10157,12 +10184,10 @@ func (a *application) rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replace if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { @@ -10175,7 +10200,10 @@ func (a *application) rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replace }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10184,12 +10212,10 @@ func (a *application) rewriteRefOfAutoIncSpec(parent SQLNode, node *AutoIncSpec, if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteColIdent(node, node.Column, func(newNode, parent SQLNode) { @@ -10202,7 +10228,10 @@ func (a *application) rewriteRefOfAutoIncSpec(parent SQLNode, node *AutoIncSpec, }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10211,15 +10240,16 @@ func (a *application) rewriteRefOfBegin(parent SQLNode, node *Begin, replacer re if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10228,12 +10258,10 @@ func (a *application) rewriteRefOfBinaryExpr(parent SQLNode, node *BinaryExpr, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { @@ -10246,7 +10274,10 @@ func (a *application) rewriteRefOfBinaryExpr(parent SQLNode, node *BinaryExpr, r }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10255,12 +10286,10 @@ func (a *application) rewriteRefOfCallProc(parent SQLNode, node *CallProc, repla if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableName(node, node.Name, func(newNode, parent SQLNode) { @@ -10273,7 +10302,10 @@ func (a *application) rewriteRefOfCallProc(parent SQLNode, node *CallProc, repla }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10282,12 +10314,10 @@ func (a *application) rewriteRefOfCaseExpr(parent SQLNode, node *CaseExpr, repla if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -10307,7 +10337,10 @@ func (a *application) rewriteRefOfCaseExpr(parent SQLNode, node *CaseExpr, repla }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10316,12 +10349,10 @@ func (a *application) rewriteRefOfChangeColumn(parent SQLNode, node *ChangeColum if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteRefOfColName(node, node.OldColumn, func(newNode, parent SQLNode) { @@ -10344,7 +10375,10 @@ func (a *application) rewriteRefOfChangeColumn(parent SQLNode, node *ChangeColum }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10353,12 +10387,10 @@ func (a *application) rewriteRefOfCheckConstraintDefinition(parent SQLNode, node if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -10366,7 +10398,10 @@ func (a *application) rewriteRefOfCheckConstraintDefinition(parent SQLNode, node }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10375,15 +10410,16 @@ func (a *application) rewriteRefOfColIdent(parent SQLNode, node *ColIdent, repla if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10392,12 +10428,10 @@ func (a *application) rewriteRefOfColName(parent SQLNode, node *ColName, replace if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -10410,7 +10444,10 @@ func (a *application) rewriteRefOfColName(parent SQLNode, node *ColName, replace }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10419,12 +10456,10 @@ func (a *application) rewriteRefOfCollateExpr(parent SQLNode, node *CollateExpr, if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -10432,7 +10467,10 @@ func (a *application) rewriteRefOfCollateExpr(parent SQLNode, node *CollateExpr, }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10441,12 +10479,10 @@ func (a *application) rewriteRefOfColumnDefinition(parent SQLNode, node *ColumnD if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -10454,7 +10490,10 @@ func (a *application) rewriteRefOfColumnDefinition(parent SQLNode, node *ColumnD }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10463,12 +10502,10 @@ func (a *application) rewriteRefOfColumnType(parent SQLNode, node *ColumnType, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { @@ -10481,7 +10518,10 @@ func (a *application) rewriteRefOfColumnType(parent SQLNode, node *ColumnType, r }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10490,15 +10530,16 @@ func (a *application) rewriteRefOfCommit(parent SQLNode, node *Commit, replacer if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10507,12 +10548,10 @@ func (a *application) rewriteRefOfComparisonExpr(parent SQLNode, node *Compariso if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { @@ -10530,7 +10569,10 @@ func (a *application) rewriteRefOfComparisonExpr(parent SQLNode, node *Compariso }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10539,12 +10581,10 @@ func (a *application) rewriteRefOfConstraintDefinition(parent SQLNode, node *Con if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteConstraintInfo(node, node.Details, func(newNode, parent SQLNode) { @@ -10552,7 +10592,10 @@ func (a *application) rewriteRefOfConstraintDefinition(parent SQLNode, node *Con }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10561,12 +10604,10 @@ func (a *application) rewriteRefOfConvertExpr(parent SQLNode, node *ConvertExpr, if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -10579,7 +10620,10 @@ func (a *application) rewriteRefOfConvertExpr(parent SQLNode, node *ConvertExpr, }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10588,12 +10632,10 @@ func (a *application) rewriteRefOfConvertType(parent SQLNode, node *ConvertType, if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { @@ -10606,7 +10648,10 @@ func (a *application) rewriteRefOfConvertType(parent SQLNode, node *ConvertType, }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10615,12 +10660,10 @@ func (a *application) rewriteRefOfConvertUsingExpr(parent SQLNode, node *Convert if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -10628,7 +10671,10 @@ func (a *application) rewriteRefOfConvertUsingExpr(parent SQLNode, node *Convert }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10637,12 +10683,10 @@ func (a *application) rewriteRefOfCreateDatabase(parent SQLNode, node *CreateDat if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -10650,7 +10694,10 @@ func (a *application) rewriteRefOfCreateDatabase(parent SQLNode, node *CreateDat }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10659,12 +10706,10 @@ func (a *application) rewriteRefOfCreateTable(parent SQLNode, node *CreateTable, if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { @@ -10682,7 +10727,10 @@ func (a *application) rewriteRefOfCreateTable(parent SQLNode, node *CreateTable, }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10691,12 +10739,10 @@ func (a *application) rewriteRefOfCreateView(parent SQLNode, node *CreateView, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { @@ -10714,7 +10760,10 @@ func (a *application) rewriteRefOfCreateView(parent SQLNode, node *CreateView, r }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10723,12 +10772,10 @@ func (a *application) rewriteRefOfCurTimeFuncExpr(parent SQLNode, node *CurTimeF if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -10741,7 +10788,10 @@ func (a *application) rewriteRefOfCurTimeFuncExpr(parent SQLNode, node *CurTimeF }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10750,15 +10800,16 @@ func (a *application) rewriteRefOfDefault(parent SQLNode, node *Default, replace if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10767,12 +10818,10 @@ func (a *application) rewriteRefOfDelete(parent SQLNode, node *Delete, replacer if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -10810,7 +10859,10 @@ func (a *application) rewriteRefOfDelete(parent SQLNode, node *Delete, replacer }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10819,12 +10871,10 @@ func (a *application) rewriteRefOfDerivedTable(parent SQLNode, node *DerivedTabl if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { @@ -10832,7 +10882,10 @@ func (a *application) rewriteRefOfDerivedTable(parent SQLNode, node *DerivedTabl }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10841,12 +10894,10 @@ func (a *application) rewriteRefOfDropColumn(parent SQLNode, node *DropColumn, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { @@ -10854,7 +10905,10 @@ func (a *application) rewriteRefOfDropColumn(parent SQLNode, node *DropColumn, r }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10863,12 +10917,10 @@ func (a *application) rewriteRefOfDropDatabase(parent SQLNode, node *DropDatabas if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -10876,7 +10928,10 @@ func (a *application) rewriteRefOfDropDatabase(parent SQLNode, node *DropDatabas }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10885,15 +10940,16 @@ func (a *application) rewriteRefOfDropKey(parent SQLNode, node *DropKey, replace if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10902,12 +10958,10 @@ func (a *application) rewriteRefOfDropTable(parent SQLNode, node *DropTable, rep if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { @@ -10915,7 +10969,10 @@ func (a *application) rewriteRefOfDropTable(parent SQLNode, node *DropTable, rep }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10924,12 +10981,10 @@ func (a *application) rewriteRefOfDropView(parent SQLNode, node *DropView, repla if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { @@ -10937,7 +10992,10 @@ func (a *application) rewriteRefOfDropView(parent SQLNode, node *DropView, repla }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10946,12 +11004,10 @@ func (a *application) rewriteRefOfExistsExpr(parent SQLNode, node *ExistsExpr, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteRefOfSubquery(node, node.Subquery, func(newNode, parent SQLNode) { @@ -10959,7 +11015,10 @@ func (a *application) rewriteRefOfExistsExpr(parent SQLNode, node *ExistsExpr, r }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10968,12 +11027,10 @@ func (a *application) rewriteRefOfExplainStmt(parent SQLNode, node *ExplainStmt, if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteStatement(node, node.Statement, func(newNode, parent SQLNode) { @@ -10981,7 +11038,10 @@ func (a *application) rewriteRefOfExplainStmt(parent SQLNode, node *ExplainStmt, }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -10990,12 +11050,10 @@ func (a *application) rewriteRefOfExplainTab(parent SQLNode, node *ExplainTab, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { @@ -11003,7 +11061,10 @@ func (a *application) rewriteRefOfExplainTab(parent SQLNode, node *ExplainTab, r }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11012,12 +11073,10 @@ func (a *application) rewriteRefOfFlush(parent SQLNode, node *Flush, replacer re if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableNames(node, node.TableNames, func(newNode, parent SQLNode) { @@ -11025,7 +11084,10 @@ func (a *application) rewriteRefOfFlush(parent SQLNode, node *Flush, replacer re }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11034,15 +11096,16 @@ func (a *application) rewriteRefOfForce(parent SQLNode, node *Force, replacer re if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11051,12 +11114,10 @@ func (a *application) rewriteRefOfForeignKeyDefinition(parent SQLNode, node *For if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteColumns(node, node.Source, func(newNode, parent SQLNode) { @@ -11084,7 +11145,10 @@ func (a *application) rewriteRefOfForeignKeyDefinition(parent SQLNode, node *For }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11093,12 +11157,10 @@ func (a *application) rewriteRefOfFuncExpr(parent SQLNode, node *FuncExpr, repla if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { @@ -11116,7 +11178,10 @@ func (a *application) rewriteRefOfFuncExpr(parent SQLNode, node *FuncExpr, repla }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11125,12 +11190,10 @@ func (a *application) rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupCon if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteSelectExprs(node, node.Exprs, func(newNode, parent SQLNode) { @@ -11148,7 +11211,10 @@ func (a *application) rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupCon }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11157,12 +11223,10 @@ func (a *application) rewriteRefOfIndexDefinition(parent SQLNode, node *IndexDef if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteRefOfIndexInfo(node, node.Info, func(newNode, parent SQLNode) { @@ -11170,7 +11234,10 @@ func (a *application) rewriteRefOfIndexDefinition(parent SQLNode, node *IndexDef }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11179,12 +11246,10 @@ func (a *application) rewriteRefOfIndexHints(parent SQLNode, node *IndexHints, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node.Indexes { @@ -11194,7 +11259,10 @@ func (a *application) rewriteRefOfIndexHints(parent SQLNode, node *IndexHints, r return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11203,12 +11271,10 @@ func (a *application) rewriteRefOfIndexInfo(parent SQLNode, node *IndexInfo, rep if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -11221,7 +11287,10 @@ func (a *application) rewriteRefOfIndexInfo(parent SQLNode, node *IndexInfo, rep }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11230,12 +11299,10 @@ func (a *application) rewriteRefOfInsert(parent SQLNode, node *Insert, replacer if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -11268,7 +11335,10 @@ func (a *application) rewriteRefOfInsert(parent SQLNode, node *Insert, replacer }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11277,12 +11347,10 @@ func (a *application) rewriteRefOfIntervalExpr(parent SQLNode, node *IntervalExp if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -11290,7 +11358,10 @@ func (a *application) rewriteRefOfIntervalExpr(parent SQLNode, node *IntervalExp }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11299,12 +11370,10 @@ func (a *application) rewriteRefOfIsExpr(parent SQLNode, node *IsExpr, replacer if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -11312,7 +11381,10 @@ func (a *application) rewriteRefOfIsExpr(parent SQLNode, node *IsExpr, replacer }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11321,12 +11393,10 @@ func (a *application) rewriteRefOfJoinCondition(parent SQLNode, node *JoinCondit if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.On, func(newNode, parent SQLNode) { @@ -11339,7 +11409,10 @@ func (a *application) rewriteRefOfJoinCondition(parent SQLNode, node *JoinCondit }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11348,12 +11421,10 @@ func (a *application) rewriteRefOfJoinTableExpr(parent SQLNode, node *JoinTableE if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableExpr(node, node.LeftExpr, func(newNode, parent SQLNode) { @@ -11371,7 +11442,10 @@ func (a *application) rewriteRefOfJoinTableExpr(parent SQLNode, node *JoinTableE }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11380,15 +11454,16 @@ func (a *application) rewriteRefOfKeyState(parent SQLNode, node *KeyState, repla if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11397,12 +11472,10 @@ func (a *application) rewriteRefOfLimit(parent SQLNode, node *Limit, replacer re if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Offset, func(newNode, parent SQLNode) { @@ -11415,7 +11488,10 @@ func (a *application) rewriteRefOfLimit(parent SQLNode, node *Limit, replacer re }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11424,15 +11500,16 @@ func (a *application) rewriteRefOfLiteral(parent SQLNode, node *Literal, replace if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11441,15 +11518,16 @@ func (a *application) rewriteRefOfLoad(parent SQLNode, node *Load, replacer repl if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11458,15 +11536,16 @@ func (a *application) rewriteRefOfLockOption(parent SQLNode, node *LockOption, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11475,15 +11554,16 @@ func (a *application) rewriteRefOfLockTables(parent SQLNode, node *LockTables, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11492,12 +11572,10 @@ func (a *application) rewriteRefOfMatchExpr(parent SQLNode, node *MatchExpr, rep if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteSelectExprs(node, node.Columns, func(newNode, parent SQLNode) { @@ -11510,7 +11588,10 @@ func (a *application) rewriteRefOfMatchExpr(parent SQLNode, node *MatchExpr, rep }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11519,12 +11600,10 @@ func (a *application) rewriteRefOfModifyColumn(parent SQLNode, node *ModifyColum if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteRefOfColumnDefinition(node, node.NewColDefinition, func(newNode, parent SQLNode) { @@ -11542,7 +11621,10 @@ func (a *application) rewriteRefOfModifyColumn(parent SQLNode, node *ModifyColum }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11551,12 +11633,10 @@ func (a *application) rewriteRefOfNextval(parent SQLNode, node *Nextval, replace if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -11564,7 +11644,10 @@ func (a *application) rewriteRefOfNextval(parent SQLNode, node *Nextval, replace }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11573,12 +11656,10 @@ func (a *application) rewriteRefOfNotExpr(parent SQLNode, node *NotExpr, replace if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -11586,7 +11667,10 @@ func (a *application) rewriteRefOfNotExpr(parent SQLNode, node *NotExpr, replace }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11595,15 +11679,16 @@ func (a *application) rewriteRefOfNullVal(parent SQLNode, node *NullVal, replace if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11612,12 +11697,10 @@ func (a *application) rewriteRefOfOptLike(parent SQLNode, node *OptLike, replace if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableName(node, node.LikeTable, func(newNode, parent SQLNode) { @@ -11625,7 +11708,10 @@ func (a *application) rewriteRefOfOptLike(parent SQLNode, node *OptLike, replace }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11634,12 +11720,10 @@ func (a *application) rewriteRefOfOrExpr(parent SQLNode, node *OrExpr, replacer if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { @@ -11652,7 +11736,10 @@ func (a *application) rewriteRefOfOrExpr(parent SQLNode, node *OrExpr, replacer }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11661,12 +11748,10 @@ func (a *application) rewriteRefOfOrder(parent SQLNode, node *Order, replacer re if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -11674,7 +11759,10 @@ func (a *application) rewriteRefOfOrder(parent SQLNode, node *Order, replacer re }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11683,12 +11771,10 @@ func (a *application) rewriteRefOfOrderByOption(parent SQLNode, node *OrderByOpt if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteColumns(node, node.Cols, func(newNode, parent SQLNode) { @@ -11696,7 +11782,10 @@ func (a *application) rewriteRefOfOrderByOption(parent SQLNode, node *OrderByOpt }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11705,15 +11794,16 @@ func (a *application) rewriteRefOfOtherAdmin(parent SQLNode, node *OtherAdmin, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11722,15 +11812,16 @@ func (a *application) rewriteRefOfOtherRead(parent SQLNode, node *OtherRead, rep if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11739,12 +11830,10 @@ func (a *application) rewriteRefOfParenSelect(parent SQLNode, node *ParenSelect, if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { @@ -11752,7 +11841,10 @@ func (a *application) rewriteRefOfParenSelect(parent SQLNode, node *ParenSelect, }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11761,12 +11853,10 @@ func (a *application) rewriteRefOfParenTableExpr(parent SQLNode, node *ParenTabl if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableExprs(node, node.Exprs, func(newNode, parent SQLNode) { @@ -11774,7 +11864,10 @@ func (a *application) rewriteRefOfParenTableExpr(parent SQLNode, node *ParenTabl }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11783,12 +11876,10 @@ func (a *application) rewriteRefOfPartitionDefinition(parent SQLNode, node *Part if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -11801,7 +11892,10 @@ func (a *application) rewriteRefOfPartitionDefinition(parent SQLNode, node *Part }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11810,12 +11904,10 @@ func (a *application) rewriteRefOfPartitionSpec(parent SQLNode, node *PartitionS if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewritePartitions(node, node.Names, func(newNode, parent SQLNode) { @@ -11840,7 +11932,10 @@ func (a *application) rewriteRefOfPartitionSpec(parent SQLNode, node *PartitionS return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11849,12 +11944,10 @@ func (a *application) rewriteRefOfRangeCond(parent SQLNode, node *RangeCond, rep if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { @@ -11872,7 +11965,10 @@ func (a *application) rewriteRefOfRangeCond(parent SQLNode, node *RangeCond, rep }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11881,12 +11977,10 @@ func (a *application) rewriteRefOfRelease(parent SQLNode, node *Release, replace if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -11894,7 +11988,10 @@ func (a *application) rewriteRefOfRelease(parent SQLNode, node *Release, replace }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11903,15 +12000,16 @@ func (a *application) rewriteRefOfRenameIndex(parent SQLNode, node *RenameIndex, if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11920,15 +12018,16 @@ func (a *application) rewriteRefOfRenameTable(parent SQLNode, node *RenameTable, if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11937,12 +12036,10 @@ func (a *application) rewriteRefOfRenameTableName(parent SQLNode, node *RenameTa if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { @@ -11950,7 +12047,10 @@ func (a *application) rewriteRefOfRenameTableName(parent SQLNode, node *RenameTa }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11959,15 +12059,16 @@ func (a *application) rewriteRefOfRevertMigration(parent SQLNode, node *RevertMi if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11976,15 +12077,16 @@ func (a *application) rewriteRefOfRollback(parent SQLNode, node *Rollback, repla if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -11993,12 +12095,10 @@ func (a *application) rewriteRefOfSRollback(parent SQLNode, node *SRollback, rep if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -12006,7 +12106,10 @@ func (a *application) rewriteRefOfSRollback(parent SQLNode, node *SRollback, rep }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12015,12 +12118,10 @@ func (a *application) rewriteRefOfSavepoint(parent SQLNode, node *Savepoint, rep if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -12028,7 +12129,10 @@ func (a *application) rewriteRefOfSavepoint(parent SQLNode, node *Savepoint, rep }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12037,12 +12141,10 @@ func (a *application) rewriteRefOfSelect(parent SQLNode, node *Select, replacer if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -12090,7 +12192,10 @@ func (a *application) rewriteRefOfSelect(parent SQLNode, node *Select, replacer }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12099,15 +12204,16 @@ func (a *application) rewriteRefOfSelectInto(parent SQLNode, node *SelectInto, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12116,12 +12222,10 @@ func (a *application) rewriteRefOfSet(parent SQLNode, node *Set, replacer replac if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -12134,7 +12238,10 @@ func (a *application) rewriteRefOfSet(parent SQLNode, node *Set, replacer replac }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12143,12 +12250,10 @@ func (a *application) rewriteRefOfSetExpr(parent SQLNode, node *SetExpr, replace if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -12161,7 +12266,10 @@ func (a *application) rewriteRefOfSetExpr(parent SQLNode, node *SetExpr, replace }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12170,12 +12278,10 @@ func (a *application) rewriteRefOfSetTransaction(parent SQLNode, node *SetTransa if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteSQLNode(node, node.SQLNode, func(newNode, parent SQLNode) { @@ -12195,7 +12301,10 @@ func (a *application) rewriteRefOfSetTransaction(parent SQLNode, node *SetTransa return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12204,12 +12313,10 @@ func (a *application) rewriteRefOfShow(parent SQLNode, node *Show, replacer repl if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteShowInternal(node, node.Internal, func(newNode, parent SQLNode) { @@ -12217,7 +12324,10 @@ func (a *application) rewriteRefOfShow(parent SQLNode, node *Show, replacer repl }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12226,12 +12336,10 @@ func (a *application) rewriteRefOfShowBasic(parent SQLNode, node *ShowBasic, rep if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableName(node, node.Tbl, func(newNode, parent SQLNode) { @@ -12244,7 +12352,10 @@ func (a *application) rewriteRefOfShowBasic(parent SQLNode, node *ShowBasic, rep }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12253,12 +12364,10 @@ func (a *application) rewriteRefOfShowCreate(parent SQLNode, node *ShowCreate, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableName(node, node.Op, func(newNode, parent SQLNode) { @@ -12266,7 +12375,10 @@ func (a *application) rewriteRefOfShowCreate(parent SQLNode, node *ShowCreate, r }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12275,12 +12387,10 @@ func (a *application) rewriteRefOfShowFilter(parent SQLNode, node *ShowFilter, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Filter, func(newNode, parent SQLNode) { @@ -12288,7 +12398,10 @@ func (a *application) rewriteRefOfShowFilter(parent SQLNode, node *ShowFilter, r }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12297,12 +12410,10 @@ func (a *application) rewriteRefOfShowLegacy(parent SQLNode, node *ShowLegacy, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableName(node, node.OnTable, func(newNode, parent SQLNode) { @@ -12320,7 +12431,10 @@ func (a *application) rewriteRefOfShowLegacy(parent SQLNode, node *ShowLegacy, r }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12329,12 +12443,10 @@ func (a *application) rewriteRefOfStarExpr(parent SQLNode, node *StarExpr, repla if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableName(node, node.TableName, func(newNode, parent SQLNode) { @@ -12342,7 +12454,10 @@ func (a *application) rewriteRefOfStarExpr(parent SQLNode, node *StarExpr, repla }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12351,12 +12466,10 @@ func (a *application) rewriteRefOfStream(parent SQLNode, node *Stream, replacer if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -12374,7 +12487,10 @@ func (a *application) rewriteRefOfStream(parent SQLNode, node *Stream, replacer }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12383,12 +12499,10 @@ func (a *application) rewriteRefOfSubquery(parent SQLNode, node *Subquery, repla if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { @@ -12396,7 +12510,10 @@ func (a *application) rewriteRefOfSubquery(parent SQLNode, node *Subquery, repla }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12405,12 +12522,10 @@ func (a *application) rewriteRefOfSubstrExpr(parent SQLNode, node *SubstrExpr, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { @@ -12433,7 +12548,10 @@ func (a *application) rewriteRefOfSubstrExpr(parent SQLNode, node *SubstrExpr, r }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12442,15 +12560,16 @@ func (a *application) rewriteRefOfTableIdent(parent SQLNode, node *TableIdent, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12459,12 +12578,10 @@ func (a *application) rewriteRefOfTableName(parent SQLNode, node *TableName, rep if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -12477,7 +12594,10 @@ func (a *application) rewriteRefOfTableName(parent SQLNode, node *TableName, rep }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12486,12 +12606,10 @@ func (a *application) rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, rep if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node.Columns { @@ -12520,7 +12638,10 @@ func (a *application) rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, rep }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12529,15 +12650,16 @@ func (a *application) rewriteRefOfTablespaceOperation(parent SQLNode, node *Tabl if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12546,12 +12668,10 @@ func (a *application) rewriteRefOfTimestampFuncExpr(parent SQLNode, node *Timest if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Expr1, func(newNode, parent SQLNode) { @@ -12564,7 +12684,10 @@ func (a *application) rewriteRefOfTimestampFuncExpr(parent SQLNode, node *Timest }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12573,12 +12696,10 @@ func (a *application) rewriteRefOfTruncateTable(parent SQLNode, node *TruncateTa if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { @@ -12586,7 +12707,10 @@ func (a *application) rewriteRefOfTruncateTable(parent SQLNode, node *TruncateTa }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12595,12 +12719,10 @@ func (a *application) rewriteRefOfUnaryExpr(parent SQLNode, node *UnaryExpr, rep if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -12608,7 +12730,10 @@ func (a *application) rewriteRefOfUnaryExpr(parent SQLNode, node *UnaryExpr, rep }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12617,12 +12742,10 @@ func (a *application) rewriteRefOfUnion(parent SQLNode, node *Union, replacer re if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteSelectStatement(node, node.FirstStatement, func(newNode, parent SQLNode) { @@ -12647,7 +12770,10 @@ func (a *application) rewriteRefOfUnion(parent SQLNode, node *Union, replacer re }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12656,12 +12782,10 @@ func (a *application) rewriteRefOfUnionSelect(parent SQLNode, node *UnionSelect, if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteSelectStatement(node, node.Statement, func(newNode, parent SQLNode) { @@ -12669,7 +12793,10 @@ func (a *application) rewriteRefOfUnionSelect(parent SQLNode, node *UnionSelect, }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12678,15 +12805,16 @@ func (a *application) rewriteRefOfUnlockTables(parent SQLNode, node *UnlockTable if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12695,12 +12823,10 @@ func (a *application) rewriteRefOfUpdate(parent SQLNode, node *Update, replacer if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -12733,7 +12859,10 @@ func (a *application) rewriteRefOfUpdate(parent SQLNode, node *Update, replacer }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12742,12 +12871,10 @@ func (a *application) rewriteRefOfUpdateExpr(parent SQLNode, node *UpdateExpr, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { @@ -12760,7 +12887,10 @@ func (a *application) rewriteRefOfUpdateExpr(parent SQLNode, node *UpdateExpr, r }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12769,12 +12899,10 @@ func (a *application) rewriteRefOfUse(parent SQLNode, node *Use, replacer replac if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableIdent(node, node.DBName, func(newNode, parent SQLNode) { @@ -12782,7 +12910,10 @@ func (a *application) rewriteRefOfUse(parent SQLNode, node *Use, replacer replac }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12791,12 +12922,10 @@ func (a *application) rewriteRefOfVStream(parent SQLNode, node *VStream, replace if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -12824,7 +12953,10 @@ func (a *application) rewriteRefOfVStream(parent SQLNode, node *VStream, replace }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12833,15 +12965,16 @@ func (a *application) rewriteRefOfValidation(parent SQLNode, node *Validation, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12850,12 +12983,10 @@ func (a *application) rewriteRefOfValuesFuncExpr(parent SQLNode, node *ValuesFun if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { @@ -12863,7 +12994,10 @@ func (a *application) rewriteRefOfValuesFuncExpr(parent SQLNode, node *ValuesFun }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12872,12 +13006,10 @@ func (a *application) rewriteRefOfVindexParam(parent SQLNode, node *VindexParam, if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { @@ -12885,7 +13017,10 @@ func (a *application) rewriteRefOfVindexParam(parent SQLNode, node *VindexParam, }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12894,12 +13029,10 @@ func (a *application) rewriteRefOfVindexSpec(parent SQLNode, node *VindexSpec, r if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -12919,7 +13052,10 @@ func (a *application) rewriteRefOfVindexSpec(parent SQLNode, node *VindexSpec, r return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12928,12 +13064,10 @@ func (a *application) rewriteRefOfWhen(parent SQLNode, node *When, replacer repl if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Cond, func(newNode, parent SQLNode) { @@ -12946,7 +13080,10 @@ func (a *application) rewriteRefOfWhen(parent SQLNode, node *When, replacer repl }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12955,12 +13092,10 @@ func (a *application) rewriteRefOfWhere(parent SQLNode, node *Where, replacer re if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -12968,7 +13103,10 @@ func (a *application) rewriteRefOfWhere(parent SQLNode, node *Where, replacer re }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -12977,12 +13115,10 @@ func (a *application) rewriteRefOfXorExpr(parent SQLNode, node *XorExpr, replace if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { @@ -12995,21 +13131,25 @@ func (a *application) rewriteRefOfXorExpr(parent SQLNode, node *XorExpr, replace }); errF != nil { return errF } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil } func (a *application) rewriteReferenceAction(parent SQLNode, node ReferenceAction, replacer replacerFunc) error { - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -13336,12 +13476,10 @@ func (a *application) rewriteSelectExprs(parent SQLNode, node SelectExprs, repla if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node { @@ -13351,7 +13489,10 @@ func (a *application) rewriteSelectExprs(parent SQLNode, node SelectExprs, repla return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -13376,12 +13517,10 @@ func (a *application) rewriteSetExprs(parent SQLNode, node SetExprs, replacer re if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node { @@ -13391,7 +13530,10 @@ func (a *application) rewriteSetExprs(parent SQLNode, node SetExprs, replacer re return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -13538,12 +13680,10 @@ func (a *application) rewriteTableExprs(parent SQLNode, node TableExprs, replace if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node { @@ -13553,37 +13693,39 @@ func (a *application) rewriteTableExprs(parent SQLNode, node TableExprs, replace return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil } func (a *application) rewriteTableIdent(parent SQLNode, node TableIdent, replacer replacerFunc) error { var err error - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if err != nil { return err } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil } func (a *application) rewriteTableName(parent SQLNode, node TableName, replacer replacerFunc) error { var err error - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { @@ -13599,7 +13741,10 @@ func (a *application) rewriteTableName(parent SQLNode, node TableName, replacer if err != nil { return err } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -13608,12 +13753,10 @@ func (a *application) rewriteTableNames(parent SQLNode, node TableNames, replace if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node { @@ -13623,7 +13766,10 @@ func (a *application) rewriteTableNames(parent SQLNode, node TableNames, replace return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -13632,15 +13778,16 @@ func (a *application) rewriteTableOptions(parent SQLNode, node TableOptions, rep if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -13649,12 +13796,10 @@ func (a *application) rewriteUpdateExprs(parent SQLNode, node UpdateExprs, repla if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node { @@ -13664,7 +13809,10 @@ func (a *application) rewriteUpdateExprs(parent SQLNode, node UpdateExprs, repla return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -13673,12 +13821,10 @@ func (a *application) rewriteValTuple(parent SQLNode, node ValTuple, replacer re if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node { @@ -13688,7 +13834,10 @@ func (a *application) rewriteValTuple(parent SQLNode, node ValTuple, replacer re return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil @@ -13697,12 +13846,10 @@ func (a *application) rewriteValues(parent SQLNode, node Values, replacer replac if node == nil { return nil } - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } for i, el := range node { @@ -13712,19 +13859,20 @@ func (a *application) rewriteValues(parent SQLNode, node Values, replacer replac return errF } } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil } func (a *application) rewriteVindexParam(parent SQLNode, node VindexParam, replacer replacerFunc) error { var err error - cur := Cursor{ - node: node, - parent: parent, - replacer: replacer, - } - if a.pre != nil && !a.pre(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.pre != nil && !a.pre(&a.cur) { return nil } if errF := a.rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { @@ -13735,7 +13883,10 @@ func (a *application) rewriteVindexParam(parent SQLNode, node VindexParam, repla if err != nil { return err } - if a.post != nil && !a.post(&cur) { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if a.post != nil && !a.post(&a.cur) { return errAbort } return nil diff --git a/go/vt/sqlparser/rewriter_api.go b/go/vt/sqlparser/rewriter_api.go index bd8dd1efbff..cd4f3d957dc 100644 --- a/go/vt/sqlparser/rewriter_api.go +++ b/go/vt/sqlparser/rewriter_api.go @@ -96,5 +96,5 @@ type replacerFunc func(newNode, parent SQLNode) // application carries all the shared data so we can pass it around cheaply. type application struct { pre, post ApplyFunc - cursor Cursor + cur Cursor } diff --git a/go/vt/sqlparser/walker_test.go b/go/vt/sqlparser/walker_test.go index 386e8736eac..e0a05cc1bdb 100644 --- a/go/vt/sqlparser/walker_test.go +++ b/go/vt/sqlparser/walker_test.go @@ -40,7 +40,7 @@ func BenchmarkWalkLargeExpression(b *testing.B) { } func BenchmarkRewriteLargeExpression(b *testing.B) { - for i := 0; i < 10; i++ { + for i := 0; i < 2; i++ { b.Run(fmt.Sprintf("%d", i), func(b *testing.B) { exp := newGenerator(int64(i*100), 5).expression() count := 0 From 428b572071b27d25f713912cf6d9952ea143e78e Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Sat, 20 Mar 2021 14:57:03 +0530 Subject: [PATCH 12/15] add cursor initialization only when required Signed-off-by: Harshit Gangal --- .../asthelpergen/integration/ast_helper.go | 306 +- go/tools/asthelpergen/rewrite_gen.go | 46 +- go/vt/sqlparser/ast_helper.go | 3293 ++++++++++------- go/vt/sqlparser/walker_test.go | 2 +- 4 files changed, 2051 insertions(+), 1596 deletions(-) diff --git a/go/tools/asthelpergen/integration/ast_helper.go b/go/tools/asthelpergen/integration/ast_helper.go index d2775185d0b..5adae62f511 100644 --- a/go/tools/asthelpergen/integration/ast_helper.go +++ b/go/tools/asthelpergen/integration/ast_helper.go @@ -797,15 +797,14 @@ func (a *application) rewriteAST(parent AST, node AST, replacer replacerFunc) er } } func (a *application) rewriteBasicType(parent AST, node BasicType, replacer replacerFunc) error { - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -815,15 +814,14 @@ func (a *application) rewriteBytes(parent AST, node Bytes, replacer replacerFunc if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -831,18 +829,17 @@ func (a *application) rewriteBytes(parent AST, node Bytes, replacer replacerFunc } func (a *application) rewriteInterfaceContainer(parent AST, node InterfaceContainer, replacer replacerFunc) error { var err error - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if err != nil { return err } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -852,11 +849,13 @@ func (a *application) rewriteInterfaceSlice(parent AST, node InterfaceSlice, rep if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node { if errF := a.rewriteAST(node, el, func(newNode, parent AST) { @@ -865,11 +864,13 @@ func (a *application) rewriteInterfaceSlice(parent AST, node InterfaceSlice, rep return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -877,11 +878,13 @@ func (a *application) rewriteLeafSlice(parent AST, node LeafSlice, replacer repl if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node { if errF := a.rewriteRefOfLeaf(node, el, func(newNode, parent AST) { @@ -890,11 +893,13 @@ func (a *application) rewriteLeafSlice(parent AST, node LeafSlice, replacer repl return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -902,15 +907,14 @@ func (a *application) rewriteRefOfInterfaceContainer(parent AST, node *Interface if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -920,15 +924,14 @@ func (a *application) rewriteRefOfLeaf(parent AST, node *Leaf, replacer replacer if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -938,15 +941,14 @@ func (a *application) rewriteRefOfNoCloneType(parent AST, node *NoCloneType, rep if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -956,11 +958,13 @@ func (a *application) rewriteRefOfRefContainer(parent AST, node *RefContainer, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteAST(node, node.ASTType, func(newNode, parent AST) { parent.(*RefContainer).ASTType = newNode.(AST) @@ -972,11 +976,13 @@ func (a *application) rewriteRefOfRefContainer(parent AST, node *RefContainer, r }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -984,11 +990,13 @@ func (a *application) rewriteRefOfRefSliceContainer(parent AST, node *RefSliceCo if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node.ASTElements { if errF := a.rewriteAST(node, el, func(newNode, parent AST) { @@ -1004,11 +1012,13 @@ func (a *application) rewriteRefOfRefSliceContainer(parent AST, node *RefSliceCo return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -1016,22 +1026,26 @@ func (a *application) rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer re if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteSubIface(node, node.inner, func(newNode, parent AST) { parent.(*SubImpl).inner = newNode.(SubIface) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -1039,11 +1053,13 @@ func (a *application) rewriteRefOfValueContainer(parent AST, node *ValueContaine if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteAST(node, node.ASTType, func(newNode, parent AST) { parent.(*ValueContainer).ASTType = newNode.(AST) @@ -1055,11 +1071,13 @@ func (a *application) rewriteRefOfValueContainer(parent AST, node *ValueContaine }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -1067,11 +1085,13 @@ func (a *application) rewriteRefOfValueSliceContainer(parent AST, node *ValueSli if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node.ASTElements { if errF := a.rewriteAST(node, el, func(newNode, parent AST) { @@ -1087,11 +1107,13 @@ func (a *application) rewriteRefOfValueSliceContainer(parent AST, node *ValueSli return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -1109,11 +1131,13 @@ func (a *application) rewriteSubIface(parent AST, node SubIface, replacer replac } func (a *application) rewriteValueContainer(parent AST, node ValueContainer, replacer replacerFunc) error { var err error - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteAST(node, node.ASTType, func(newNode, parent AST) { err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTType' on 'ValueContainer'") @@ -1128,21 +1152,25 @@ func (a *application) rewriteValueContainer(parent AST, node ValueContainer, rep if err != nil { return err } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } func (a *application) rewriteValueSliceContainer(parent AST, node ValueSliceContainer, replacer replacerFunc) error { var err error - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for _, el := range node.ASTElements { if errF := a.rewriteAST(node, el, func(newNode, parent AST) { @@ -1161,11 +1189,13 @@ func (a *application) rewriteValueSliceContainer(parent AST, node ValueSliceCont if err != nil { return err } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go index 3d74570ff89..339861a2cc9 100644 --- a/go/tools/asthelpergen/rewrite_gen.go +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -88,12 +88,13 @@ func (e rewriteGen) structMethod(t types.Type, strct *types.Struct, spi generato if !shouldAdd(t, spi.iface()) { return nil } + fields := e.rewriteAllStructFields(t, strct, spi, true) stmts := []jen.Code{jen.Var().Id("err").Error()} - stmts = append(stmts, executePre()...) - stmts = append(stmts, e.rewriteAllStructFields(t, strct, spi, true)...) + stmts = append(stmts, executePre()) + stmts = append(stmts, fields...) stmts = append(stmts, jen.If(jen.Id("err != nil")).Block(jen.Return(jen.Err()))) - stmts = append(stmts, executePost()...) + stmts = append(stmts, executePost(len(fields) > 0)) stmts = append(stmts, returnNil()) e.rewriteFunc(t, stmts, spi) @@ -116,9 +117,10 @@ func (e rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi gen return nil } */ - stmts = append(stmts, executePre()...) - stmts = append(stmts, e.rewriteAllStructFields(t, strct, spi, false)...) - stmts = append(stmts, executePost()...) + stmts = append(stmts, executePre()) + fields := e.rewriteAllStructFields(t, strct, spi, false) + stmts = append(stmts, fields...) + stmts = append(stmts, executePost(len(fields) > 0)) stmts = append(stmts, returnNil()) e.rewriteFunc(t, stmts, spi) @@ -163,8 +165,9 @@ func (e rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorS stmts := []jen.Code{ jen.If(jen.Id("node == nil").Block(returnNil())), } - stmts = append(stmts, executePre()...) + stmts = append(stmts, executePre()) + addCur := false if shouldAdd(slice.Elem(), spi.iface()) { /* for i, el := range node { @@ -175,34 +178,40 @@ func (e rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorS } } */ + addCur = true stmts = append(stmts, jen.For(jen.Id("i, el").Op(":=").Id("range node")). Block(e.rewriteChild(t, slice.Elem(), "notUsed", jen.Id("el"), jen.Index(jen.Id("i")), false))) } - stmts = append(stmts, executePost()...) + stmts = append(stmts, executePost(addCur)) stmts = append(stmts, returnNil()) e.rewriteFunc(t, stmts, spi) return nil } -func executePre() []jen.Code { +func setupCursor() []jen.Code { return []jen.Code{ jen.Id("a.cur.replacer = replacer"), jen.Id("a.cur.parent = parent"), jen.Id("a.cur.node = node"), - jen.If(jen.Id("a.pre!= nil && !a.pre(&a.cur)")).Block(returnNil()), } } +func executePre() jen.Code { + curStmts := setupCursor() + curStmts = append(curStmts, jen.If(jen.Id("!a.pre(&a.cur)")).Block(returnNil())) + return jen.If(jen.Id("a.pre!= nil").Block(curStmts...)) +} -func executePost() []jen.Code { - return []jen.Code{ - jen.Id("a.cur.replacer = replacer"), - jen.Id("a.cur.parent = parent"), - jen.Id("a.cur.node = node"), - jen.If(jen.Id("a.post != nil && !a.post(&a.cur)")).Block(jen.Return(jen.Id(abort))), +func executePost(addCur bool) jen.Code { + if addCur { + curStmts := setupCursor() + curStmts = append(curStmts, jen.If(jen.Id("!a.post(&a.cur)")).Block(returnNil())) + return jen.If(jen.Id("a.post!= nil").Block(curStmts...)) } + + return jen.If(jen.Id("a.post != nil && !a.post(&a.cur)")).Block(jen.Return(jen.Id(abort))) } func (e rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { @@ -210,10 +219,7 @@ func (e rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI) return nil } - stmts := executePre() - stmts = append(stmts, executePost()...) - stmts = append(stmts, returnNil()) - + stmts := []jen.Code{executePre(), executePost(false), returnNil()} e.rewriteFunc(t, stmts, spi) return nil } diff --git a/go/vt/sqlparser/ast_helper.go b/go/vt/sqlparser/ast_helper.go index 2cb85eaeda6..06c70036160 100644 --- a/go/vt/sqlparser/ast_helper.go +++ b/go/vt/sqlparser/ast_helper.go @@ -9297,30 +9297,28 @@ func VisitVindexParam(in VindexParam, f Visit) error { return nil } func (a *application) rewriteAccessMode(parent SQLNode, node AccessMode, replacer replacerFunc) error { - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } return nil } func (a *application) rewriteAlgorithmValue(parent SQLNode, node AlgorithmValue, replacer replacerFunc) error { - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -9375,30 +9373,28 @@ func (a *application) rewriteAlterOption(parent SQLNode, node AlterOption, repla } } func (a *application) rewriteArgument(parent SQLNode, node Argument, replacer replacerFunc) error { - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } return nil } func (a *application) rewriteBoolVal(parent SQLNode, node BoolVal, replacer replacerFunc) error { - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -9420,18 +9416,17 @@ func (a *application) rewriteCharacteristic(parent SQLNode, node Characteristic, } func (a *application) rewriteColIdent(parent SQLNode, node ColIdent, replacer replacerFunc) error { var err error - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if err != nil { return err } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -9457,11 +9452,13 @@ func (a *application) rewriteColumns(parent SQLNode, node Columns, replacer repl if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node { if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { @@ -9470,11 +9467,13 @@ func (a *application) rewriteColumns(parent SQLNode, node Columns, replacer repl return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -9482,15 +9481,14 @@ func (a *application) rewriteComments(parent SQLNode, node Comments, replacer re if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -9642,11 +9640,13 @@ func (a *application) rewriteExprs(parent SQLNode, node Exprs, replacer replacer if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node { if errF := a.rewriteExpr(node, el, func(newNode, parent SQLNode) { @@ -9655,11 +9655,13 @@ func (a *application) rewriteExprs(parent SQLNode, node Exprs, replacer replacer return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -9667,11 +9669,13 @@ func (a *application) rewriteGroupBy(parent SQLNode, node GroupBy, replacer repl if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node { if errF := a.rewriteExpr(node, el, func(newNode, parent SQLNode) { @@ -9680,11 +9684,13 @@ func (a *application) rewriteGroupBy(parent SQLNode, node GroupBy, replacer repl return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -9707,15 +9713,14 @@ func (a *application) rewriteInsertRows(parent SQLNode, node InsertRows, replace } } func (a *application) rewriteIsolationLevel(parent SQLNode, node IsolationLevel, replacer replacerFunc) error { - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -9723,11 +9728,13 @@ func (a *application) rewriteIsolationLevel(parent SQLNode, node IsolationLevel, } func (a *application) rewriteJoinCondition(parent SQLNode, node JoinCondition, replacer replacerFunc) error { var err error - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.On, func(newNode, parent SQLNode) { err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'On' on 'JoinCondition'") @@ -9742,11 +9749,13 @@ func (a *application) rewriteJoinCondition(parent SQLNode, node JoinCondition, r if err != nil { return err } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -9754,15 +9763,14 @@ func (a *application) rewriteListArg(parent SQLNode, node ListArg, replacer repl if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -9772,11 +9780,13 @@ func (a *application) rewriteOnDup(parent SQLNode, node OnDup, replacer replacer if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node { if errF := a.rewriteRefOfUpdateExpr(node, el, func(newNode, parent SQLNode) { @@ -9785,11 +9795,13 @@ func (a *application) rewriteOnDup(parent SQLNode, node OnDup, replacer replacer return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -9797,11 +9809,13 @@ func (a *application) rewriteOrderBy(parent SQLNode, node OrderBy, replacer repl if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node { if errF := a.rewriteRefOfOrder(node, el, func(newNode, parent SQLNode) { @@ -9810,11 +9824,13 @@ func (a *application) rewriteOrderBy(parent SQLNode, node OrderBy, replacer repl return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -9822,11 +9838,13 @@ func (a *application) rewritePartitions(parent SQLNode, node Partitions, replace if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node { if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { @@ -9835,11 +9853,13 @@ func (a *application) rewritePartitions(parent SQLNode, node Partitions, replace return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -9847,11 +9867,13 @@ func (a *application) rewriteRefOfAddColumns(parent SQLNode, node *AddColumns, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node.Columns { if errF := a.rewriteRefOfColumnDefinition(node, el, func(newNode, parent SQLNode) { @@ -9870,11 +9892,13 @@ func (a *application) rewriteRefOfAddColumns(parent SQLNode, node *AddColumns, r }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -9882,22 +9906,26 @@ func (a *application) rewriteRefOfAddConstraintDefinition(parent SQLNode, node * if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteRefOfConstraintDefinition(node, node.ConstraintDefinition, func(newNode, parent SQLNode) { parent.(*AddConstraintDefinition).ConstraintDefinition = newNode.(*ConstraintDefinition) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -9905,22 +9933,26 @@ func (a *application) rewriteRefOfAddIndexDefinition(parent SQLNode, node *AddIn if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteRefOfIndexDefinition(node, node.IndexDefinition, func(newNode, parent SQLNode) { parent.(*AddIndexDefinition).IndexDefinition = newNode.(*IndexDefinition) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -9928,11 +9960,13 @@ func (a *application) rewriteRefOfAliasedExpr(parent SQLNode, node *AliasedExpr, if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*AliasedExpr).Expr = newNode.(Expr) @@ -9944,11 +9978,13 @@ func (a *application) rewriteRefOfAliasedExpr(parent SQLNode, node *AliasedExpr, }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -9956,11 +9992,13 @@ func (a *application) rewriteRefOfAliasedTableExpr(parent SQLNode, node *Aliased if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteSimpleTableExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*AliasedTableExpr).Expr = newNode.(SimpleTableExpr) @@ -9982,11 +10020,13 @@ func (a *application) rewriteRefOfAliasedTableExpr(parent SQLNode, node *Aliased }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -9994,15 +10034,14 @@ func (a *application) rewriteRefOfAlterCharset(parent SQLNode, node *AlterCharse if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -10012,11 +10051,13 @@ func (a *application) rewriteRefOfAlterColumn(parent SQLNode, node *AlterColumn, if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteRefOfColName(node, node.Column, func(newNode, parent SQLNode) { parent.(*AlterColumn).Column = newNode.(*ColName) @@ -10028,11 +10069,13 @@ func (a *application) rewriteRefOfAlterColumn(parent SQLNode, node *AlterColumn, }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10040,15 +10083,14 @@ func (a *application) rewriteRefOfAlterDatabase(parent SQLNode, node *AlterDatab if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -10058,15 +10100,14 @@ func (a *application) rewriteRefOfAlterMigration(parent SQLNode, node *AlterMigr if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -10076,11 +10117,13 @@ func (a *application) rewriteRefOfAlterTable(parent SQLNode, node *AlterTable, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { parent.(*AlterTable).Table = newNode.(TableName) @@ -10099,11 +10142,13 @@ func (a *application) rewriteRefOfAlterTable(parent SQLNode, node *AlterTable, r }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10111,11 +10156,13 @@ func (a *application) rewriteRefOfAlterView(parent SQLNode, node *AlterView, rep if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { parent.(*AlterView).ViewName = newNode.(TableName) @@ -10132,11 +10179,13 @@ func (a *application) rewriteRefOfAlterView(parent SQLNode, node *AlterView, rep }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10144,11 +10193,13 @@ func (a *application) rewriteRefOfAlterVschema(parent SQLNode, node *AlterVschem if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { parent.(*AlterVschema).Table = newNode.(TableName) @@ -10172,11 +10223,13 @@ func (a *application) rewriteRefOfAlterVschema(parent SQLNode, node *AlterVschem }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10184,11 +10237,13 @@ func (a *application) rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replace if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { parent.(*AndExpr).Left = newNode.(Expr) @@ -10200,11 +10255,13 @@ func (a *application) rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replace }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10212,11 +10269,13 @@ func (a *application) rewriteRefOfAutoIncSpec(parent SQLNode, node *AutoIncSpec, if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteColIdent(node, node.Column, func(newNode, parent SQLNode) { parent.(*AutoIncSpec).Column = newNode.(ColIdent) @@ -10228,11 +10287,13 @@ func (a *application) rewriteRefOfAutoIncSpec(parent SQLNode, node *AutoIncSpec, }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10240,15 +10301,14 @@ func (a *application) rewriteRefOfBegin(parent SQLNode, node *Begin, replacer re if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -10258,11 +10318,13 @@ func (a *application) rewriteRefOfBinaryExpr(parent SQLNode, node *BinaryExpr, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { parent.(*BinaryExpr).Left = newNode.(Expr) @@ -10274,11 +10336,13 @@ func (a *application) rewriteRefOfBinaryExpr(parent SQLNode, node *BinaryExpr, r }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10286,11 +10350,13 @@ func (a *application) rewriteRefOfCallProc(parent SQLNode, node *CallProc, repla if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableName(node, node.Name, func(newNode, parent SQLNode) { parent.(*CallProc).Name = newNode.(TableName) @@ -10302,11 +10368,13 @@ func (a *application) rewriteRefOfCallProc(parent SQLNode, node *CallProc, repla }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10314,11 +10382,13 @@ func (a *application) rewriteRefOfCaseExpr(parent SQLNode, node *CaseExpr, repla if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*CaseExpr).Expr = newNode.(Expr) @@ -10337,11 +10407,13 @@ func (a *application) rewriteRefOfCaseExpr(parent SQLNode, node *CaseExpr, repla }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10349,11 +10421,13 @@ func (a *application) rewriteRefOfChangeColumn(parent SQLNode, node *ChangeColum if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteRefOfColName(node, node.OldColumn, func(newNode, parent SQLNode) { parent.(*ChangeColumn).OldColumn = newNode.(*ColName) @@ -10375,11 +10449,13 @@ func (a *application) rewriteRefOfChangeColumn(parent SQLNode, node *ChangeColum }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10387,22 +10463,26 @@ func (a *application) rewriteRefOfCheckConstraintDefinition(parent SQLNode, node if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*CheckConstraintDefinition).Expr = newNode.(Expr) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10410,15 +10490,14 @@ func (a *application) rewriteRefOfColIdent(parent SQLNode, node *ColIdent, repla if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -10428,11 +10507,13 @@ func (a *application) rewriteRefOfColName(parent SQLNode, node *ColName, replace if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*ColName).Name = newNode.(ColIdent) @@ -10444,11 +10525,13 @@ func (a *application) rewriteRefOfColName(parent SQLNode, node *ColName, replace }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10456,22 +10539,26 @@ func (a *application) rewriteRefOfCollateExpr(parent SQLNode, node *CollateExpr, if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*CollateExpr).Expr = newNode.(Expr) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10479,22 +10566,26 @@ func (a *application) rewriteRefOfColumnDefinition(parent SQLNode, node *ColumnD if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*ColumnDefinition).Name = newNode.(ColIdent) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10502,11 +10593,13 @@ func (a *application) rewriteRefOfColumnType(parent SQLNode, node *ColumnType, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { parent.(*ColumnType).Length = newNode.(*Literal) @@ -10518,11 +10611,13 @@ func (a *application) rewriteRefOfColumnType(parent SQLNode, node *ColumnType, r }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10530,15 +10625,14 @@ func (a *application) rewriteRefOfCommit(parent SQLNode, node *Commit, replacer if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -10548,11 +10642,13 @@ func (a *application) rewriteRefOfComparisonExpr(parent SQLNode, node *Compariso if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { parent.(*ComparisonExpr).Left = newNode.(Expr) @@ -10569,11 +10665,13 @@ func (a *application) rewriteRefOfComparisonExpr(parent SQLNode, node *Compariso }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10581,22 +10679,26 @@ func (a *application) rewriteRefOfConstraintDefinition(parent SQLNode, node *Con if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteConstraintInfo(node, node.Details, func(newNode, parent SQLNode) { parent.(*ConstraintDefinition).Details = newNode.(ConstraintInfo) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10604,11 +10706,13 @@ func (a *application) rewriteRefOfConvertExpr(parent SQLNode, node *ConvertExpr, if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*ConvertExpr).Expr = newNode.(Expr) @@ -10620,11 +10724,13 @@ func (a *application) rewriteRefOfConvertExpr(parent SQLNode, node *ConvertExpr, }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10632,11 +10738,13 @@ func (a *application) rewriteRefOfConvertType(parent SQLNode, node *ConvertType, if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { parent.(*ConvertType).Length = newNode.(*Literal) @@ -10648,11 +10756,13 @@ func (a *application) rewriteRefOfConvertType(parent SQLNode, node *ConvertType, }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10660,22 +10770,26 @@ func (a *application) rewriteRefOfConvertUsingExpr(parent SQLNode, node *Convert if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*ConvertUsingExpr).Expr = newNode.(Expr) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10683,22 +10797,26 @@ func (a *application) rewriteRefOfCreateDatabase(parent SQLNode, node *CreateDat if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*CreateDatabase).Comments = newNode.(Comments) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10706,11 +10824,13 @@ func (a *application) rewriteRefOfCreateTable(parent SQLNode, node *CreateTable, if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { parent.(*CreateTable).Table = newNode.(TableName) @@ -10727,11 +10847,13 @@ func (a *application) rewriteRefOfCreateTable(parent SQLNode, node *CreateTable, }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10739,11 +10861,13 @@ func (a *application) rewriteRefOfCreateView(parent SQLNode, node *CreateView, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { parent.(*CreateView).ViewName = newNode.(TableName) @@ -10760,11 +10884,13 @@ func (a *application) rewriteRefOfCreateView(parent SQLNode, node *CreateView, r }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10772,11 +10898,13 @@ func (a *application) rewriteRefOfCurTimeFuncExpr(parent SQLNode, node *CurTimeF if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*CurTimeFuncExpr).Name = newNode.(ColIdent) @@ -10788,11 +10916,13 @@ func (a *application) rewriteRefOfCurTimeFuncExpr(parent SQLNode, node *CurTimeF }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10800,15 +10930,14 @@ func (a *application) rewriteRefOfDefault(parent SQLNode, node *Default, replace if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -10818,11 +10947,13 @@ func (a *application) rewriteRefOfDelete(parent SQLNode, node *Delete, replacer if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*Delete).Comments = newNode.(Comments) @@ -10859,11 +10990,13 @@ func (a *application) rewriteRefOfDelete(parent SQLNode, node *Delete, replacer }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10871,22 +11004,26 @@ func (a *application) rewriteRefOfDerivedTable(parent SQLNode, node *DerivedTabl if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { parent.(*DerivedTable).Select = newNode.(SelectStatement) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10894,22 +11031,26 @@ func (a *application) rewriteRefOfDropColumn(parent SQLNode, node *DropColumn, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { parent.(*DropColumn).Name = newNode.(*ColName) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10917,22 +11058,26 @@ func (a *application) rewriteRefOfDropDatabase(parent SQLNode, node *DropDatabas if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*DropDatabase).Comments = newNode.(Comments) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10940,15 +11085,14 @@ func (a *application) rewriteRefOfDropKey(parent SQLNode, node *DropKey, replace if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -10958,22 +11102,26 @@ func (a *application) rewriteRefOfDropTable(parent SQLNode, node *DropTable, rep if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { parent.(*DropTable).FromTables = newNode.(TableNames) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -10981,22 +11129,26 @@ func (a *application) rewriteRefOfDropView(parent SQLNode, node *DropView, repla if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { parent.(*DropView).FromTables = newNode.(TableNames) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11004,22 +11156,26 @@ func (a *application) rewriteRefOfExistsExpr(parent SQLNode, node *ExistsExpr, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteRefOfSubquery(node, node.Subquery, func(newNode, parent SQLNode) { parent.(*ExistsExpr).Subquery = newNode.(*Subquery) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11027,22 +11183,26 @@ func (a *application) rewriteRefOfExplainStmt(parent SQLNode, node *ExplainStmt, if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteStatement(node, node.Statement, func(newNode, parent SQLNode) { parent.(*ExplainStmt).Statement = newNode.(Statement) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11050,22 +11210,26 @@ func (a *application) rewriteRefOfExplainTab(parent SQLNode, node *ExplainTab, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { parent.(*ExplainTab).Table = newNode.(TableName) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11073,22 +11237,26 @@ func (a *application) rewriteRefOfFlush(parent SQLNode, node *Flush, replacer re if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableNames(node, node.TableNames, func(newNode, parent SQLNode) { parent.(*Flush).TableNames = newNode.(TableNames) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11096,15 +11264,14 @@ func (a *application) rewriteRefOfForce(parent SQLNode, node *Force, replacer re if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -11114,11 +11281,13 @@ func (a *application) rewriteRefOfForeignKeyDefinition(parent SQLNode, node *For if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteColumns(node, node.Source, func(newNode, parent SQLNode) { parent.(*ForeignKeyDefinition).Source = newNode.(Columns) @@ -11145,11 +11314,13 @@ func (a *application) rewriteRefOfForeignKeyDefinition(parent SQLNode, node *For }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11157,11 +11328,13 @@ func (a *application) rewriteRefOfFuncExpr(parent SQLNode, node *FuncExpr, repla if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { parent.(*FuncExpr).Qualifier = newNode.(TableIdent) @@ -11178,11 +11351,13 @@ func (a *application) rewriteRefOfFuncExpr(parent SQLNode, node *FuncExpr, repla }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11190,11 +11365,13 @@ func (a *application) rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupCon if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteSelectExprs(node, node.Exprs, func(newNode, parent SQLNode) { parent.(*GroupConcatExpr).Exprs = newNode.(SelectExprs) @@ -11211,11 +11388,13 @@ func (a *application) rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupCon }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11223,22 +11402,26 @@ func (a *application) rewriteRefOfIndexDefinition(parent SQLNode, node *IndexDef if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteRefOfIndexInfo(node, node.Info, func(newNode, parent SQLNode) { parent.(*IndexDefinition).Info = newNode.(*IndexInfo) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11246,11 +11429,13 @@ func (a *application) rewriteRefOfIndexHints(parent SQLNode, node *IndexHints, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node.Indexes { if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { @@ -11259,11 +11444,13 @@ func (a *application) rewriteRefOfIndexHints(parent SQLNode, node *IndexHints, r return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11271,11 +11458,13 @@ func (a *application) rewriteRefOfIndexInfo(parent SQLNode, node *IndexInfo, rep if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*IndexInfo).Name = newNode.(ColIdent) @@ -11287,11 +11476,13 @@ func (a *application) rewriteRefOfIndexInfo(parent SQLNode, node *IndexInfo, rep }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11299,11 +11490,13 @@ func (a *application) rewriteRefOfInsert(parent SQLNode, node *Insert, replacer if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*Insert).Comments = newNode.(Comments) @@ -11335,11 +11528,13 @@ func (a *application) rewriteRefOfInsert(parent SQLNode, node *Insert, replacer }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11347,22 +11542,26 @@ func (a *application) rewriteRefOfIntervalExpr(parent SQLNode, node *IntervalExp if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*IntervalExpr).Expr = newNode.(Expr) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11370,22 +11569,26 @@ func (a *application) rewriteRefOfIsExpr(parent SQLNode, node *IsExpr, replacer if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*IsExpr).Expr = newNode.(Expr) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11393,11 +11596,13 @@ func (a *application) rewriteRefOfJoinCondition(parent SQLNode, node *JoinCondit if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.On, func(newNode, parent SQLNode) { parent.(*JoinCondition).On = newNode.(Expr) @@ -11409,11 +11614,13 @@ func (a *application) rewriteRefOfJoinCondition(parent SQLNode, node *JoinCondit }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11421,11 +11628,13 @@ func (a *application) rewriteRefOfJoinTableExpr(parent SQLNode, node *JoinTableE if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableExpr(node, node.LeftExpr, func(newNode, parent SQLNode) { parent.(*JoinTableExpr).LeftExpr = newNode.(TableExpr) @@ -11442,11 +11651,13 @@ func (a *application) rewriteRefOfJoinTableExpr(parent SQLNode, node *JoinTableE }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11454,15 +11665,14 @@ func (a *application) rewriteRefOfKeyState(parent SQLNode, node *KeyState, repla if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -11472,11 +11682,13 @@ func (a *application) rewriteRefOfLimit(parent SQLNode, node *Limit, replacer re if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Offset, func(newNode, parent SQLNode) { parent.(*Limit).Offset = newNode.(Expr) @@ -11488,11 +11700,13 @@ func (a *application) rewriteRefOfLimit(parent SQLNode, node *Limit, replacer re }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11500,15 +11714,14 @@ func (a *application) rewriteRefOfLiteral(parent SQLNode, node *Literal, replace if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -11518,15 +11731,14 @@ func (a *application) rewriteRefOfLoad(parent SQLNode, node *Load, replacer repl if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -11536,15 +11748,14 @@ func (a *application) rewriteRefOfLockOption(parent SQLNode, node *LockOption, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -11554,15 +11765,14 @@ func (a *application) rewriteRefOfLockTables(parent SQLNode, node *LockTables, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -11572,11 +11782,13 @@ func (a *application) rewriteRefOfMatchExpr(parent SQLNode, node *MatchExpr, rep if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteSelectExprs(node, node.Columns, func(newNode, parent SQLNode) { parent.(*MatchExpr).Columns = newNode.(SelectExprs) @@ -11588,11 +11800,13 @@ func (a *application) rewriteRefOfMatchExpr(parent SQLNode, node *MatchExpr, rep }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11600,11 +11814,13 @@ func (a *application) rewriteRefOfModifyColumn(parent SQLNode, node *ModifyColum if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteRefOfColumnDefinition(node, node.NewColDefinition, func(newNode, parent SQLNode) { parent.(*ModifyColumn).NewColDefinition = newNode.(*ColumnDefinition) @@ -11621,11 +11837,13 @@ func (a *application) rewriteRefOfModifyColumn(parent SQLNode, node *ModifyColum }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11633,22 +11851,26 @@ func (a *application) rewriteRefOfNextval(parent SQLNode, node *Nextval, replace if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*Nextval).Expr = newNode.(Expr) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11656,22 +11878,26 @@ func (a *application) rewriteRefOfNotExpr(parent SQLNode, node *NotExpr, replace if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*NotExpr).Expr = newNode.(Expr) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11679,15 +11905,14 @@ func (a *application) rewriteRefOfNullVal(parent SQLNode, node *NullVal, replace if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -11697,22 +11922,26 @@ func (a *application) rewriteRefOfOptLike(parent SQLNode, node *OptLike, replace if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableName(node, node.LikeTable, func(newNode, parent SQLNode) { parent.(*OptLike).LikeTable = newNode.(TableName) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11720,11 +11949,13 @@ func (a *application) rewriteRefOfOrExpr(parent SQLNode, node *OrExpr, replacer if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { parent.(*OrExpr).Left = newNode.(Expr) @@ -11736,11 +11967,13 @@ func (a *application) rewriteRefOfOrExpr(parent SQLNode, node *OrExpr, replacer }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11748,22 +11981,26 @@ func (a *application) rewriteRefOfOrder(parent SQLNode, node *Order, replacer re if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*Order).Expr = newNode.(Expr) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11771,22 +12008,26 @@ func (a *application) rewriteRefOfOrderByOption(parent SQLNode, node *OrderByOpt if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteColumns(node, node.Cols, func(newNode, parent SQLNode) { parent.(*OrderByOption).Cols = newNode.(Columns) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11794,15 +12035,14 @@ func (a *application) rewriteRefOfOtherAdmin(parent SQLNode, node *OtherAdmin, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -11812,15 +12052,14 @@ func (a *application) rewriteRefOfOtherRead(parent SQLNode, node *OtherRead, rep if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -11830,22 +12069,26 @@ func (a *application) rewriteRefOfParenSelect(parent SQLNode, node *ParenSelect, if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { parent.(*ParenSelect).Select = newNode.(SelectStatement) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11853,34 +12096,40 @@ func (a *application) rewriteRefOfParenTableExpr(parent SQLNode, node *ParenTabl if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableExprs(node, node.Exprs, func(newNode, parent SQLNode) { parent.(*ParenTableExpr).Exprs = newNode.(TableExprs) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort - } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } + } return nil } func (a *application) rewriteRefOfPartitionDefinition(parent SQLNode, node *PartitionDefinition, replacer replacerFunc) error { if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*PartitionDefinition).Name = newNode.(ColIdent) @@ -11892,11 +12141,13 @@ func (a *application) rewriteRefOfPartitionDefinition(parent SQLNode, node *Part }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11904,11 +12155,13 @@ func (a *application) rewriteRefOfPartitionSpec(parent SQLNode, node *PartitionS if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewritePartitions(node, node.Names, func(newNode, parent SQLNode) { parent.(*PartitionSpec).Names = newNode.(Partitions) @@ -11932,11 +12185,13 @@ func (a *application) rewriteRefOfPartitionSpec(parent SQLNode, node *PartitionS return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11944,11 +12199,13 @@ func (a *application) rewriteRefOfRangeCond(parent SQLNode, node *RangeCond, rep if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { parent.(*RangeCond).Left = newNode.(Expr) @@ -11965,11 +12222,13 @@ func (a *application) rewriteRefOfRangeCond(parent SQLNode, node *RangeCond, rep }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -11977,22 +12236,26 @@ func (a *application) rewriteRefOfRelease(parent SQLNode, node *Release, replace if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*Release).Name = newNode.(ColIdent) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12000,15 +12263,14 @@ func (a *application) rewriteRefOfRenameIndex(parent SQLNode, node *RenameIndex, if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -12018,15 +12280,14 @@ func (a *application) rewriteRefOfRenameTable(parent SQLNode, node *RenameTable, if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -12036,22 +12297,26 @@ func (a *application) rewriteRefOfRenameTableName(parent SQLNode, node *RenameTa if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { parent.(*RenameTableName).Table = newNode.(TableName) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12059,15 +12324,14 @@ func (a *application) rewriteRefOfRevertMigration(parent SQLNode, node *RevertMi if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -12077,15 +12341,14 @@ func (a *application) rewriteRefOfRollback(parent SQLNode, node *Rollback, repla if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -12095,22 +12358,26 @@ func (a *application) rewriteRefOfSRollback(parent SQLNode, node *SRollback, rep if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*SRollback).Name = newNode.(ColIdent) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12118,22 +12385,26 @@ func (a *application) rewriteRefOfSavepoint(parent SQLNode, node *Savepoint, rep if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*Savepoint).Name = newNode.(ColIdent) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12141,11 +12412,13 @@ func (a *application) rewriteRefOfSelect(parent SQLNode, node *Select, replacer if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*Select).Comments = newNode.(Comments) @@ -12192,11 +12465,13 @@ func (a *application) rewriteRefOfSelect(parent SQLNode, node *Select, replacer }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12204,15 +12479,14 @@ func (a *application) rewriteRefOfSelectInto(parent SQLNode, node *SelectInto, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -12222,11 +12496,13 @@ func (a *application) rewriteRefOfSet(parent SQLNode, node *Set, replacer replac if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*Set).Comments = newNode.(Comments) @@ -12238,11 +12514,13 @@ func (a *application) rewriteRefOfSet(parent SQLNode, node *Set, replacer replac }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12250,11 +12528,13 @@ func (a *application) rewriteRefOfSetExpr(parent SQLNode, node *SetExpr, replace if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*SetExpr).Name = newNode.(ColIdent) @@ -12266,11 +12546,13 @@ func (a *application) rewriteRefOfSetExpr(parent SQLNode, node *SetExpr, replace }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12278,11 +12560,13 @@ func (a *application) rewriteRefOfSetTransaction(parent SQLNode, node *SetTransa if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteSQLNode(node, node.SQLNode, func(newNode, parent SQLNode) { parent.(*SetTransaction).SQLNode = newNode.(SQLNode) @@ -12301,11 +12585,13 @@ func (a *application) rewriteRefOfSetTransaction(parent SQLNode, node *SetTransa return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12313,22 +12599,26 @@ func (a *application) rewriteRefOfShow(parent SQLNode, node *Show, replacer repl if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteShowInternal(node, node.Internal, func(newNode, parent SQLNode) { parent.(*Show).Internal = newNode.(ShowInternal) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12336,11 +12626,13 @@ func (a *application) rewriteRefOfShowBasic(parent SQLNode, node *ShowBasic, rep if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableName(node, node.Tbl, func(newNode, parent SQLNode) { parent.(*ShowBasic).Tbl = newNode.(TableName) @@ -12352,11 +12644,13 @@ func (a *application) rewriteRefOfShowBasic(parent SQLNode, node *ShowBasic, rep }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12364,22 +12658,26 @@ func (a *application) rewriteRefOfShowCreate(parent SQLNode, node *ShowCreate, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableName(node, node.Op, func(newNode, parent SQLNode) { parent.(*ShowCreate).Op = newNode.(TableName) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12387,22 +12685,26 @@ func (a *application) rewriteRefOfShowFilter(parent SQLNode, node *ShowFilter, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Filter, func(newNode, parent SQLNode) { parent.(*ShowFilter).Filter = newNode.(Expr) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12410,11 +12712,13 @@ func (a *application) rewriteRefOfShowLegacy(parent SQLNode, node *ShowLegacy, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableName(node, node.OnTable, func(newNode, parent SQLNode) { parent.(*ShowLegacy).OnTable = newNode.(TableName) @@ -12431,11 +12735,13 @@ func (a *application) rewriteRefOfShowLegacy(parent SQLNode, node *ShowLegacy, r }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12443,22 +12749,26 @@ func (a *application) rewriteRefOfStarExpr(parent SQLNode, node *StarExpr, repla if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableName(node, node.TableName, func(newNode, parent SQLNode) { parent.(*StarExpr).TableName = newNode.(TableName) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12466,11 +12776,13 @@ func (a *application) rewriteRefOfStream(parent SQLNode, node *Stream, replacer if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*Stream).Comments = newNode.(Comments) @@ -12487,11 +12799,13 @@ func (a *application) rewriteRefOfStream(parent SQLNode, node *Stream, replacer }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12499,22 +12813,26 @@ func (a *application) rewriteRefOfSubquery(parent SQLNode, node *Subquery, repla if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { parent.(*Subquery).Select = newNode.(SelectStatement) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12522,11 +12840,13 @@ func (a *application) rewriteRefOfSubstrExpr(parent SQLNode, node *SubstrExpr, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { parent.(*SubstrExpr).Name = newNode.(*ColName) @@ -12548,11 +12868,13 @@ func (a *application) rewriteRefOfSubstrExpr(parent SQLNode, node *SubstrExpr, r }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12560,15 +12882,14 @@ func (a *application) rewriteRefOfTableIdent(parent SQLNode, node *TableIdent, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -12578,11 +12899,13 @@ func (a *application) rewriteRefOfTableName(parent SQLNode, node *TableName, rep if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*TableName).Name = newNode.(TableIdent) @@ -12594,11 +12917,13 @@ func (a *application) rewriteRefOfTableName(parent SQLNode, node *TableName, rep }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12606,11 +12931,13 @@ func (a *application) rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, rep if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node.Columns { if errF := a.rewriteRefOfColumnDefinition(node, el, func(newNode, parent SQLNode) { @@ -12638,11 +12965,13 @@ func (a *application) rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, rep }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12650,15 +12979,14 @@ func (a *application) rewriteRefOfTablespaceOperation(parent SQLNode, node *Tabl if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -12668,11 +12996,13 @@ func (a *application) rewriteRefOfTimestampFuncExpr(parent SQLNode, node *Timest if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Expr1, func(newNode, parent SQLNode) { parent.(*TimestampFuncExpr).Expr1 = newNode.(Expr) @@ -12684,11 +13014,13 @@ func (a *application) rewriteRefOfTimestampFuncExpr(parent SQLNode, node *Timest }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12696,22 +13028,26 @@ func (a *application) rewriteRefOfTruncateTable(parent SQLNode, node *TruncateTa if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { parent.(*TruncateTable).Table = newNode.(TableName) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12719,22 +13055,26 @@ func (a *application) rewriteRefOfUnaryExpr(parent SQLNode, node *UnaryExpr, rep if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*UnaryExpr).Expr = newNode.(Expr) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12742,11 +13082,13 @@ func (a *application) rewriteRefOfUnion(parent SQLNode, node *Union, replacer re if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteSelectStatement(node, node.FirstStatement, func(newNode, parent SQLNode) { parent.(*Union).FirstStatement = newNode.(SelectStatement) @@ -12770,11 +13112,13 @@ func (a *application) rewriteRefOfUnion(parent SQLNode, node *Union, replacer re }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12782,22 +13126,26 @@ func (a *application) rewriteRefOfUnionSelect(parent SQLNode, node *UnionSelect, if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteSelectStatement(node, node.Statement, func(newNode, parent SQLNode) { parent.(*UnionSelect).Statement = newNode.(SelectStatement) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12805,15 +13153,14 @@ func (a *application) rewriteRefOfUnlockTables(parent SQLNode, node *UnlockTable if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -12823,11 +13170,13 @@ func (a *application) rewriteRefOfUpdate(parent SQLNode, node *Update, replacer if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*Update).Comments = newNode.(Comments) @@ -12859,11 +13208,13 @@ func (a *application) rewriteRefOfUpdate(parent SQLNode, node *Update, replacer }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12871,11 +13222,13 @@ func (a *application) rewriteRefOfUpdateExpr(parent SQLNode, node *UpdateExpr, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { parent.(*UpdateExpr).Name = newNode.(*ColName) @@ -12887,11 +13240,13 @@ func (a *application) rewriteRefOfUpdateExpr(parent SQLNode, node *UpdateExpr, r }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12899,22 +13254,26 @@ func (a *application) rewriteRefOfUse(parent SQLNode, node *Use, replacer replac if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableIdent(node, node.DBName, func(newNode, parent SQLNode) { parent.(*Use).DBName = newNode.(TableIdent) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12922,11 +13281,13 @@ func (a *application) rewriteRefOfVStream(parent SQLNode, node *VStream, replace if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { parent.(*VStream).Comments = newNode.(Comments) @@ -12953,11 +13314,13 @@ func (a *application) rewriteRefOfVStream(parent SQLNode, node *VStream, replace }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -12965,15 +13328,14 @@ func (a *application) rewriteRefOfValidation(parent SQLNode, node *Validation, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -12983,22 +13345,26 @@ func (a *application) rewriteRefOfValuesFuncExpr(parent SQLNode, node *ValuesFun if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { parent.(*ValuesFuncExpr).Name = newNode.(*ColName) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -13006,22 +13372,26 @@ func (a *application) rewriteRefOfVindexParam(parent SQLNode, node *VindexParam, if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { parent.(*VindexParam).Key = newNode.(ColIdent) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -13029,11 +13399,13 @@ func (a *application) rewriteRefOfVindexSpec(parent SQLNode, node *VindexSpec, r if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { parent.(*VindexSpec).Name = newNode.(ColIdent) @@ -13052,11 +13424,13 @@ func (a *application) rewriteRefOfVindexSpec(parent SQLNode, node *VindexSpec, r return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -13064,11 +13438,13 @@ func (a *application) rewriteRefOfWhen(parent SQLNode, node *When, replacer repl if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Cond, func(newNode, parent SQLNode) { parent.(*When).Cond = newNode.(Expr) @@ -13080,11 +13456,13 @@ func (a *application) rewriteRefOfWhen(parent SQLNode, node *When, replacer repl }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -13092,22 +13470,26 @@ func (a *application) rewriteRefOfWhere(parent SQLNode, node *Where, replacer re if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { parent.(*Where).Expr = newNode.(Expr) }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -13115,11 +13497,13 @@ func (a *application) rewriteRefOfXorExpr(parent SQLNode, node *XorExpr, replace if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { parent.(*XorExpr).Left = newNode.(Expr) @@ -13131,24 +13515,25 @@ func (a *application) rewriteRefOfXorExpr(parent SQLNode, node *XorExpr, replace }); errF != nil { return errF } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } func (a *application) rewriteReferenceAction(parent SQLNode, node ReferenceAction, replacer replacerFunc) error { - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -13476,11 +13861,13 @@ func (a *application) rewriteSelectExprs(parent SQLNode, node SelectExprs, repla if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node { if errF := a.rewriteSelectExpr(node, el, func(newNode, parent SQLNode) { @@ -13489,11 +13876,13 @@ func (a *application) rewriteSelectExprs(parent SQLNode, node SelectExprs, repla return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -13517,11 +13906,13 @@ func (a *application) rewriteSetExprs(parent SQLNode, node SetExprs, replacer re if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node { if errF := a.rewriteRefOfSetExpr(node, el, func(newNode, parent SQLNode) { @@ -13530,11 +13921,13 @@ func (a *application) rewriteSetExprs(parent SQLNode, node SetExprs, replacer re return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -13680,11 +14073,13 @@ func (a *application) rewriteTableExprs(parent SQLNode, node TableExprs, replace if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node { if errF := a.rewriteTableExpr(node, el, func(newNode, parent SQLNode) { @@ -13693,28 +14088,29 @@ func (a *application) rewriteTableExprs(parent SQLNode, node TableExprs, replace return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } func (a *application) rewriteTableIdent(parent SQLNode, node TableIdent, replacer replacerFunc) error { var err error - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if err != nil { return err } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -13722,11 +14118,13 @@ func (a *application) rewriteTableIdent(parent SQLNode, node TableIdent, replace } func (a *application) rewriteTableName(parent SQLNode, node TableName, replacer replacerFunc) error { var err error - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Name' on 'TableName'") @@ -13741,11 +14139,13 @@ func (a *application) rewriteTableName(parent SQLNode, node TableName, replacer if err != nil { return err } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -13753,11 +14153,13 @@ func (a *application) rewriteTableNames(parent SQLNode, node TableNames, replace if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node { if errF := a.rewriteTableName(node, el, func(newNode, parent SQLNode) { @@ -13766,11 +14168,13 @@ func (a *application) rewriteTableNames(parent SQLNode, node TableNames, replace return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -13778,15 +14182,14 @@ func (a *application) rewriteTableOptions(parent SQLNode, node TableOptions, rep if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node if a.post != nil && !a.post(&a.cur) { return errAbort } @@ -13796,11 +14199,13 @@ func (a *application) rewriteUpdateExprs(parent SQLNode, node UpdateExprs, repla if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node { if errF := a.rewriteRefOfUpdateExpr(node, el, func(newNode, parent SQLNode) { @@ -13809,11 +14214,13 @@ func (a *application) rewriteUpdateExprs(parent SQLNode, node UpdateExprs, repla return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -13821,11 +14228,13 @@ func (a *application) rewriteValTuple(parent SQLNode, node ValTuple, replacer re if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node { if errF := a.rewriteExpr(node, el, func(newNode, parent SQLNode) { @@ -13834,11 +14243,13 @@ func (a *application) rewriteValTuple(parent SQLNode, node ValTuple, replacer re return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } @@ -13846,11 +14257,13 @@ func (a *application) rewriteValues(parent SQLNode, node Values, replacer replac if node == nil { return nil } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } for i, el := range node { if errF := a.rewriteValTuple(node, el, func(newNode, parent SQLNode) { @@ -13859,21 +14272,25 @@ func (a *application) rewriteValues(parent SQLNode, node Values, replacer replac return errF } } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } func (a *application) rewriteVindexParam(parent SQLNode, node VindexParam, replacer replacerFunc) error { var err error - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.pre != nil && !a.pre(&a.cur) { - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } if errF := a.rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Key' on 'VindexParam'") @@ -13883,11 +14300,13 @@ func (a *application) rewriteVindexParam(parent SQLNode, node VindexParam, repla if err != nil { return err } - a.cur.replacer = replacer - a.cur.parent = parent - a.cur.node = node - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return nil + } } return nil } diff --git a/go/vt/sqlparser/walker_test.go b/go/vt/sqlparser/walker_test.go index e0a05cc1bdb..386e8736eac 100644 --- a/go/vt/sqlparser/walker_test.go +++ b/go/vt/sqlparser/walker_test.go @@ -40,7 +40,7 @@ func BenchmarkWalkLargeExpression(b *testing.B) { } func BenchmarkRewriteLargeExpression(b *testing.B) { - for i := 0; i < 2; i++ { + for i := 0; i < 10; i++ { b.Run(fmt.Sprintf("%d", i), func(b *testing.B) { exp := newGenerator(int64(i*100), 5).expression() count := 0 From 4d5cdf77170a2303d97352e2a0deb467ec8107c1 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Sat, 20 Mar 2021 11:40:59 +0100 Subject: [PATCH 13/15] speed up benchmark Signed-off-by: Andres Taylor --- go/vt/sqlparser/walker_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/go/vt/sqlparser/walker_test.go b/go/vt/sqlparser/walker_test.go index 386e8736eac..04bc864d68f 100644 --- a/go/vt/sqlparser/walker_test.go +++ b/go/vt/sqlparser/walker_test.go @@ -52,7 +52,9 @@ func BenchmarkRewriteLargeExpression(b *testing.B) { count-- return true }) - require.NoError(b, err) + if err != nil { + b.Fatal(err) + } } }) } From 35cddfe214c60c3764b4c48ab348b17ede49a2ca Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Sat, 20 Mar 2021 11:43:34 +0100 Subject: [PATCH 14/15] update make target Signed-off-by: Andres Taylor --- .github/workflows/check_make_visitor.yml | 2 +- Makefile | 5 ++++- go/vt/sqlparser/walker_test.go | 4 ++-- misc/git/hooks/{visitorgen => asthelpers} | 0 4 files changed, 7 insertions(+), 4 deletions(-) rename misc/git/hooks/{visitorgen => asthelpers} (100%) diff --git a/.github/workflows/check_make_visitor.yml b/.github/workflows/check_make_visitor.yml index 219e1ece2dc..a23a8680a2e 100644 --- a/.github/workflows/check_make_visitor.yml +++ b/.github/workflows/check_make_visitor.yml @@ -31,5 +31,5 @@ jobs: - name: check_make_visitor run: | - misc/git/hooks/visitorgen + misc/git/hooks/asthelpers diff --git a/Makefile b/Makefile index cecb52b0e21..e83727c811d 100644 --- a/Makefile +++ b/Makefile @@ -103,6 +103,10 @@ parser: make -C go/vt/sqlparser visitor: + >&2 echo "make visitor has been replaced by make asthelpers" + exit 1 + +asthelpers: go run ./go/tools/asthelpergen/main -in ./go/vt/sqlparser -iface vitess.io/vitess/go/vt/sqlparser.SQLNode -except "*ColName" sizegen: @@ -123,7 +127,6 @@ clean: go clean -i ./go/... rm -rf third_party/acolyte rm -rf go/vt/.proto.tmp - rm -rf ./visitorgen # Remove everything including stuff pulled down by bootstrap.sh cleanall: clean diff --git a/go/vt/sqlparser/walker_test.go b/go/vt/sqlparser/walker_test.go index 04bc864d68f..ec7727a0832 100644 --- a/go/vt/sqlparser/walker_test.go +++ b/go/vt/sqlparser/walker_test.go @@ -40,9 +40,9 @@ func BenchmarkWalkLargeExpression(b *testing.B) { } func BenchmarkRewriteLargeExpression(b *testing.B) { - for i := 0; i < 10; i++ { + for i := 1; i < 7; i++ { b.Run(fmt.Sprintf("%d", i), func(b *testing.B) { - exp := newGenerator(int64(i*100), 5).expression() + exp := newGenerator(int64(i*100), i).expression() count := 0 for i := 0; i < b.N; i++ { _, err := Rewrite(exp, func(_ *Cursor) bool { diff --git a/misc/git/hooks/visitorgen b/misc/git/hooks/asthelpers similarity index 100% rename from misc/git/hooks/visitorgen rename to misc/git/hooks/asthelpers From deac101f7420fc86534013bd68c1bc5d04d8f7bb Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Sat, 20 Mar 2021 13:40:14 +0100 Subject: [PATCH 15/15] setup the cursor for post visit Signed-off-by: Andres Taylor --- .../asthelpergen/integration/ast_helper.go | 84 ++- .../integration/integration_rewriter_test.go | 30 +- go/tools/asthelpergen/rewrite_gen.go | 23 +- go/vt/sqlparser/ast_helper.go | 635 ++++++++++++------ 4 files changed, 553 insertions(+), 219 deletions(-) diff --git a/go/tools/asthelpergen/integration/ast_helper.go b/go/tools/asthelpergen/integration/ast_helper.go index 5adae62f511..63ee62f998d 100644 --- a/go/tools/asthelpergen/integration/ast_helper.go +++ b/go/tools/asthelpergen/integration/ast_helper.go @@ -805,8 +805,15 @@ func (a *application) rewriteBasicType(parent AST, node BasicType, replacer repl return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -822,8 +829,15 @@ func (a *application) rewriteBytes(parent AST, node Bytes, replacer replacerFunc return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -840,8 +854,15 @@ func (a *application) rewriteInterfaceContainer(parent AST, node InterfaceContai if err != nil { return err } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -869,7 +890,7 @@ func (a *application) rewriteInterfaceSlice(parent AST, node InterfaceSlice, rep a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -898,7 +919,7 @@ func (a *application) rewriteLeafSlice(parent AST, node LeafSlice, replacer repl a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -915,8 +936,15 @@ func (a *application) rewriteRefOfInterfaceContainer(parent AST, node *Interface return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -932,8 +960,15 @@ func (a *application) rewriteRefOfLeaf(parent AST, node *Leaf, replacer replacer return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -949,8 +984,15 @@ func (a *application) rewriteRefOfNoCloneType(parent AST, node *NoCloneType, rep return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -981,7 +1023,7 @@ func (a *application) rewriteRefOfRefContainer(parent AST, node *RefContainer, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -1017,7 +1059,7 @@ func (a *application) rewriteRefOfRefSliceContainer(parent AST, node *RefSliceCo a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -1044,7 +1086,7 @@ func (a *application) rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer re a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -1076,7 +1118,7 @@ func (a *application) rewriteRefOfValueContainer(parent AST, node *ValueContaine a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -1112,7 +1154,7 @@ func (a *application) rewriteRefOfValueSliceContainer(parent AST, node *ValueSli a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -1157,7 +1199,7 @@ func (a *application) rewriteValueContainer(parent AST, node ValueContainer, rep a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -1194,7 +1236,7 @@ func (a *application) rewriteValueSliceContainer(parent AST, node ValueSliceCont a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil diff --git a/go/tools/asthelpergen/integration/integration_rewriter_test.go b/go/tools/asthelpergen/integration/integration_rewriter_test.go index df204dc37f2..a5ad57ef9ab 100644 --- a/go/tools/asthelpergen/integration/integration_rewriter_test.go +++ b/go/tools/asthelpergen/integration/integration_rewriter_test.go @@ -226,6 +226,34 @@ func TestRewriteVisitValueContainerReplace2(t *testing.T) { require.Error(t, err) } +func TestRewriteVisitRefContainerPreOrPostOnly(t *testing.T) { + leaf1 := &Leaf{1} + leaf2 := &Leaf{2} + container := &RefContainer{ASTType: leaf1, ASTImplementationType: leaf2} + containerContainer := &RefContainer{ASTType: container} + + tv := &rewriteTestVisitor{} + + _, err := Rewrite(containerContainer, tv.pre, nil) + require.NoError(t, err) + tv.assertEquals(t, []step{ + Pre{containerContainer}, + Pre{container}, + Pre{leaf1}, + Pre{leaf2}, + }) + + tv = &rewriteTestVisitor{} + _, err = Rewrite(containerContainer, nil, tv.post) + require.NoError(t, err) + tv.assertEquals(t, []step{ + Post{leaf1}, + Post{leaf2}, + Post{container}, + Post{containerContainer}, + }) +} + func rewriteLeaf(from, to int) func(*Cursor) bool { return func(cursor *Cursor) bool { leaf, ok := cursor.node.(*Leaf) @@ -270,7 +298,7 @@ func (r Pre) String() string { return fmt.Sprintf("Pre(%s)", r.el.String()) } func (r Post) String() string { - return fmt.Sprintf("Pre(%s)", r.el.String()) + return fmt.Sprintf("Post(%s)", r.el.String()) } type Post struct { diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go index 339861a2cc9..42a33ab4b60 100644 --- a/go/tools/asthelpergen/rewrite_gen.go +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -167,7 +167,7 @@ func (e rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorS } stmts = append(stmts, executePre()) - addCur := false + haveChildren := false if shouldAdd(slice.Elem(), spi.iface()) { /* for i, el := range node { @@ -178,13 +178,13 @@ func (e rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorS } } */ - addCur = true + haveChildren = true stmts = append(stmts, jen.For(jen.Id("i, el").Op(":=").Id("range node")). Block(e.rewriteChild(t, slice.Elem(), "notUsed", jen.Id("el"), jen.Index(jen.Id("i")), false))) } - stmts = append(stmts, executePost(addCur)) + stmts = append(stmts, executePost(haveChildren)) stmts = append(stmts, returnNil()) e.rewriteFunc(t, stmts, spi) @@ -204,14 +204,19 @@ func executePre() jen.Code { return jen.If(jen.Id("a.pre!= nil").Block(curStmts...)) } -func executePost(addCur bool) jen.Code { - if addCur { - curStmts := setupCursor() - curStmts = append(curStmts, jen.If(jen.Id("!a.post(&a.cur)")).Block(returnNil())) - return jen.If(jen.Id("a.post!= nil").Block(curStmts...)) +func executePost(seenChildren bool) jen.Code { + var curStmts []jen.Code + if seenChildren { + // if we have visited children, we have to write to the cursor fields + curStmts = setupCursor() + } else { + curStmts = append(curStmts, + jen.If(jen.Id("a.pre == nil")).Block(setupCursor()...)) } - return jen.If(jen.Id("a.post != nil && !a.post(&a.cur)")).Block(jen.Return(jen.Id(abort))) + curStmts = append(curStmts, jen.If(jen.Id("!a.post(&a.cur)")).Block(jen.Return(jen.Id(abort)))) + + return jen.If(jen.Id("a.post != nil")).Block(curStmts...) } func (e rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { diff --git a/go/vt/sqlparser/ast_helper.go b/go/vt/sqlparser/ast_helper.go index 06c70036160..8172a93fdb6 100644 --- a/go/vt/sqlparser/ast_helper.go +++ b/go/vt/sqlparser/ast_helper.go @@ -9305,8 +9305,15 @@ func (a *application) rewriteAccessMode(parent SQLNode, node AccessMode, replace return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -9319,8 +9326,15 @@ func (a *application) rewriteAlgorithmValue(parent SQLNode, node AlgorithmValue, return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -9381,8 +9395,15 @@ func (a *application) rewriteArgument(parent SQLNode, node Argument, replacer re return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -9395,8 +9416,15 @@ func (a *application) rewriteBoolVal(parent SQLNode, node BoolVal, replacer repl return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -9427,8 +9455,15 @@ func (a *application) rewriteColIdent(parent SQLNode, node ColIdent, replacer re if err != nil { return err } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -9472,7 +9507,7 @@ func (a *application) rewriteColumns(parent SQLNode, node Columns, replacer repl a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -9489,8 +9524,15 @@ func (a *application) rewriteComments(parent SQLNode, node Comments, replacer re return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -9660,7 +9702,7 @@ func (a *application) rewriteExprs(parent SQLNode, node Exprs, replacer replacer a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -9689,7 +9731,7 @@ func (a *application) rewriteGroupBy(parent SQLNode, node GroupBy, replacer repl a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -9721,8 +9763,15 @@ func (a *application) rewriteIsolationLevel(parent SQLNode, node IsolationLevel, return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -9754,7 +9803,7 @@ func (a *application) rewriteJoinCondition(parent SQLNode, node JoinCondition, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -9771,8 +9820,15 @@ func (a *application) rewriteListArg(parent SQLNode, node ListArg, replacer repl return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -9800,7 +9856,7 @@ func (a *application) rewriteOnDup(parent SQLNode, node OnDup, replacer replacer a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -9829,7 +9885,7 @@ func (a *application) rewriteOrderBy(parent SQLNode, node OrderBy, replacer repl a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -9858,7 +9914,7 @@ func (a *application) rewritePartitions(parent SQLNode, node Partitions, replace a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -9897,7 +9953,7 @@ func (a *application) rewriteRefOfAddColumns(parent SQLNode, node *AddColumns, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -9924,7 +9980,7 @@ func (a *application) rewriteRefOfAddConstraintDefinition(parent SQLNode, node * a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -9951,7 +10007,7 @@ func (a *application) rewriteRefOfAddIndexDefinition(parent SQLNode, node *AddIn a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -9983,7 +10039,7 @@ func (a *application) rewriteRefOfAliasedExpr(parent SQLNode, node *AliasedExpr, a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10025,7 +10081,7 @@ func (a *application) rewriteRefOfAliasedTableExpr(parent SQLNode, node *Aliased a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10042,8 +10098,15 @@ func (a *application) rewriteRefOfAlterCharset(parent SQLNode, node *AlterCharse return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -10074,7 +10137,7 @@ func (a *application) rewriteRefOfAlterColumn(parent SQLNode, node *AlterColumn, a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10091,8 +10154,15 @@ func (a *application) rewriteRefOfAlterDatabase(parent SQLNode, node *AlterDatab return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -10108,8 +10178,15 @@ func (a *application) rewriteRefOfAlterMigration(parent SQLNode, node *AlterMigr return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -10147,7 +10224,7 @@ func (a *application) rewriteRefOfAlterTable(parent SQLNode, node *AlterTable, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10184,7 +10261,7 @@ func (a *application) rewriteRefOfAlterView(parent SQLNode, node *AlterView, rep a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10228,7 +10305,7 @@ func (a *application) rewriteRefOfAlterVschema(parent SQLNode, node *AlterVschem a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10260,7 +10337,7 @@ func (a *application) rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replace a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10292,7 +10369,7 @@ func (a *application) rewriteRefOfAutoIncSpec(parent SQLNode, node *AutoIncSpec, a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10309,8 +10386,15 @@ func (a *application) rewriteRefOfBegin(parent SQLNode, node *Begin, replacer re return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -10341,7 +10425,7 @@ func (a *application) rewriteRefOfBinaryExpr(parent SQLNode, node *BinaryExpr, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10373,7 +10457,7 @@ func (a *application) rewriteRefOfCallProc(parent SQLNode, node *CallProc, repla a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10412,7 +10496,7 @@ func (a *application) rewriteRefOfCaseExpr(parent SQLNode, node *CaseExpr, repla a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10454,7 +10538,7 @@ func (a *application) rewriteRefOfChangeColumn(parent SQLNode, node *ChangeColum a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10481,7 +10565,7 @@ func (a *application) rewriteRefOfCheckConstraintDefinition(parent SQLNode, node a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10498,8 +10582,15 @@ func (a *application) rewriteRefOfColIdent(parent SQLNode, node *ColIdent, repla return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -10530,7 +10621,7 @@ func (a *application) rewriteRefOfColName(parent SQLNode, node *ColName, replace a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10557,7 +10648,7 @@ func (a *application) rewriteRefOfCollateExpr(parent SQLNode, node *CollateExpr, a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10584,7 +10675,7 @@ func (a *application) rewriteRefOfColumnDefinition(parent SQLNode, node *ColumnD a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10616,7 +10707,7 @@ func (a *application) rewriteRefOfColumnType(parent SQLNode, node *ColumnType, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10633,8 +10724,15 @@ func (a *application) rewriteRefOfCommit(parent SQLNode, node *Commit, replacer return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -10670,7 +10768,7 @@ func (a *application) rewriteRefOfComparisonExpr(parent SQLNode, node *Compariso a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10697,7 +10795,7 @@ func (a *application) rewriteRefOfConstraintDefinition(parent SQLNode, node *Con a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10729,7 +10827,7 @@ func (a *application) rewriteRefOfConvertExpr(parent SQLNode, node *ConvertExpr, a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10761,7 +10859,7 @@ func (a *application) rewriteRefOfConvertType(parent SQLNode, node *ConvertType, a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10788,7 +10886,7 @@ func (a *application) rewriteRefOfConvertUsingExpr(parent SQLNode, node *Convert a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10815,7 +10913,7 @@ func (a *application) rewriteRefOfCreateDatabase(parent SQLNode, node *CreateDat a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10852,7 +10950,7 @@ func (a *application) rewriteRefOfCreateTable(parent SQLNode, node *CreateTable, a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10889,7 +10987,7 @@ func (a *application) rewriteRefOfCreateView(parent SQLNode, node *CreateView, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10921,7 +11019,7 @@ func (a *application) rewriteRefOfCurTimeFuncExpr(parent SQLNode, node *CurTimeF a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -10938,8 +11036,15 @@ func (a *application) rewriteRefOfDefault(parent SQLNode, node *Default, replace return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -10995,7 +11100,7 @@ func (a *application) rewriteRefOfDelete(parent SQLNode, node *Delete, replacer a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11022,7 +11127,7 @@ func (a *application) rewriteRefOfDerivedTable(parent SQLNode, node *DerivedTabl a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11049,7 +11154,7 @@ func (a *application) rewriteRefOfDropColumn(parent SQLNode, node *DropColumn, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11076,7 +11181,7 @@ func (a *application) rewriteRefOfDropDatabase(parent SQLNode, node *DropDatabas a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11093,8 +11198,15 @@ func (a *application) rewriteRefOfDropKey(parent SQLNode, node *DropKey, replace return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -11120,7 +11232,7 @@ func (a *application) rewriteRefOfDropTable(parent SQLNode, node *DropTable, rep a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11147,7 +11259,7 @@ func (a *application) rewriteRefOfDropView(parent SQLNode, node *DropView, repla a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11174,7 +11286,7 @@ func (a *application) rewriteRefOfExistsExpr(parent SQLNode, node *ExistsExpr, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11201,7 +11313,7 @@ func (a *application) rewriteRefOfExplainStmt(parent SQLNode, node *ExplainStmt, a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11228,7 +11340,7 @@ func (a *application) rewriteRefOfExplainTab(parent SQLNode, node *ExplainTab, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11255,7 +11367,7 @@ func (a *application) rewriteRefOfFlush(parent SQLNode, node *Flush, replacer re a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11272,8 +11384,15 @@ func (a *application) rewriteRefOfForce(parent SQLNode, node *Force, replacer re return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -11319,7 +11438,7 @@ func (a *application) rewriteRefOfForeignKeyDefinition(parent SQLNode, node *For a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11356,7 +11475,7 @@ func (a *application) rewriteRefOfFuncExpr(parent SQLNode, node *FuncExpr, repla a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11393,7 +11512,7 @@ func (a *application) rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupCon a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11420,7 +11539,7 @@ func (a *application) rewriteRefOfIndexDefinition(parent SQLNode, node *IndexDef a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11449,7 +11568,7 @@ func (a *application) rewriteRefOfIndexHints(parent SQLNode, node *IndexHints, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11481,7 +11600,7 @@ func (a *application) rewriteRefOfIndexInfo(parent SQLNode, node *IndexInfo, rep a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11533,7 +11652,7 @@ func (a *application) rewriteRefOfInsert(parent SQLNode, node *Insert, replacer a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11560,7 +11679,7 @@ func (a *application) rewriteRefOfIntervalExpr(parent SQLNode, node *IntervalExp a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11587,7 +11706,7 @@ func (a *application) rewriteRefOfIsExpr(parent SQLNode, node *IsExpr, replacer a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11619,7 +11738,7 @@ func (a *application) rewriteRefOfJoinCondition(parent SQLNode, node *JoinCondit a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11656,7 +11775,7 @@ func (a *application) rewriteRefOfJoinTableExpr(parent SQLNode, node *JoinTableE a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11673,8 +11792,15 @@ func (a *application) rewriteRefOfKeyState(parent SQLNode, node *KeyState, repla return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -11705,7 +11831,7 @@ func (a *application) rewriteRefOfLimit(parent SQLNode, node *Limit, replacer re a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11722,8 +11848,15 @@ func (a *application) rewriteRefOfLiteral(parent SQLNode, node *Literal, replace return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -11739,8 +11872,15 @@ func (a *application) rewriteRefOfLoad(parent SQLNode, node *Load, replacer repl return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -11756,8 +11896,15 @@ func (a *application) rewriteRefOfLockOption(parent SQLNode, node *LockOption, r return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -11773,8 +11920,15 @@ func (a *application) rewriteRefOfLockTables(parent SQLNode, node *LockTables, r return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -11805,7 +11959,7 @@ func (a *application) rewriteRefOfMatchExpr(parent SQLNode, node *MatchExpr, rep a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11842,7 +11996,7 @@ func (a *application) rewriteRefOfModifyColumn(parent SQLNode, node *ModifyColum a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11869,7 +12023,7 @@ func (a *application) rewriteRefOfNextval(parent SQLNode, node *Nextval, replace a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11896,7 +12050,7 @@ func (a *application) rewriteRefOfNotExpr(parent SQLNode, node *NotExpr, replace a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11913,8 +12067,15 @@ func (a *application) rewriteRefOfNullVal(parent SQLNode, node *NullVal, replace return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -11940,7 +12101,7 @@ func (a *application) rewriteRefOfOptLike(parent SQLNode, node *OptLike, replace a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11972,7 +12133,7 @@ func (a *application) rewriteRefOfOrExpr(parent SQLNode, node *OrExpr, replacer a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -11999,7 +12160,7 @@ func (a *application) rewriteRefOfOrder(parent SQLNode, node *Order, replacer re a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12026,7 +12187,7 @@ func (a *application) rewriteRefOfOrderByOption(parent SQLNode, node *OrderByOpt a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12043,8 +12204,15 @@ func (a *application) rewriteRefOfOtherAdmin(parent SQLNode, node *OtherAdmin, r return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -12060,8 +12228,15 @@ func (a *application) rewriteRefOfOtherRead(parent SQLNode, node *OtherRead, rep return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -12087,7 +12262,7 @@ func (a *application) rewriteRefOfParenSelect(parent SQLNode, node *ParenSelect, a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12114,7 +12289,7 @@ func (a *application) rewriteRefOfParenTableExpr(parent SQLNode, node *ParenTabl a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12146,7 +12321,7 @@ func (a *application) rewriteRefOfPartitionDefinition(parent SQLNode, node *Part a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12190,7 +12365,7 @@ func (a *application) rewriteRefOfPartitionSpec(parent SQLNode, node *PartitionS a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12227,7 +12402,7 @@ func (a *application) rewriteRefOfRangeCond(parent SQLNode, node *RangeCond, rep a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12254,7 +12429,7 @@ func (a *application) rewriteRefOfRelease(parent SQLNode, node *Release, replace a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12271,8 +12446,15 @@ func (a *application) rewriteRefOfRenameIndex(parent SQLNode, node *RenameIndex, return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -12288,8 +12470,15 @@ func (a *application) rewriteRefOfRenameTable(parent SQLNode, node *RenameTable, return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -12315,7 +12504,7 @@ func (a *application) rewriteRefOfRenameTableName(parent SQLNode, node *RenameTa a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12332,8 +12521,15 @@ func (a *application) rewriteRefOfRevertMigration(parent SQLNode, node *RevertMi return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -12349,8 +12545,15 @@ func (a *application) rewriteRefOfRollback(parent SQLNode, node *Rollback, repla return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -12376,7 +12579,7 @@ func (a *application) rewriteRefOfSRollback(parent SQLNode, node *SRollback, rep a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12403,7 +12606,7 @@ func (a *application) rewriteRefOfSavepoint(parent SQLNode, node *Savepoint, rep a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12470,7 +12673,7 @@ func (a *application) rewriteRefOfSelect(parent SQLNode, node *Select, replacer a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12487,8 +12690,15 @@ func (a *application) rewriteRefOfSelectInto(parent SQLNode, node *SelectInto, r return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -12519,7 +12729,7 @@ func (a *application) rewriteRefOfSet(parent SQLNode, node *Set, replacer replac a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12551,7 +12761,7 @@ func (a *application) rewriteRefOfSetExpr(parent SQLNode, node *SetExpr, replace a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12590,7 +12800,7 @@ func (a *application) rewriteRefOfSetTransaction(parent SQLNode, node *SetTransa a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12617,7 +12827,7 @@ func (a *application) rewriteRefOfShow(parent SQLNode, node *Show, replacer repl a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12649,7 +12859,7 @@ func (a *application) rewriteRefOfShowBasic(parent SQLNode, node *ShowBasic, rep a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12676,7 +12886,7 @@ func (a *application) rewriteRefOfShowCreate(parent SQLNode, node *ShowCreate, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12703,7 +12913,7 @@ func (a *application) rewriteRefOfShowFilter(parent SQLNode, node *ShowFilter, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12740,7 +12950,7 @@ func (a *application) rewriteRefOfShowLegacy(parent SQLNode, node *ShowLegacy, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12767,7 +12977,7 @@ func (a *application) rewriteRefOfStarExpr(parent SQLNode, node *StarExpr, repla a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12804,7 +13014,7 @@ func (a *application) rewriteRefOfStream(parent SQLNode, node *Stream, replacer a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12831,7 +13041,7 @@ func (a *application) rewriteRefOfSubquery(parent SQLNode, node *Subquery, repla a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12873,7 +13083,7 @@ func (a *application) rewriteRefOfSubstrExpr(parent SQLNode, node *SubstrExpr, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12890,8 +13100,15 @@ func (a *application) rewriteRefOfTableIdent(parent SQLNode, node *TableIdent, r return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -12922,7 +13139,7 @@ func (a *application) rewriteRefOfTableName(parent SQLNode, node *TableName, rep a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12970,7 +13187,7 @@ func (a *application) rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, rep a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -12987,8 +13204,15 @@ func (a *application) rewriteRefOfTablespaceOperation(parent SQLNode, node *Tabl return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -13019,7 +13243,7 @@ func (a *application) rewriteRefOfTimestampFuncExpr(parent SQLNode, node *Timest a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -13046,7 +13270,7 @@ func (a *application) rewriteRefOfTruncateTable(parent SQLNode, node *TruncateTa a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -13073,7 +13297,7 @@ func (a *application) rewriteRefOfUnaryExpr(parent SQLNode, node *UnaryExpr, rep a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -13117,7 +13341,7 @@ func (a *application) rewriteRefOfUnion(parent SQLNode, node *Union, replacer re a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -13144,7 +13368,7 @@ func (a *application) rewriteRefOfUnionSelect(parent SQLNode, node *UnionSelect, a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -13161,8 +13385,15 @@ func (a *application) rewriteRefOfUnlockTables(parent SQLNode, node *UnlockTable return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -13213,7 +13444,7 @@ func (a *application) rewriteRefOfUpdate(parent SQLNode, node *Update, replacer a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -13245,7 +13476,7 @@ func (a *application) rewriteRefOfUpdateExpr(parent SQLNode, node *UpdateExpr, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -13272,7 +13503,7 @@ func (a *application) rewriteRefOfUse(parent SQLNode, node *Use, replacer replac a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -13319,7 +13550,7 @@ func (a *application) rewriteRefOfVStream(parent SQLNode, node *VStream, replace a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -13336,8 +13567,15 @@ func (a *application) rewriteRefOfValidation(parent SQLNode, node *Validation, r return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -13363,7 +13601,7 @@ func (a *application) rewriteRefOfValuesFuncExpr(parent SQLNode, node *ValuesFun a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -13390,7 +13628,7 @@ func (a *application) rewriteRefOfVindexParam(parent SQLNode, node *VindexParam, a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -13429,7 +13667,7 @@ func (a *application) rewriteRefOfVindexSpec(parent SQLNode, node *VindexSpec, r a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -13461,7 +13699,7 @@ func (a *application) rewriteRefOfWhen(parent SQLNode, node *When, replacer repl a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -13488,7 +13726,7 @@ func (a *application) rewriteRefOfWhere(parent SQLNode, node *Where, replacer re a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -13520,7 +13758,7 @@ func (a *application) rewriteRefOfXorExpr(parent SQLNode, node *XorExpr, replace a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -13534,8 +13772,15 @@ func (a *application) rewriteReferenceAction(parent SQLNode, node ReferenceActio return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -13881,7 +14126,7 @@ func (a *application) rewriteSelectExprs(parent SQLNode, node SelectExprs, repla a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -13926,7 +14171,7 @@ func (a *application) rewriteSetExprs(parent SQLNode, node SetExprs, replacer re a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -14093,7 +14338,7 @@ func (a *application) rewriteTableExprs(parent SQLNode, node TableExprs, replace a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -14111,8 +14356,15 @@ func (a *application) rewriteTableIdent(parent SQLNode, node TableIdent, replace if err != nil { return err } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -14144,7 +14396,7 @@ func (a *application) rewriteTableName(parent SQLNode, node TableName, replacer a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -14173,7 +14425,7 @@ func (a *application) rewriteTableNames(parent SQLNode, node TableNames, replace a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -14190,8 +14442,15 @@ func (a *application) rewriteTableOptions(parent SQLNode, node TableOptions, rep return nil } } - if a.post != nil && !a.post(&a.cur) { - return errAbort + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } return nil } @@ -14219,7 +14478,7 @@ func (a *application) rewriteUpdateExprs(parent SQLNode, node UpdateExprs, repla a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -14248,7 +14507,7 @@ func (a *application) rewriteValTuple(parent SQLNode, node ValTuple, replacer re a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -14277,7 +14536,7 @@ func (a *application) rewriteValues(parent SQLNode, node Values, replacer replac a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil @@ -14305,7 +14564,7 @@ func (a *application) rewriteVindexParam(parent SQLNode, node VindexParam, repla a.cur.parent = parent a.cur.node = node if !a.post(&a.cur) { - return nil + return errAbort } } return nil