Skip to content

Commit

Permalink
all: make syntax patterns aware of stdlib packages (#260)
Browse files Browse the repository at this point in the history
When you write `fmt.Sprint`, you probably want to only match
`Sprint` function from `fmt` package, not a `Sprint` method
call on `fmt` variable or whatever.

If we need to match a variable, a filter expression in `Where` can
explicitly state that.

For now, we only recognize stdlib packages.
In the future, I would consider matching `m.Import()`-ed packages
in a smart way as well. That could simplify the rules writing
and make them more correct by default.
  • Loading branch information
quasilyte authored Oct 10, 2021
1 parent b789ea4 commit 519222b
Show file tree
Hide file tree
Showing 17 changed files with 490 additions and 335 deletions.
1 change: 1 addition & 0 deletions analyzer/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ var tests = []struct {
{name: "matching"},
{name: "dgryski"},
{name: "comments"},
{name: "stdlib"},
{name: "goversion", flags: map[string]string{"go": "1.16"}},
}

Expand Down
20 changes: 20 additions & 0 deletions analyzer/testdata/src/stdlib/file.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package stdlib

import "io"

type foo struct{}

func (foo) WriteString(args ...interface{}) {}

func (foo) Sprint(args ...interface{}) string { return "" }

func sink(args ...interface{}) {}

func test(w io.Writer) {
io.WriteString(w, "") // want `\QWriteString from stdlib`

{
var io foo
io.WriteString(w, "")
}
}
17 changes: 17 additions & 0 deletions analyzer/testdata/src/stdlib/file2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package stdlib

import (
"fmt"
io "myio"
)

func _(w io.Writer) {
io.WriteString(w, "")

sink(fmt.Sprint(1), fmt.Sprint("ok")) // want `\Qsink with two Sprint from stdlib`

{
var fmt foo
sink(fmt.Sprint(1), fmt.Sprint("ok"))
}
}
18 changes: 18 additions & 0 deletions analyzer/testdata/src/stdlib/file3.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package stdlib

import (
iorenamed "io"
)

func _(w iorenamed.Writer) {
iorenamed.WriteString(w, "") // want `\QWriteString from stdlib`

{
var io foo
io.WriteString(w, "")
}
{
var iorenamed foo
iorenamed.WriteString(w, "")
}
}
13 changes: 13 additions & 0 deletions analyzer/testdata/src/stdlib/rules.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// +build ignore

package gorules

import (
"github.com/quasilyte/go-ruleguard/dsl"
)

func testRules(m dsl.Matcher) {
m.Match(`io.WriteString($*_)`).Report(`WriteString from stdlib`)

m.Match(`sink(fmt.Sprint($_), fmt.Sprint($_))`).Report(`sink with two Sprint from stdlib`)
}
9 changes: 9 additions & 0 deletions analyzer/testdata/src/stdlib/vendor/myio/myio.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package myio

import "io"

type Writer interface {
io.Writer
}

func WriteString(w Writer, s string) {}
29 changes: 28 additions & 1 deletion internal/gogrep/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"go/ast"
"go/token"

"github.com/quasilyte/go-ruleguard/internal/stdinfo"
)

type compileError string
Expand Down Expand Up @@ -415,13 +417,38 @@ func (c *compiler) compileCallExpr(n *ast.CallExpr) {
}

c.emitInstOp(op)
c.compileExpr(n.Fun)
c.compileSymbol(n.Fun)
for _, arg := range n.Args {
c.compileExpr(arg)
}
c.emitInstOp(opEnd)
}

// compileSymbol is mostly like a normal compileExpr, but it's used
// in places where we can find a type/function symbol.
//
// For example, in function call expressions a called function expression
// can look like `fmt.Sprint`. It will be compiled as a special
// selector expression that requires `fmt` to be a package as opposed
// to only check that it's an identifier with "fmt" value.
func (c *compiler) compileSymbol(fn ast.Expr) {
if e, ok := fn.(*ast.SelectorExpr); ok {
if ident, ok := e.X.(*ast.Ident); ok && stdinfo.Packages[ident.Name] != "" {
c.emitInst(instruction{
op: opSimpleSelectorExpr,
valueIndex: c.internString(e.Sel, e.Sel.String()),
})
c.emitInst(instruction{
op: opStdlibPkg,
valueIndex: c.internString(ident, ident.Name),
})
return
}
}

c.compileExpr(fn)
}

func (c *compiler) compileUnaryExpr(n *ast.UnaryExpr) {
c.prog.insts = append(c.prog.insts, instruction{
op: opUnaryExpr,
Expand Down
17 changes: 17 additions & 0 deletions internal/gogrep/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,23 @@ func TestCompile(t *testing.T) {
` • • End`,
` • End`,
},

`fmt.Println()`: {
`NonVariadicCallExpr`,
` • SimpleSelectorExpr Println`,
` • • StdlibPkg fmt`,
` • End`,
},

`x = fmt.Sprint(y)`: {
`AssignStmt =`,
` • Ident x`,
` • NonVariadicCallExpr`,
` • • SimpleSelectorExpr Sprint`,
` • • • StdlibPkg fmt`,
` • • Ident y`,
` • • End`,
},
})

for i := range tests {
Expand Down
5 changes: 3 additions & 2 deletions internal/gogrep/gogrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gogrep
import (
"go/ast"
"go/token"
"go/types"

"github.com/quasilyte/go-ruleguard/nodetag"
)
Expand Down Expand Up @@ -41,8 +42,8 @@ func (p *Pattern) NodeTag() nodetag.Value {
}

// MatchNode calls cb if n matches a pattern.
func (p *Pattern) MatchNode(n ast.Node, cb func(MatchData)) {
p.m.MatchNode(n, cb)
func (p *Pattern) MatchNode(info *types.Info, n ast.Node, cb func(MatchData)) {
p.m.MatchNode(info, n, cb)
}

// Clone creates a pattern copy.
Expand Down
20 changes: 19 additions & 1 deletion internal/gogrep/match.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import (
"fmt"
"go/ast"
"go/token"
"go/types"
"strconv"

"github.com/go-toolsmith/astequal"
"github.com/quasilyte/go-ruleguard/internal/stdinfo"
)

type matcher struct {
Expand All @@ -18,6 +20,8 @@ type matcher struct {
// node values recorded by name, excluding "_" (used only by the
// actual matching phase)
capture []CapturedNode

types *types.Info
}

func newMatcher(prog *program) *matcher {
Expand All @@ -42,8 +46,9 @@ func (m *matcher) ifaceValue(inst instruction) interface{} {
return m.prog.ifaces[inst.valueIndex]
}

func (m *matcher) MatchNode(n ast.Node, accept func(MatchData)) {
func (m *matcher) MatchNode(info *types.Info, n ast.Node, accept func(MatchData)) {
m.pc = 0
m.types = info
inst := m.nextInst()
switch inst.op {
case opMultiStmt:
Expand Down Expand Up @@ -150,6 +155,19 @@ func (m *matcher) matchNodeWithInst(inst instruction, n ast.Node) bool {
n, ok := n.(*ast.Ident)
return ok && m.stringValue(inst) == n.Name

case opStdlibPkg:
n, ok := n.(*ast.Ident)
if !ok {
return false
}
obj := m.types.ObjectOf(n)
if obj == nil {
return false
}
pkgName, ok := obj.(*types.PkgName)
return ok && m.stringValue(inst) == pkgName.Imported().Name() &&
pkgName.Imported().Path() == stdinfo.Packages[pkgName.Imported().Name()]

case opBinaryExpr:
n, ok := n.(*ast.BinaryExpr)
return ok && n.Op == token.Token(inst.value) &&
Expand Down
2 changes: 1 addition & 1 deletion internal/gogrep/match_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ func testAllMatches(p *Pattern, target ast.Node, cb func(MatchData)) {
if n == nil {
return false
}
p.MatchNode(n, cb)
p.MatchNode(nil, n, cb)
return true
}
ast.Inspect(target, visit)
Expand Down
Loading

0 comments on commit 519222b

Please sign in to comment.