diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 437dc1f2c47f..f48d3a9d4d17 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -427,6 +427,7 @@ ALL_TESTS = [ "//pkg/testutils/lint/passes/errwrap:errwrap_test", "//pkg/testutils/lint/passes/fmtsafe:fmtsafe_test", "//pkg/testutils/lint/passes/forbiddenmethod:forbiddenmethod_test", + "//pkg/testutils/lint/passes/goroutinerefcapture:goroutinerefcapture_test", "//pkg/testutils/lint/passes/hash:hash_test", "//pkg/testutils/lint/passes/leaktestcall:leaktestcall_test", "//pkg/testutils/lint/passes/nilness:nilness_test", diff --git a/pkg/cmd/roachvet/BUILD.bazel b/pkg/cmd/roachvet/BUILD.bazel index f9b2ca205862..8c46f868d598 100644 --- a/pkg/cmd/roachvet/BUILD.bazel +++ b/pkg/cmd/roachvet/BUILD.bazel @@ -10,6 +10,7 @@ go_library( "//pkg/testutils/lint/passes/errwrap", "//pkg/testutils/lint/passes/fmtsafe", "//pkg/testutils/lint/passes/forbiddenmethod", + "//pkg/testutils/lint/passes/goroutinerefcapture", "//pkg/testutils/lint/passes/hash", "//pkg/testutils/lint/passes/leaktestcall", "//pkg/testutils/lint/passes/nilness", diff --git a/pkg/cmd/roachvet/main.go b/pkg/cmd/roachvet/main.go index 40fb3e5010d9..772409af7b2a 100644 --- a/pkg/cmd/roachvet/main.go +++ b/pkg/cmd/roachvet/main.go @@ -18,6 +18,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/testutils/lint/passes/errwrap" "github.com/cockroachdb/cockroach/pkg/testutils/lint/passes/fmtsafe" "github.com/cockroachdb/cockroach/pkg/testutils/lint/passes/forbiddenmethod" + "github.com/cockroachdb/cockroach/pkg/testutils/lint/passes/goroutinerefcapture" "github.com/cockroachdb/cockroach/pkg/testutils/lint/passes/hash" "github.com/cockroachdb/cockroach/pkg/testutils/lint/passes/leaktestcall" "github.com/cockroachdb/cockroach/pkg/testutils/lint/passes/nilness" @@ -67,6 +68,7 @@ func main() { errcmp.Analyzer, nilness.Analyzer, errwrap.Analyzer, + goroutinerefcapture.Analyzer, ) // Standard go vet analyzers: diff --git a/pkg/testutils/lint/passes/goroutinerefcapture/BUILD.bazel b/pkg/testutils/lint/passes/goroutinerefcapture/BUILD.bazel new file mode 100644 index 000000000000..fb9df4aa2097 --- /dev/null +++ b/pkg/testutils/lint/passes/goroutinerefcapture/BUILD.bazel @@ -0,0 +1,32 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "goroutinerefcapture", + srcs = [ + "goroutinerefcapture.go", + "loop.go", + ], + importpath = "github.com/cockroachdb/cockroach/pkg/testutils/lint/passes/goroutinerefcapture", + visibility = ["//visibility:public"], + deps = [ + "//pkg/testutils/lint/passes/passesutil", + "@org_golang_x_tools//go/analysis", + "@org_golang_x_tools//go/analysis/passes/inspect", + "@org_golang_x_tools//go/ast/inspector", + ], +) + +go_test( + name = "goroutinerefcapture_test", + srcs = ["goroutinerefcapture_test.go"], + data = glob(["testdata/**"]) + [ + "@go_sdk//:files", + ], + deps = [ + ":goroutinerefcapture", + "//pkg/build/bazel", + "//pkg/testutils", + "//pkg/testutils/skip", + "@org_golang_x_tools//go/analysis/analysistest", + ], +) diff --git a/pkg/testutils/lint/passes/goroutinerefcapture/goroutinerefcapture.go b/pkg/testutils/lint/passes/goroutinerefcapture/goroutinerefcapture.go new file mode 100644 index 000000000000..2b8697d0610f --- /dev/null +++ b/pkg/testutils/lint/passes/goroutinerefcapture/goroutinerefcapture.go @@ -0,0 +1,269 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package goroutinerefcapture + +import ( + "fmt" + "go/ast" + "strings" + + "github.com/cockroachdb/cockroach/pkg/testutils/lint/passes/passesutil" + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/inspect" + astinspector "golang.org/x/tools/go/ast/inspector" +) + +const ( + name = "goroutinerefcapture" + + doc = `check for loop variables captured by reference in Goroutines, +which often leads to data races.` +) + +// Analyzer implements this linter, looking for loop variables +// captured by reference in closures called in Go routines +var Analyzer = &analysis.Analyzer{ + Name: name, + Doc: doc, + Requires: []*analysis.Analyzer{inspect.Analyzer}, + Run: run, +} + +// run is the linter entrypoint +func run(pass *analysis.Pass) (interface{}, error) { + inspector := pass.ResultOf[inspect.Analyzer].(*astinspector.Inspector) + loops := []ast.Node{ + (*ast.RangeStmt)(nil), + (*ast.ForStmt)(nil), + } + + // visit every `for` and `range` loops; when a loop is found, we + // stop traversing (return false) since the logic responsible for + // checking the body of loops is implemented in the `Visitor` struct + inspector.Nodes(loops, func(n ast.Node, _ bool) bool { + loop := NewLoop(n) + if loop.IsEmpty() { + return true + } + + v := NewVisitor(pass, loop) + for _, issue := range v.FindCaptures() { + pass.Report(issue) + } + return false + }) + + return nil, nil +} + +// Visitor implements the logic of checking for use of loop variables +// in Go routines either directly (referencing a loop variable in the +// function literal passed to `go`) or indirectly (calling a local +// function that captures loop variables by reference). +type Visitor struct { + loop *Loop + pass *analysis.Pass + + // closures keeps a map from local closures that capture loop + // variables by reference, to the captured loop variable. + closures map[*ast.Object]*ast.Ident + // issues accumulates issues found in a loop + issues []analysis.Diagnostic +} + +// NewVisitor creates a new Visitor instance for the given loop. +func NewVisitor(pass *analysis.Pass, loop *Loop) *Visitor { + return &Visitor{ + loop: loop, + pass: pass, + closures: map[*ast.Object]*ast.Ident{}, + } +} + +// FindCaptures returns a list of Diagnostic instances to be reported +// to the user +func (v *Visitor) FindCaptures() []analysis.Diagnostic { + ast.Inspect(v.loop.Body, v.visitLoopBody) + return v.issues +} + +// visitLoopBody ignores everything but `go` statements and +// assignments. +// +// When an assignment to a function literal is found, we check if the +// closure captures any of the loop variables; in case it does, the +// 'closures' map is updated. +// +// When a `go` statement is found, we look for function literals +// (closures) in either the function being called itself, or in +// parameters in the function call. +// +// In other words, both of the following scenarios are problematic and +// reported by this linter: +// +// 1: +// for k, v := range myMap { +// go func() { +// fmt.Printf("k = %v, v = %v\n", k, v) +// }() +// } +// +// 2: +// for k, v := range myMap { +// go doWork(func() { +// doMoreWork(k, v) +// }) +// } +// +// If a go routine calls a previously-defined closure that captures a +// loop variable, that is also reported. +func (v *Visitor) visitLoopBody(n ast.Node) bool { + switch stmt := n.(type) { + case *ast.GoStmt: + ast.Inspect(stmt.Call, v.funcLitInspector(v.addIssue)) + switch ident := stmt.Call.Fun.(type) { + case *ast.Ident: + if _, ok := v.closures[ident.Obj]; ok { + v.addIssue(ident) + } + } + + case *ast.AssignStmt: + for i, rhs := range stmt.Rhs { + lhs, ok := stmt.Lhs[i].(*ast.Ident) + if !ok || lhs.Obj == nil { + continue + } + + ast.Inspect(rhs, v.funcLitInspector(func(id *ast.Ident) { + v.closures[lhs.Obj] = id + })) + } + + default: + return true + } + + return false +} + +// funcLitInspector returns a function that can be passed to +// `ast.Inspect`. When a function literal that references a loop +// variable is found, the `onInvalidReference` function is called. +func (v *Visitor) funcLitInspector(onInvalidReference func(*ast.Ident)) func(ast.Node) bool { + return func(n ast.Node) bool { + funcLit, ok := n.(*ast.FuncLit) + if !ok { + return true + } + + ast.Inspect(funcLit.Body, v.findVariableCaptures(onInvalidReference)) + return false + } +} + +// findVariableCaptures returns a function that can be passed to +// `ast.Inspect`. When a reference to a loop variable is found, or +// when a function that is known to capture a loop variable by +// reference is called, the `onInvalidReference` function passed is +// called. +func (v *Visitor) findVariableCaptures(onInvalidReference func(*ast.Ident)) func(ast.Node) bool { + return func(n ast.Node) bool { + switch expr := n.(type) { + case *ast.Ident: + if expr.Obj == nil { + return true + } + + for _, loopVar := range v.loop.Vars { + if expr.Obj == loopVar.Obj { + onInvalidReference(expr) + return false + } + } + return false + + case *ast.CallExpr: + target, ok := expr.Fun.(*ast.Ident) + if !ok || target.Obj == nil { + return true + } + + if _, ok := v.closures[target.Obj]; ok { + onInvalidReference(target) + return false + } + + return true + } + + return true + } +} + +// addIssue adds a new issue in the `issues` field of the visitor +// associated with the identifier passed. The message is slightly +// different depending on whether the identifier is a loop variable +// directly, or invoking a closure that captures a loop variable by +// reference. In the latter case, the chain of calls that lead to the +// capture is included in the diagnostic. If a `//nolint` comment is +// associated with the use of this identifier, no issue is reported. +func (v *Visitor) addIssue(id *ast.Ident) { + if passesutil.HasNolintComment(v.pass, id, name) { + return + } + + var ( + chain = []*ast.Ident{id} + currentIdent = id + ok bool + ) + + for currentIdent, ok = v.closures[currentIdent.Obj]; ok; currentIdent, ok = v.closures[currentIdent.Obj] { + chain = append(chain, currentIdent) + } + + v.issues = append(v.issues, analysis.Diagnostic{ + Pos: id.Pos(), + Message: reportMessage(chain), + }) +} + +// reportMessage constructs the message to be reported to the user +// based on the chain of identifiers that lead to the loop variable +// being captured. The last identifier in the chain is always the loop +// variable being captured; everything else is the chain of closure +// calls that lead to the capture. +func reportMessage(chain []*ast.Ident) string { + if len(chain) == 1 { + return fmt.Sprintf("loop variable '%s' captured by reference, often leading to data races", chain[0].String()) + } + + functionName := chain[0] + loopVar := chain[len(chain)-1] + + var path []string + for i := 1; i < len(chain)-1; i++ { + path = append(path, fmt.Sprintf("'%s'", chain[i].String())) + } + + var pathMsg string + if len(path) > 0 { + pathMsg = fmt.Sprintf(" (via %s)", strings.Join(path, " -> ")) + } + + return fmt.Sprintf( + "'%s' function captures loop variable '%s'%s by reference, often leading to data races", + functionName.String(), + loopVar.String(), + pathMsg, + ) +} diff --git a/pkg/testutils/lint/passes/goroutinerefcapture/goroutinerefcapture_test.go b/pkg/testutils/lint/passes/goroutinerefcapture/goroutinerefcapture_test.go new file mode 100644 index 000000000000..e8e9d41e4bf4 --- /dev/null +++ b/pkg/testutils/lint/passes/goroutinerefcapture/goroutinerefcapture_test.go @@ -0,0 +1,34 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package goroutinerefcapture_test + +import ( + "testing" + + "github.com/cockroachdb/cockroach/pkg/build/bazel" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/lint/passes/goroutinerefcapture" + "github.com/cockroachdb/cockroach/pkg/testutils/skip" + "golang.org/x/tools/go/analysis/analysistest" +) + +func init() { + if bazel.BuiltWithBazel() { + bazel.SetGoEnv() + } +} + +func TestAnalyzer(t *testing.T) { + skip.UnderStress(t) + testdata := testutils.TestDataPath(t) + analysistest.TestData = func() string { return testdata } + analysistest.Run(t, testdata, goroutinerefcapture.Analyzer, "p") +} diff --git a/pkg/testutils/lint/passes/goroutinerefcapture/loop.go b/pkg/testutils/lint/passes/goroutinerefcapture/loop.go new file mode 100644 index 000000000000..6caa0b4da162 --- /dev/null +++ b/pkg/testutils/lint/passes/goroutinerefcapture/loop.go @@ -0,0 +1,74 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package goroutinerefcapture + +import ( + "fmt" + "go/ast" +) + +// Loop abstracts away the type of loop (`for` loop with index +// variable vs `range` loops) +type Loop struct { + Vars []*ast.Ident + Body *ast.BlockStmt +} + +// NewLoop creates a new Loop struct according to the node passed. If +// the node does not represent either a `for` loop or a `range` loop, +// this function will panic. +func NewLoop(n ast.Node) *Loop { + switch node := n.(type) { + case *ast.ForStmt: + return newForLoop(node) + case *ast.RangeStmt: + return newRange(node) + default: + panic(fmt.Errorf("unexpected loop node: %#v", n)) + } +} + +// IsEmpty returns whether the loop is empty for the purposes of this +// linter; in other words, whether there no loop variables, or whether +// the loop has zero statements. +func (l *Loop) IsEmpty() bool { + return len(l.Vars) == 0 || len(l.Body.List) == 0 +} + +func newForLoop(stmt *ast.ForStmt) *Loop { + loop := Loop{Body: stmt.Body} + + switch post := stmt.Post.(type) { + case *ast.AssignStmt: + for _, lhs := range post.Lhs { + loop.addVar(lhs) + } + + case *ast.IncDecStmt: + loop.addVar(post.X) + } + + return &loop +} + +func newRange(stmt *ast.RangeStmt) *Loop { + loop := Loop{Body: stmt.Body} + loop.addVar(stmt.Key) + loop.addVar(stmt.Value) + + return &loop +} + +func (l *Loop) addVar(e ast.Expr) { + if ident, ok := e.(*ast.Ident); ok && ident.Obj != nil { + l.Vars = append(l.Vars, ident) + } +} diff --git a/pkg/testutils/lint/passes/goroutinerefcapture/testdata/src/p/p.go b/pkg/testutils/lint/passes/goroutinerefcapture/testdata/src/p/p.go new file mode 100644 index 000000000000..450b5de212ae --- /dev/null +++ b/pkg/testutils/lint/passes/goroutinerefcapture/testdata/src/p/p.go @@ -0,0 +1,216 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package p + +import ( + "fmt" + "net" + "sync" + "testing" +) + +var ( + intID = func(n int) int { return n } + doWork = func() {} + runFunc = func(f func()) { f() } + + collection = []int{1, 2, 3} +) + +func OutOfScope() { + var i int + for range collection { + i++ + go func() { + intID(i) // valid data race, but out of scope for this linter + }() + } +} + +// TableDriven ensures that we are able to flag a common pattern in +// table-driven tests. If a Go routine is spawned while iterating over +// the test cases, it's easy to accidentally reference the test case +// variable, leading to flaky tests if the test cases run in parallel. +func TableDriven(t *testing.T) { + values := [][]byte{{0x08, 0x00, 0x00, 0xff, 0xff}} + + for _, tc := range values { + t.Run("", func(t *testing.T) { + w, _ := net.Pipe() + errChan := make(chan error, 1) + + go func() { + if _, err := w.Write(tc); err != nil { // want `loop variable 'tc' captured by reference` + errChan <- err + return + } + }() + }) + } +} + +// Conditional ensures that even when a Go routine is created in more +// syntactically complex subtrees, it's still flagged if it captures a +// loop variable. In this case, the code is technically safe since the +// Go routine is only created in the last iteration of the loop, but +// it is believed that the variable should not be captured either way +// to avoid the chance of introducing bugs when this code is changed +// (it's also not possible to statically determine when using a loop +// variable inside a Go routine is safe, so we err on the side of +// caution). +func Conditional() { + for i, n := range collection { + intID(n) + if i == len(collection)-1 { + go func() { + fmt.Printf("i = %d\n", i) // want `loop variable 'i' captured by reference` + }() + } + } + + for j := 0; j < 10; j++ { + go func() { + doWork() + fmt.Printf("done: %d\n", j) // want `loop variable 'j' captured by reference` + }() + } +} + +// FuncLitArg ensures that function literals (closures) passed as +// argument to a function call in a 'go' statement should also be +// flagged if they capture a loop variable. +func FuncLitArg() { + for _, n := range collection { + doWork() + go runFunc(func() { + intID(n) // want `loop variable 'n' captured by reference` + }) + + go intID(n) // this is OK + doWork() + } + + for j := 0; j < len(collection); j++ { + doWork() + go runFunc(func() { + intID(collection[j]) // want `loop variable 'j' captured by reference` + }) + + go intID(collection[j]) // this is OK + } +} + +// Synchronization is another example of a technically safe use of a +// loop variable in a Go routine that we decide to flag anyway. +func Synchronization() { + for _, n := range collection { + var wg sync.WaitGroup + go func() { + defer wg.Done() + intID(n) // want `loop variable 'n' captured by reference` + }() + + wg.Wait() + } +} + +// IndirectClosure makes sure that closures that capture loop +// variables cannot be called in a Go routine. +func IndirectClosure() { + for i := range collection { + badClosure := func() { fmt.Printf("finished iteration %d\n", i+1) } + goodClosure := func(i int) { fmt.Printf("finished iteration %d\n", i+1) } + + wrapper1 := func() { badClosure() } + wrapper2 := func() { wrapper1() } + wrapper3 := func() { goodClosure(i) } + + iCopy := i + go func() { + defer badClosure() // want `'badClosure' function captures loop variable 'i' by reference` + doWork() + + // referencing a closure without invoking it is fine + if badClosure != nil { + wrapper1() // want `'wrapper1' function captures loop variable 'i' \(via 'badClosure'\)` + doWork() + wrapper2() // want `'wrapper2' function captures loop variable 'i' \(via 'wrapper1' -> 'badClosure'\)` + + wrapper3() // want `'wrapper3' function captures loop variable 'i' by reference` + + // copying here does not solve the problem + k := i // want `loop variable 'i' captured by reference` + goodClosure(k) // still problematic + + goodClosure(iCopy) // this is OK + } + }() + + go badClosure() // want `'badClosure' function captures loop variable 'i' by reference` + go wrapper2() // want `'wrapper2' function captures loop variable 'i' \(via 'wrapper1' -> 'badClosure'\)` + } + + for j := 0; j < len(collection); j++ { + showProgress := func() { + fmt.Printf("finished iteration %d\n", j+1) + } + + go func() { + doWork() + showProgress() // want `'showProgress' function captures loop variable 'j' by reference` + }() + } +} + +// FixedFunction tests that common patterns to fix loop variable +// capture by reference in Go routines work: namely, passing the loop +// variable as an argument to the function called asynchronously; or +// creating a scoped copy of the loop variable within the loop. +func FixedFunction() { + for _, n := range collection { + doWork() + go func(n int) { + intID(n) // this is OK + }(n) + } + + for j := 0; j < len(collection); j++ { + j := j + go func() { + intID(j) // this is OK + }() + } +} + +// RespectsNolintComments makes sure that developers are able to +// silence the linter using their own judgement. +func RespectsNolintComments() { + for _, n := range collection { + var wg sync.WaitGroup + wg.Add(1) + + badClosure := func() { fmt.Printf("n = %d\n", n) } + + go func() { + defer wg.Done() + //nolint:goroutinerefcapture + intID(n) + + //nolint:goroutinerefcapture + badClosure() + }() + + //nolint:goroutinerefcapture + go badClosure() + + wg.Wait() + } +}