Skip to content

Commit

Permalink
deprecated function checker
Browse files Browse the repository at this point in the history
  • Loading branch information
notJoon committed Oct 7, 2024
1 parent ecce4c3 commit d46e471
Show file tree
Hide file tree
Showing 2 changed files with 371 additions and 0 deletions.
121 changes: 121 additions & 0 deletions internal/checker/deprecate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package checker

import (
"go/ast"
"go/token"
"strconv"
"strings"
)

// pkgPath -> funcName -> alternative
type deprecatedFuncMap map[string]map[string]string

// DeprecatedFunc represents a deprecated function.
type DeprecatedFunc struct {
Package string
Function string
Alternative string
Position token.Position
}

// DeprecatedFuncChecker checks for deprecated functions.
type DeprecatedFuncChecker struct {
deprecatedFuncs deprecatedFuncMap
}

func NewDeprecatedFuncChecker() *DeprecatedFuncChecker {
return &DeprecatedFuncChecker{
deprecatedFuncs: make(deprecatedFuncMap),
}
}

func (d *DeprecatedFuncChecker) Register(pkgName, funcName, alternative string) {
if _, ok := d.deprecatedFuncs[pkgName]; !ok {
d.deprecatedFuncs[pkgName] = make(map[string]string)
}
d.deprecatedFuncs[pkgName][funcName] = alternative
}

// Check checks a AST node for deprecated functions.
//
// TODO: use this in the linter rule implementation
func (d *DeprecatedFuncChecker) Check(
filename string,
node *ast.File,
fset *token.FileSet,
) ([]DeprecatedFunc, error) {
var found []DeprecatedFunc

packageAliases := make(map[string]string)
for _, imp := range node.Imports {
path, err := strconv.Unquote(imp.Path.Value)
if err != nil {
continue
}
name := ""
if imp.Name != nil {
name = imp.Name.Name
} else {
parts := strings.Split(path, "/")
name = parts[len(parts)-1]
}
packageAliases[name] = path
}

ast.Inspect(node, func(n ast.Node) bool {
call, ok := n.(*ast.CallExpr)
if !ok {
return true
}

switch fun := call.Fun.(type) {
case *ast.SelectorExpr:
ident, ok := fun.X.(*ast.Ident)
if !ok {
return true
}
pkgAlias := ident.Name
funcName := fun.Sel.Name

pkgPath, ok := packageAliases[pkgAlias]
if !ok {
// Not a package alias, possibly a method call
return true
}

if funcs, ok := d.deprecatedFuncs[pkgPath]; ok {
if alt, ok := funcs[funcName]; ok {
found = append(found, DeprecatedFunc{
Package: pkgPath,
Function: funcName,
Alternative: alt,
Position: fset.Position(call.Pos()),
})
}
}
case *ast.Ident:
// Handle functions imported via dot imports
funcName := fun.Name
// Check dot-imported packages
for alias, pkgPath := range packageAliases {
if alias != "." {
continue
}
if funcs, ok := d.deprecatedFuncs[pkgPath]; ok {
if alt, ok := funcs[funcName]; ok {
found = append(found, DeprecatedFunc{
Package: pkgPath,
Function: funcName,
Alternative: alt,
Position: fset.Position(call.Pos()),
})
break
}
}
}
}
return true
})

return found, nil
}
250 changes: 250 additions & 0 deletions internal/checker/deprecate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
package checker

import (
"go/parser"
"go/token"
"testing"

"github.com/stretchr/testify/assert"
)

func TestRegisterDeprecatedFunctions(t *testing.T) {
t.Parallel()
checker := NewDeprecatedFuncChecker()

checker.Register("fmt", "Println", "fmt.Print")
checker.Register("os", "Remove", "os.RemoveAll")

expected := deprecatedFuncMap{
"fmt": {"Println": "fmt.Print"},
"os": {"Remove": "os.RemoveAll"},
}

assert.Equal(t, expected, checker.deprecatedFuncs)
}

func TestCheck(t *testing.T) {
t.Parallel()
src := `
package main
import (
"fmt"
"os"
)
func main() {
fmt.Println("Hello, World!")
os.Remove("some_file.txt")
}
`

fset := token.NewFileSet()
node, err := parser.ParseFile(fset, "example.go", src, 0)
if err != nil {
t.Fatalf("Failed to parse file: %v", err)
}

checker := NewDeprecatedFuncChecker()
checker.Register("fmt", "Println", "fmt.Print")
checker.Register("os", "Remove", "os.RemoveAll")

deprecated, err := checker.Check("example.go", node, fset)
if err != nil {
t.Fatalf("Check failed with error: %v", err)
}

expected := []DeprecatedFunc{
{
Package: "fmt",
Function: "Println",
Alternative: "fmt.Print",
Position: token.Position{
Filename: "example.go",
Offset: 55,
Line: 10,
Column: 2,
},
},
{
Package: "os",
Function: "Remove",
Alternative: "os.RemoveAll",
Position: token.Position{
Filename: "example.go",
Offset: 85,
Line: 11,
Column: 2,
},
},
}

assert.Equal(t, expected, deprecated)
}

func TestCheckNoDeprecated(t *testing.T) {
t.Parallel()
src := `
package main
import "fmt"
func main() {
fmt.Printf("Hello, %s\n", "World")
}
`

fset := token.NewFileSet()
node, err := parser.ParseFile(fset, "example.go", src, 0)
if err != nil {
t.Fatalf("Failed to parse file: %v", err)
}

checker := NewDeprecatedFuncChecker()
checker.Register("fmt", "Println", "fmt.Print")
checker.Register("os", "Remove", "os.RemoveAll")

deprecated, err := checker.Check("example.go", node, fset)
if err != nil {
t.Fatalf("Check failed with error: %v", err)
}

assert.Equal(t, 0, len(deprecated))
}

func TestCheckMultipleDeprecatedCalls(t *testing.T) {
t.Parallel()
src := `
package main
import (
"fmt"
"os"
)
func main() {
fmt.Println("Hello")
fmt.Println("World")
os.Remove("file1.txt")
os.Remove("file2.txt")
}
`

fset := token.NewFileSet()
node, err := parser.ParseFile(fset, "example.go", src, 0)
if err != nil {
t.Fatalf("Failed to parse file: %v", err)
}

checker := NewDeprecatedFuncChecker()
checker.Register("fmt", "Println", "fmt.Print")
checker.Register("os", "Remove", "os.RemoveAll")

deprecated, err := checker.Check("example.go", node, fset)
if err != nil {
t.Fatalf("Check failed with error: %v", err)
}

expected := []DeprecatedFunc{
{Package: "fmt", Function: "Println", Alternative: "fmt.Print"},
{Package: "fmt", Function: "Println", Alternative: "fmt.Print"},
{Package: "os", Function: "Remove", Alternative: "os.RemoveAll"},
{Package: "os", Function: "Remove", Alternative: "os.RemoveAll"},
}

assert.Equal(t, len(expected), len(deprecated))
for i, exp := range expected {
assertDeprecatedFuncEqual(t, exp, deprecated[i])
}
}

func TestDeprecatedFuncCheckerWithAlias(t *testing.T) {
t.Parallel()

c := NewDeprecatedFuncChecker()
c.Register("math", "Sqrt", "math.Pow")

const src = `
package main
import (
m "math"
"fmt"
)
type MyStruct struct{}
func (s *MyStruct) Method() {}
func main() {
result := m.Sqrt(42)
_ = result
fmt.Println("Hello")
s := &MyStruct{}
s.Method()
}
`

fset := token.NewFileSet()
node, err := parser.ParseFile(fset, "sample.go", src, parser.ParseComments)
assert.NoError(t, err)

results, err := c.Check("sample.go", node, fset)
assert.NoError(t, err)

assert.Equal(t, 1, len(results))

expected := DeprecatedFunc{
Package: "math",
Function: "Sqrt",
Alternative: "math.Pow",
}

assertDeprecatedFuncEqual(t, expected, results[0])
}

func TestDeprecatedFuncChecker_Check_DotImport(t *testing.T) {
t.Parallel()

checker := NewDeprecatedFuncChecker()
checker.Register("fmt", "Println", "Use fmt.Print instead")

src := `
package main
import . "fmt"
func main() {
Println("Hello, World!")
}
`

fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "test.go", src, 0)
assert.NoError(t, err)

found, err := checker.Check("test.go", f, fset)
assert.NoError(t, err)

assert.Equal(t, 1, len(found))

if len(found) > 0 {
df := found[0]
if df.Package != "fmt" || df.Function != "Println" || df.Alternative != "Use fmt.Print instead" {
t.Errorf("unexpected deprecated function info: %+v", df)
}
}
}

func assertDeprecatedFuncEqual(t *testing.T, expected, actual DeprecatedFunc) {
t.Helper()
assert.Equal(t, expected.Package, actual.Package)
assert.Equal(t, expected.Function, actual.Function)
assert.Equal(t, expected.Alternative, actual.Alternative)
assert.NotEmpty(t, actual.Position.Filename)
assert.Greater(t, actual.Position.Offset, 0)
assert.Greater(t, actual.Position.Line, 0)
assert.Greater(t, actual.Position.Column, 0)
}

0 comments on commit d46e471

Please sign in to comment.