From b411d9d1448a190854121b719544c33cf1f382d0 Mon Sep 17 00:00:00 2001 From: meteorgan Date: Wed, 17 Apr 2024 21:41:08 +0800 Subject: [PATCH] fix(type assert): add more check for type assert --- errcheck/errcheck.go | 37 ++++++++++++++++++++++++---- errcheck/testdata/src/assert/main.go | 17 +++++++++++++ 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/errcheck/errcheck.go b/errcheck/errcheck.go index 437f8c2..a95a335 100644 --- a/errcheck/errcheck.go +++ b/errcheck/errcheck.go @@ -586,18 +586,29 @@ func (v *visitor) Visit(node ast.Node) ast.Visitor { for _, name := range vspec.Names { lhs = append(lhs, ast.Expr(name)) } - v.checkAssignment(lhs, vspec.Values) + followed := v.checkAssignment(lhs, vspec.Values) + if !followed { + return nil + } } case *ast.AssignStmt: - v.checkAssignment(stmt.Lhs, stmt.Rhs) + followed := v.checkAssignment(stmt.Lhs, stmt.Rhs) + if !followed { + return nil + } + + case *ast.TypeAssertExpr: + v.checkAssertExpr(stmt) + return nil default: } return v } -func (v *visitor) checkAssignment(lhs, rhs []ast.Expr) { +func (v *visitor) checkAssignment(lhs, rhs []ast.Expr) (followed bool) { + followed = true if len(rhs) == 1 { // single value on rhs; check against lhs identifiers if call, ok := rhs[0].(*ast.CallExpr); ok { @@ -619,11 +630,11 @@ func (v *visitor) checkAssignment(lhs, rhs []ast.Expr) { } } else if assert, ok := rhs[0].(*ast.TypeAssertExpr); ok { if !v.asserts { - return + return false } if assert.Type == nil { // type switch - return + return false } if len(lhs) < 2 { // assertion result not read @@ -632,6 +643,7 @@ func (v *visitor) checkAssignment(lhs, rhs []ast.Expr) { // assertion result ignored v.addErrorAtPosition(id.NamePos, nil) } + return false } } else { // multiple value on rhs; in this case a call can't return @@ -661,6 +673,21 @@ func (v *visitor) checkAssignment(lhs, rhs []ast.Expr) { } } } + + return +} + +func (v *visitor) checkAssertExpr(expr *ast.TypeAssertExpr) bool { + if !v.asserts { + return false + } + if expr.Type == nil { + // type switch + return false + } + v.addErrorAtPosition(expr.Pos(), nil) + + return false } func isErrorType(t types.Type) bool { diff --git a/errcheck/testdata/src/assert/main.go b/errcheck/testdata/src/assert/main.go index 77feb6d..7eacbf4 100644 --- a/errcheck/testdata/src/assert/main.go +++ b/errcheck/testdata/src/assert/main.go @@ -3,4 +3,21 @@ package assert func main() { var i interface{} _ = i.(string) // want "unchecked error" + + handleInterface(i.(string)) // want "unchecked error" + + if i.(string) == "hello" { // want "unchecked error" + // + } + + switch i.(type) { + case string: + case int: + _ = i.(int) // want "unchecked error" + case nil: + } +} + +func handleInterface(i interface{}) string { + return i.(string) // want "unchecked error" }