diff --git a/compile.go b/compile.go index a00a39c..31b60df 100644 --- a/compile.go +++ b/compile.go @@ -122,16 +122,19 @@ func (c *compiler) compileNode(n ast.Node) { c.compileStmt(n) case *ast.ValueSpec: c.compileValueSpec(n) - case stmtSlice: - c.compileStmtSlice(n) - case declSlice: - c.compileDeclSlice(n) - case ExprSlice: - c.compileExprSlice(n) case *rangeClause: c.compileRangeClause(n) case *rangeHeader: c.compileRangeHeader(n) + case *NodeSlice: + switch n.Kind { + case StmtNodeSlice: + c.compileStmtSlice(n.stmtSlice) + case DeclNodeSlice: + c.compileDeclSlice(n.declSlice) + case ExprNodeSlice: + c.compileExprSlice(n.exprSlice) + } default: panic(c.errorf(n, "compileNode: unexpected %T", n)) } @@ -1191,7 +1194,7 @@ func (c *compiler) compileSendStmt(n *ast.SendStmt) { c.compileExpr(n.Value) } -func (c *compiler) compileDeclSlice(decls declSlice) { +func (c *compiler) compileDeclSlice(decls []ast.Decl) { c.emitInstOp(opMultiDecl) for _, n := range decls { c.compileDecl(n) @@ -1199,7 +1202,7 @@ func (c *compiler) compileDeclSlice(decls declSlice) { c.emitInstOp(opEnd) } -func (c *compiler) compileStmtSlice(stmts stmtSlice) { +func (c *compiler) compileStmtSlice(stmts []ast.Stmt) { c.emitInstOp(opMultiStmt) insideStmtList := c.insideStmtList c.insideStmtList = true @@ -1210,7 +1213,7 @@ func (c *compiler) compileStmtSlice(stmts stmtSlice) { c.emitInstOp(opEnd) } -func (c *compiler) compileExprSlice(exprs ExprSlice) { +func (c *compiler) compileExprSlice(exprs []ast.Expr) { c.emitInstOp(opMultiExpr) for _, n := range exprs { c.compileExpr(n) diff --git a/gogrep.go b/gogrep.go index 313a9a2..47a03f9 100644 --- a/gogrep.go +++ b/gogrep.go @@ -11,7 +11,7 @@ import ( ) func IsEmptyNodeSlice(n ast.Node) bool { - if list, ok := n.(NodeSlice); ok { + if list, ok := n.(*NodeSlice); ok { return list.Len() == 0 } return false @@ -62,6 +62,9 @@ type MatcherState struct { // actual matching phase) capture []CapturedNode + nodeSlices []NodeSlice + nodeSlicesUsed int + pc int partial PartialNode @@ -69,7 +72,8 @@ type MatcherState struct { func NewMatcherState() MatcherState { return MatcherState{ - capture: make([]CapturedNode, 0, 8), + capture: make([]CapturedNode, 0, 8), + nodeSlices: make([]NodeSlice, 16), } } @@ -143,34 +147,37 @@ func Compile(config CompileConfig) (*Pattern, PatternInfo, error) { } func Walk(root ast.Node, fn func(n ast.Node) bool) { - switch root := root.(type) { - case ExprSlice: - for _, e := range root { - ast.Inspect(e, fn) - } - case stmtSlice: - for _, e := range root { - ast.Inspect(e, fn) - } - case fieldSlice: - for _, e := range root { - ast.Inspect(e, fn) - } - case identSlice: - for _, e := range root { - ast.Inspect(e, fn) + if root, ok := root.(*NodeSlice); ok { + switch root.Kind { + case ExprNodeSlice: + for _, e := range root.exprSlice { + ast.Inspect(e, fn) + } + case StmtNodeSlice: + for _, e := range root.stmtSlice { + ast.Inspect(e, fn) + } + case FieldNodeSlice: + for _, e := range root.fieldSlice { + ast.Inspect(e, fn) + } + case IdentNodeSlice: + for _, e := range root.identSlice { + ast.Inspect(e, fn) + } + case SpecNodeSlice: + for _, e := range root.specSlice { + ast.Inspect(e, fn) + } + default: + for _, e := range root.declSlice { + ast.Inspect(e, fn) + } } - case specSlice: - for _, e := range root { - ast.Inspect(e, fn) - } - case declSlice: - for _, e := range root { - ast.Inspect(e, fn) - } - default: - ast.Inspect(root, fn) + return } + + ast.Inspect(root, fn) } func newPatternInfo() PatternInfo { diff --git a/match.go b/match.go index d4e3243..d4b317b 100644 --- a/match.go +++ b/match.go @@ -45,8 +45,36 @@ func (m *matcher) resetCapture(state *MatcherState) { } } +func (m *matcher) toStmtSlice(state *MatcherState, nodes ...ast.Node) *NodeSlice { + slice := m.allocNodeSlice(state) + var stmts []ast.Stmt + for _, node := range nodes { + switch x := node.(type) { + case nil: + case ast.Stmt: + stmts = append(stmts, x) + case ast.Expr: + stmts = append(stmts, &ast.ExprStmt{X: x}) + default: + panic(fmt.Sprintf("unexpected node type: %T", x)) + } + } + slice.assignStmtSlice(stmts) + return slice +} + +func (m *matcher) allocNodeSlice(state *MatcherState) *NodeSlice { + if state.nodeSlicesUsed < len(state.nodeSlices) { + i := state.nodeSlicesUsed + state.nodeSlicesUsed++ + return &state.nodeSlices[i] + } + return &NodeSlice{} +} + func (m *matcher) MatchNode(state *MatcherState, n ast.Node, accept func(MatchData)) { state.pc = 0 + state.nodeSlicesUsed = 0 inst := m.nextInst(state) switch inst.op { case opMultiStmt: @@ -91,24 +119,32 @@ func (m *matcher) MatchNode(state *MatcherState, n ast.Node, accept func(MatchDa } func (m *matcher) walkDeclSlice(state *MatcherState, decls []ast.Decl, accept func(MatchData)) { - m.walkNodeSlice(state, declSlice(decls), accept) + slice := m.allocNodeSlice(state) + slice.assignDeclSlice(decls) + m.walkNodeSlice(state, slice, accept) } func (m *matcher) walkExprSlice(state *MatcherState, exprs []ast.Expr, accept func(MatchData)) { - m.walkNodeSlice(state, ExprSlice(exprs), accept) + slice := m.allocNodeSlice(state) + slice.assignExprSlice(exprs) + m.walkNodeSlice(state, slice, accept) } func (m *matcher) walkStmtSlice(state *MatcherState, stmts []ast.Stmt, accept func(MatchData)) { - m.walkNodeSlice(state, stmtSlice(stmts), accept) + slice := m.allocNodeSlice(state) + slice.assignStmtSlice(stmts) + m.walkNodeSlice(state, slice, accept) } -func (m *matcher) walkNodeSlice(state *MatcherState, nodes NodeSlice, accept func(MatchData)) { +func (m *matcher) walkNodeSlice(state *MatcherState, nodes *NodeSlice, accept func(MatchData)) { sliceLen := nodes.Len() from := 0 + tmpSlice := m.allocNodeSlice(state) for { state.pc = 1 // FIXME: this is a kludge m.resetCapture(state) - matched, offset := m.matchNodeList(state, nodes.slice(from, sliceLen), true) + nodes.SliceInto(tmpSlice, from, sliceLen) + matched, offset := m.matchNodeList(state, tmpSlice, true) if matched == nil { break } @@ -422,11 +458,11 @@ func (m *matcher) matchNodeWithInst(state *MatcherState, inst instruction, n ast case opIfNamedOptStmt: n, ok := n.(*ast.IfStmt) return ok && n.Else == nil && m.matchNode(state, n.Body) && - m.matchNamed(state, m.stringValue(inst), toStmtSlice(n.Cond, n.Init)) + m.matchNamed(state, m.stringValue(inst), m.toStmtSlice(state, n.Cond, n.Init)) case opIfNamedOptElseStmt: n, ok := n.(*ast.IfStmt) return ok && n.Else != nil && m.matchNode(state, n.Body) && m.matchNode(state, n.Else) && - m.matchNamed(state, m.stringValue(inst), toStmtSlice(n.Cond, n.Init)) + m.matchNamed(state, m.stringValue(inst), m.toStmtSlice(state, n.Cond, n.Init)) case opCaseClause: n, ok := n.(*ast.CaseClause) @@ -641,33 +677,43 @@ func (m *matcher) matchArgList(state *MatcherState, exprs []ast.Expr) bool { } func (m *matcher) matchStmtSlice(state *MatcherState, stmts []ast.Stmt) bool { - matched, _ := m.matchNodeList(state, stmtSlice(stmts), false) + slice := m.allocNodeSlice(state) + slice.assignStmtSlice(stmts) + matched, _ := m.matchNodeList(state, slice, false) return matched != nil } func (m *matcher) matchExprSlice(state *MatcherState, exprs []ast.Expr) bool { - matched, _ := m.matchNodeList(state, ExprSlice(exprs), false) + slice := m.allocNodeSlice(state) + slice.assignExprSlice(exprs) + matched, _ := m.matchNodeList(state, slice, false) return matched != nil } func (m *matcher) matchFieldSlice(state *MatcherState, fields []*ast.Field) bool { - matched, _ := m.matchNodeList(state, fieldSlice(fields), false) + slice := m.allocNodeSlice(state) + slice.assignFieldSlice(fields) + matched, _ := m.matchNodeList(state, slice, false) return matched != nil } func (m *matcher) matchIdentSlice(state *MatcherState, idents []*ast.Ident) bool { - matched, _ := m.matchNodeList(state, identSlice(idents), false) + slice := m.allocNodeSlice(state) + slice.assignIdentSlice(idents) + matched, _ := m.matchNodeList(state, slice, false) return matched != nil } func (m *matcher) matchSpecSlice(state *MatcherState, specs []ast.Spec) bool { - matched, _ := m.matchNodeList(state, specSlice(specs), false) + slice := m.allocNodeSlice(state) + slice.assignSpecSlice(specs) + matched, _ := m.matchNodeList(state, slice, false) return matched != nil } // matchNodeList matches two lists of nodes. It uses a common algorithm to match // wildcard patterns with any number of nodes without recursion. -func (m *matcher) matchNodeList(state *MatcherState, nodes NodeSlice, partial bool) (matched ast.Node, offset int) { +func (m *matcher) matchNodeList(state *MatcherState, nodes *NodeSlice, partial bool) (matched ast.Node, offset int) { sliceLen := nodes.Len() inst := m.nextInst(state) if inst.op == opEnd { @@ -727,7 +773,9 @@ func (m *matcher) matchNodeList(state *MatcherState, nodes NodeSlice, partial bo case "", "_": return true } - return m.matchNamed(state, wildName, nodes.slice(wildStart, j)) + slice := m.allocNodeSlice(state) + nodes.SliceInto(slice, wildStart, j) + return m.matchNamed(state, wildName, slice) } for ; inst.op != opEnd || j < sliceLen; inst = m.nextInst(state) { if inst.op != opEnd { @@ -776,7 +824,9 @@ func (m *matcher) matchNodeList(state *MatcherState, nodes NodeSlice, partial bo if !wouldMatch() { return nil, -1 } - return nodes.slice(partialStart, partialEnd), partialEnd + 1 + slice := m.allocNodeSlice(state) + nodes.SliceInto(slice, partialStart, partialEnd) + return slice, partialEnd + 1 } func (m *matcher) matchRangeClause(state *MatcherState, n ast.Node, accept func(MatchData)) { @@ -919,58 +969,56 @@ func equalNodes(x, y ast.Node) bool { if x == nil || y == nil { return x == y } - switch x := x.(type) { - case stmtSlice: - y, ok := y.(stmtSlice) - if !ok || len(x) != len(y) { + if x, ok := x.(*NodeSlice); ok { + y, ok := y.(*NodeSlice) + if !ok || x.Kind != y.Kind || x.Len() != y.Len() { return false } - for i := range x { - if !astequal.Stmt(x[i], y[i]) { - return false + switch x.Kind { + case ExprNodeSlice: + for i, n1 := range x.exprSlice { + n2 := y.exprSlice[i] + if !astequal.Expr(n1, n2) { + return false + } } - } - return true - case ExprSlice: - y, ok := y.(ExprSlice) - if !ok || len(x) != len(y) { - return false - } - for i := range x { - if !astequal.Expr(x[i], y[i]) { - return false + case StmtNodeSlice: + for i, n1 := range x.stmtSlice { + n2 := y.stmtSlice[i] + if !astequal.Stmt(n1, n2) { + return false + } } - } - return true - case declSlice: - y, ok := y.(declSlice) - if !ok || len(x) != len(y) { - return false - } - for i := range x { - if !astequal.Decl(x[i], y[i]) { - return false + case FieldNodeSlice: + for i, n1 := range x.fieldSlice { + n2 := y.fieldSlice[i] + if !astequal.Node(n1, n2) { + return false + } + } + case IdentNodeSlice: + for i, n1 := range x.identSlice { + n2 := y.identSlice[i] + if n1.Name != n2.Name { + return false + } + } + case SpecNodeSlice: + for i, n1 := range x.specSlice { + n2 := y.specSlice[i] + if !astequal.Node(n1, n2) { + return false + } + } + case DeclNodeSlice: + for i, n1 := range x.declSlice { + n2 := y.declSlice[i] + if !astequal.Decl(n1, n2) { + return false + } } } return true - - default: - return astequal.Node(x, y) - } -} - -func toStmtSlice(nodes ...ast.Node) stmtSlice { - var stmts []ast.Stmt - for _, node := range nodes { - switch x := node.(type) { - case nil: - case ast.Stmt: - stmts = append(stmts, x) - case ast.Expr: - stmts = append(stmts, &ast.ExprStmt{X: x}) - default: - panic(fmt.Sprintf("unexpected node type: %T", x)) - } } - return stmtSlice(stmts) + return astequal.Node(x, y) } diff --git a/match_perf_test.go b/match_perf_test.go index 2680365..34e5891 100644 --- a/match_perf_test.go +++ b/match_perf_test.go @@ -157,6 +157,11 @@ func BenchmarkMatch(b *testing.B) { pat: `f($x, $*ys)`, input: `f(1, 2, 3, 4, 5, 6)`, }, + { + name: `exprList`, + pat: `g(f($*_, $x, $x), $*_, 0)`, + input: `g(f(1, 2, 3, 4, 5, 6, 6), -1, -1, 0)`, + }, } for i := range tests { diff --git a/parse.go b/parse.go index aa5ffbf..3c6854b 100644 --- a/parse.go +++ b/parse.go @@ -174,7 +174,9 @@ func parseDetectingNode(fset *token.FileSet, src string) (ast.Node, error) { if len(cl.Elts) == 1 { return cl.Elts[0], nil } - return ExprSlice(cl.Elts), nil + slice := &NodeSlice{} + slice.assignExprSlice(cl.Elts) + return slice, nil } // then try as statements @@ -185,7 +187,9 @@ func parseDetectingNode(fset *token.FileSet, src string) (ast.Node, error) { if len(bl.List) == 1 { return bl.List[0], nil } - return stmtSlice(bl.List), nil + slice := &NodeSlice{} + slice.assignStmtSlice(bl.List) + return slice, nil } // Statements is what covers most cases, so it will give // the best overall error message. Show positions @@ -199,7 +203,9 @@ func parseDetectingNode(fset *token.FileSet, src string) (ast.Node, error) { if len(f.Decls) == 1 { return f.Decls[0], nil } - return declSlice(f.Decls), nil + slice := &NodeSlice{} + slice.assignDeclSlice(f.Decls) + return slice, nil } // try as a whole file diff --git a/slices.go b/slices.go index 13775a8..fb969b5 100644 --- a/slices.go +++ b/slices.go @@ -5,54 +5,146 @@ import ( "go/token" ) -type NodeSlice interface { - At(i int) ast.Node - Len() int - slice(from, to int) NodeSlice - ast.Node -} +type NodeSliceKind uint32 + +const ( + ExprNodeSlice NodeSliceKind = iota + StmtNodeSlice + FieldNodeSlice + IdentNodeSlice + SpecNodeSlice + DeclNodeSlice +) + +type NodeSlice struct { + Kind NodeSliceKind -type ( - ExprSlice []ast.Expr + exprSlice []ast.Expr stmtSlice []ast.Stmt fieldSlice []*ast.Field identSlice []*ast.Ident specSlice []ast.Spec declSlice []ast.Decl -) +} + +func (s *NodeSlice) GetExprSlice() []ast.Expr { return s.exprSlice } +func (s *NodeSlice) GetStmtSlice() []ast.Stmt { return s.stmtSlice } +func (s *NodeSlice) GetFieldSlice() []*ast.Field { return s.fieldSlice } +func (s *NodeSlice) GetIdentSlice() []*ast.Ident { return s.identSlice } +func (s *NodeSlice) GetSpecSlice() []ast.Spec { return s.specSlice } +func (s *NodeSlice) GetDeclSlice() []ast.Decl { return s.declSlice } + +func (s *NodeSlice) assignExprSlice(xs []ast.Expr) { + s.Kind = ExprNodeSlice + s.exprSlice = xs +} + +func (s *NodeSlice) assignStmtSlice(xs []ast.Stmt) { + s.Kind = StmtNodeSlice + s.stmtSlice = xs +} + +func (s *NodeSlice) assignFieldSlice(xs []*ast.Field) { + s.Kind = FieldNodeSlice + s.fieldSlice = xs +} + +func (s *NodeSlice) assignIdentSlice(xs []*ast.Ident) { + s.Kind = IdentNodeSlice + s.identSlice = xs +} + +func (s *NodeSlice) assignSpecSlice(xs []ast.Spec) { + s.Kind = SpecNodeSlice + s.specSlice = xs +} + +func (s *NodeSlice) assignDeclSlice(xs []ast.Decl) { + s.Kind = DeclNodeSlice + s.declSlice = xs +} + +func (s *NodeSlice) Len() int { + switch s.Kind { + case ExprNodeSlice: + return len(s.exprSlice) + case StmtNodeSlice: + return len(s.stmtSlice) + case FieldNodeSlice: + return len(s.fieldSlice) + case IdentNodeSlice: + return len(s.identSlice) + case SpecNodeSlice: + return len(s.specSlice) + default: + return len(s.declSlice) + } +} -func (l ExprSlice) Len() int { return len(l) } -func (l ExprSlice) At(i int) ast.Node { return l[i] } -func (l ExprSlice) slice(i, j int) NodeSlice { return l[i:j] } -func (l ExprSlice) Pos() token.Pos { return l[0].Pos() } -func (l ExprSlice) End() token.Pos { return l[len(l)-1].End() } - -func (l stmtSlice) Len() int { return len(l) } -func (l stmtSlice) At(i int) ast.Node { return l[i] } -func (l stmtSlice) slice(i, j int) NodeSlice { return l[i:j] } -func (l stmtSlice) Pos() token.Pos { return l[0].Pos() } -func (l stmtSlice) End() token.Pos { return l[len(l)-1].End() } - -func (l fieldSlice) Len() int { return len(l) } -func (l fieldSlice) At(i int) ast.Node { return l[i] } -func (l fieldSlice) slice(i, j int) NodeSlice { return l[i:j] } -func (l fieldSlice) Pos() token.Pos { return l[0].Pos() } -func (l fieldSlice) End() token.Pos { return l[len(l)-1].End() } - -func (l identSlice) Len() int { return len(l) } -func (l identSlice) At(i int) ast.Node { return l[i] } -func (l identSlice) slice(i, j int) NodeSlice { return l[i:j] } -func (l identSlice) Pos() token.Pos { return l[0].Pos() } -func (l identSlice) End() token.Pos { return l[len(l)-1].End() } - -func (l specSlice) Len() int { return len(l) } -func (l specSlice) At(i int) ast.Node { return l[i] } -func (l specSlice) slice(i, j int) NodeSlice { return l[i:j] } -func (l specSlice) Pos() token.Pos { return l[0].Pos() } -func (l specSlice) End() token.Pos { return l[len(l)-1].End() } - -func (l declSlice) Len() int { return len(l) } -func (l declSlice) At(i int) ast.Node { return l[i] } -func (l declSlice) slice(i, j int) NodeSlice { return l[i:j] } -func (l declSlice) Pos() token.Pos { return l[0].Pos() } -func (l declSlice) End() token.Pos { return l[len(l)-1].End() } +func (s *NodeSlice) At(i int) ast.Node { + switch s.Kind { + case ExprNodeSlice: + return s.exprSlice[i] + case StmtNodeSlice: + return s.stmtSlice[i] + case FieldNodeSlice: + return s.fieldSlice[i] + case IdentNodeSlice: + return s.identSlice[i] + case SpecNodeSlice: + return s.specSlice[i] + default: + return s.declSlice[i] + } +} + +func (s *NodeSlice) SliceInto(dst *NodeSlice, i, j int) { + switch s.Kind { + case ExprNodeSlice: + dst.assignExprSlice(s.exprSlice[i:j]) + case StmtNodeSlice: + dst.assignStmtSlice(s.stmtSlice[i:j]) + case FieldNodeSlice: + dst.assignFieldSlice(s.fieldSlice[i:j]) + case IdentNodeSlice: + dst.assignIdentSlice(s.identSlice[i:j]) + case SpecNodeSlice: + dst.assignSpecSlice(s.specSlice[i:j]) + default: + dst.assignDeclSlice(s.declSlice[i:j]) + } +} + +func (s *NodeSlice) Pos() token.Pos { + switch s.Kind { + case ExprNodeSlice: + return s.exprSlice[0].Pos() + case StmtNodeSlice: + return s.stmtSlice[0].Pos() + case FieldNodeSlice: + return s.fieldSlice[0].Pos() + case IdentNodeSlice: + return s.identSlice[0].Pos() + case SpecNodeSlice: + return s.specSlice[0].Pos() + default: + return s.declSlice[0].Pos() + } +} + +func (s *NodeSlice) End() token.Pos { + switch s.Kind { + case ExprNodeSlice: + return s.exprSlice[len(s.exprSlice)-1].End() + case StmtNodeSlice: + return s.stmtSlice[len(s.stmtSlice)-1].End() + case FieldNodeSlice: + return s.fieldSlice[len(s.fieldSlice)-1].End() + case IdentNodeSlice: + return s.identSlice[len(s.identSlice)-1].End() + case SpecNodeSlice: + return s.specSlice[len(s.specSlice)-1].End() + default: + return s.declSlice[len(s.declSlice)-1].End() + } +} diff --git a/slices_perf_test.go b/slices_perf_test.go new file mode 100644 index 0000000..1716d0c --- /dev/null +++ b/slices_perf_test.go @@ -0,0 +1,51 @@ +package gogrep + +import ( + "go/ast" + "testing" +) + +func BenchmarkExprSlice(b *testing.B) { + slice := &NodeSlice{ + Kind: ExprNodeSlice, + exprSlice: []ast.Expr{ + &ast.Ident{Name: "a"}, + &ast.Ident{Name: "b"}, + &ast.Ident{Name: "c"}, + &ast.Ident{Name: "d"}, + }, + } + + b.Run("get", func(b *testing.B) { + for i := 0; i < b.N; i++ { + l := slice.Len() + for j := 0; j < l; j++ { + n := slice.At(j) + if n == nil { + b.Fail() + } + } + } + }) + + b.Run("slice", func(b *testing.B) { + var dst NodeSlice + for i := 0; i < b.N; i++ { + slice.SliceInto(&dst, 0, 2) + } + if dst.Len() == 0 { + b.Fail() + } + }) + + b.Run("pos", func(b *testing.B) { + total := 0 + for i := 0; i < b.N; i++ { + total += int(slice.Pos()) + total += int(slice.End()) + } + if total == 0 { + b.Fail() + } + }) +}