-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
371 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
package checker | ||
|
||
import ( | ||
"go/ast" | ||
"go/token" | ||
"strconv" | ||
"strings" | ||
) | ||
|
||
// pkgPath -> funcName -> alternative | ||
type deprecatedFuncMap map[string]map[string]string | ||
|
||
// DeprecatedFunc represents a deprecated function. | ||
type DeprecatedFunc struct { | ||
Package string | ||
Function string | ||
Alternative string | ||
Position token.Position | ||
} | ||
|
||
// DeprecatedFuncChecker checks for deprecated functions. | ||
type DeprecatedFuncChecker struct { | ||
deprecatedFuncs deprecatedFuncMap | ||
} | ||
|
||
func NewDeprecatedFuncChecker() *DeprecatedFuncChecker { | ||
return &DeprecatedFuncChecker{ | ||
deprecatedFuncs: make(deprecatedFuncMap), | ||
} | ||
} | ||
|
||
func (d *DeprecatedFuncChecker) Register(pkgName, funcName, alternative string) { | ||
if _, ok := d.deprecatedFuncs[pkgName]; !ok { | ||
d.deprecatedFuncs[pkgName] = make(map[string]string) | ||
} | ||
d.deprecatedFuncs[pkgName][funcName] = alternative | ||
} | ||
|
||
// Check checks a AST node for deprecated functions. | ||
// | ||
// TODO: use this in the linter rule implementation | ||
func (d *DeprecatedFuncChecker) Check( | ||
filename string, | ||
node *ast.File, | ||
fset *token.FileSet, | ||
) ([]DeprecatedFunc, error) { | ||
var found []DeprecatedFunc | ||
|
||
packageAliases := make(map[string]string) | ||
for _, imp := range node.Imports { | ||
path, err := strconv.Unquote(imp.Path.Value) | ||
if err != nil { | ||
continue | ||
} | ||
name := "" | ||
if imp.Name != nil { | ||
name = imp.Name.Name | ||
} else { | ||
parts := strings.Split(path, "/") | ||
name = parts[len(parts)-1] | ||
} | ||
packageAliases[name] = path | ||
} | ||
|
||
ast.Inspect(node, func(n ast.Node) bool { | ||
call, ok := n.(*ast.CallExpr) | ||
if !ok { | ||
return true | ||
} | ||
|
||
switch fun := call.Fun.(type) { | ||
case *ast.SelectorExpr: | ||
ident, ok := fun.X.(*ast.Ident) | ||
if !ok { | ||
return true | ||
} | ||
pkgAlias := ident.Name | ||
funcName := fun.Sel.Name | ||
|
||
pkgPath, ok := packageAliases[pkgAlias] | ||
if !ok { | ||
// Not a package alias, possibly a method call | ||
return true | ||
} | ||
|
||
if funcs, ok := d.deprecatedFuncs[pkgPath]; ok { | ||
if alt, ok := funcs[funcName]; ok { | ||
found = append(found, DeprecatedFunc{ | ||
Package: pkgPath, | ||
Function: funcName, | ||
Alternative: alt, | ||
Position: fset.Position(call.Pos()), | ||
}) | ||
} | ||
} | ||
case *ast.Ident: | ||
// Handle functions imported via dot imports | ||
funcName := fun.Name | ||
// Check dot-imported packages | ||
for alias, pkgPath := range packageAliases { | ||
if alias != "." { | ||
continue | ||
} | ||
if funcs, ok := d.deprecatedFuncs[pkgPath]; ok { | ||
if alt, ok := funcs[funcName]; ok { | ||
found = append(found, DeprecatedFunc{ | ||
Package: pkgPath, | ||
Function: funcName, | ||
Alternative: alt, | ||
Position: fset.Position(call.Pos()), | ||
}) | ||
break | ||
} | ||
} | ||
} | ||
} | ||
return true | ||
}) | ||
|
||
return found, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,250 @@ | ||
package checker | ||
|
||
import ( | ||
"go/parser" | ||
"go/token" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestRegisterDeprecatedFunctions(t *testing.T) { | ||
t.Parallel() | ||
checker := NewDeprecatedFuncChecker() | ||
|
||
checker.Register("fmt", "Println", "fmt.Print") | ||
checker.Register("os", "Remove", "os.RemoveAll") | ||
|
||
expected := deprecatedFuncMap{ | ||
"fmt": {"Println": "fmt.Print"}, | ||
"os": {"Remove": "os.RemoveAll"}, | ||
} | ||
|
||
assert.Equal(t, expected, checker.deprecatedFuncs) | ||
} | ||
|
||
func TestCheck(t *testing.T) { | ||
t.Parallel() | ||
src := ` | ||
package main | ||
import ( | ||
"fmt" | ||
"os" | ||
) | ||
func main() { | ||
fmt.Println("Hello, World!") | ||
os.Remove("some_file.txt") | ||
} | ||
` | ||
|
||
fset := token.NewFileSet() | ||
node, err := parser.ParseFile(fset, "example.go", src, 0) | ||
if err != nil { | ||
t.Fatalf("Failed to parse file: %v", err) | ||
} | ||
|
||
checker := NewDeprecatedFuncChecker() | ||
checker.Register("fmt", "Println", "fmt.Print") | ||
checker.Register("os", "Remove", "os.RemoveAll") | ||
|
||
deprecated, err := checker.Check("example.go", node, fset) | ||
if err != nil { | ||
t.Fatalf("Check failed with error: %v", err) | ||
} | ||
|
||
expected := []DeprecatedFunc{ | ||
{ | ||
Package: "fmt", | ||
Function: "Println", | ||
Alternative: "fmt.Print", | ||
Position: token.Position{ | ||
Filename: "example.go", | ||
Offset: 55, | ||
Line: 10, | ||
Column: 2, | ||
}, | ||
}, | ||
{ | ||
Package: "os", | ||
Function: "Remove", | ||
Alternative: "os.RemoveAll", | ||
Position: token.Position{ | ||
Filename: "example.go", | ||
Offset: 85, | ||
Line: 11, | ||
Column: 2, | ||
}, | ||
}, | ||
} | ||
|
||
assert.Equal(t, expected, deprecated) | ||
} | ||
|
||
func TestCheckNoDeprecated(t *testing.T) { | ||
t.Parallel() | ||
src := ` | ||
package main | ||
import "fmt" | ||
func main() { | ||
fmt.Printf("Hello, %s\n", "World") | ||
} | ||
` | ||
|
||
fset := token.NewFileSet() | ||
node, err := parser.ParseFile(fset, "example.go", src, 0) | ||
if err != nil { | ||
t.Fatalf("Failed to parse file: %v", err) | ||
} | ||
|
||
checker := NewDeprecatedFuncChecker() | ||
checker.Register("fmt", "Println", "fmt.Print") | ||
checker.Register("os", "Remove", "os.RemoveAll") | ||
|
||
deprecated, err := checker.Check("example.go", node, fset) | ||
if err != nil { | ||
t.Fatalf("Check failed with error: %v", err) | ||
} | ||
|
||
assert.Equal(t, 0, len(deprecated)) | ||
} | ||
|
||
func TestCheckMultipleDeprecatedCalls(t *testing.T) { | ||
t.Parallel() | ||
src := ` | ||
package main | ||
import ( | ||
"fmt" | ||
"os" | ||
) | ||
func main() { | ||
fmt.Println("Hello") | ||
fmt.Println("World") | ||
os.Remove("file1.txt") | ||
os.Remove("file2.txt") | ||
} | ||
` | ||
|
||
fset := token.NewFileSet() | ||
node, err := parser.ParseFile(fset, "example.go", src, 0) | ||
if err != nil { | ||
t.Fatalf("Failed to parse file: %v", err) | ||
} | ||
|
||
checker := NewDeprecatedFuncChecker() | ||
checker.Register("fmt", "Println", "fmt.Print") | ||
checker.Register("os", "Remove", "os.RemoveAll") | ||
|
||
deprecated, err := checker.Check("example.go", node, fset) | ||
if err != nil { | ||
t.Fatalf("Check failed with error: %v", err) | ||
} | ||
|
||
expected := []DeprecatedFunc{ | ||
{Package: "fmt", Function: "Println", Alternative: "fmt.Print"}, | ||
{Package: "fmt", Function: "Println", Alternative: "fmt.Print"}, | ||
{Package: "os", Function: "Remove", Alternative: "os.RemoveAll"}, | ||
{Package: "os", Function: "Remove", Alternative: "os.RemoveAll"}, | ||
} | ||
|
||
assert.Equal(t, len(expected), len(deprecated)) | ||
for i, exp := range expected { | ||
assertDeprecatedFuncEqual(t, exp, deprecated[i]) | ||
} | ||
} | ||
|
||
func TestDeprecatedFuncCheckerWithAlias(t *testing.T) { | ||
t.Parallel() | ||
|
||
c := NewDeprecatedFuncChecker() | ||
c.Register("math", "Sqrt", "math.Pow") | ||
|
||
const src = ` | ||
package main | ||
import ( | ||
m "math" | ||
"fmt" | ||
) | ||
type MyStruct struct{} | ||
func (s *MyStruct) Method() {} | ||
func main() { | ||
result := m.Sqrt(42) | ||
_ = result | ||
fmt.Println("Hello") | ||
s := &MyStruct{} | ||
s.Method() | ||
} | ||
` | ||
|
||
fset := token.NewFileSet() | ||
node, err := parser.ParseFile(fset, "sample.go", src, parser.ParseComments) | ||
assert.NoError(t, err) | ||
|
||
results, err := c.Check("sample.go", node, fset) | ||
assert.NoError(t, err) | ||
|
||
assert.Equal(t, 1, len(results)) | ||
|
||
expected := DeprecatedFunc{ | ||
Package: "math", | ||
Function: "Sqrt", | ||
Alternative: "math.Pow", | ||
} | ||
|
||
assertDeprecatedFuncEqual(t, expected, results[0]) | ||
} | ||
|
||
func TestDeprecatedFuncChecker_Check_DotImport(t *testing.T) { | ||
t.Parallel() | ||
|
||
checker := NewDeprecatedFuncChecker() | ||
checker.Register("fmt", "Println", "Use fmt.Print instead") | ||
|
||
src := ` | ||
package main | ||
import . "fmt" | ||
func main() { | ||
Println("Hello, World!") | ||
} | ||
` | ||
|
||
fset := token.NewFileSet() | ||
f, err := parser.ParseFile(fset, "test.go", src, 0) | ||
assert.NoError(t, err) | ||
|
||
found, err := checker.Check("test.go", f, fset) | ||
assert.NoError(t, err) | ||
|
||
assert.Equal(t, 1, len(found)) | ||
|
||
if len(found) > 0 { | ||
df := found[0] | ||
if df.Package != "fmt" || df.Function != "Println" || df.Alternative != "Use fmt.Print instead" { | ||
t.Errorf("unexpected deprecated function info: %+v", df) | ||
} | ||
} | ||
} | ||
|
||
func assertDeprecatedFuncEqual(t *testing.T, expected, actual DeprecatedFunc) { | ||
t.Helper() | ||
assert.Equal(t, expected.Package, actual.Package) | ||
assert.Equal(t, expected.Function, actual.Function) | ||
assert.Equal(t, expected.Alternative, actual.Alternative) | ||
assert.NotEmpty(t, actual.Position.Filename) | ||
assert.Greater(t, actual.Position.Offset, 0) | ||
assert.Greater(t, actual.Position.Line, 0) | ||
assert.Greater(t, actual.Position.Column, 0) | ||
} |