diff --git a/internal/source/source.go b/internal/source/source.go index 4dbc1bc..2686ded 100644 --- a/internal/source/source.go +++ b/internal/source/source.go @@ -10,12 +10,8 @@ import ( "go/token" "os" "runtime" - "strconv" - "strings" ) -const baseStackIndex = 1 - // FormattedCallExprArg returns the argument from an ast.CallExpr at the // index in the call stack. The argument is formatted using FormatNode. func FormattedCallExprArg(stackIndex int, argPos int) (string, error) { @@ -32,28 +28,26 @@ func FormattedCallExprArg(stackIndex int, argPos int) (string, error) { // CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at // the index in the call stack. func CallExprArgs(stackIndex int) ([]ast.Expr, error) { - _, filename, lineNum, ok := runtime.Caller(baseStackIndex + stackIndex) + _, filename, line, ok := runtime.Caller(stackIndex + 1) if !ok { return nil, errors.New("failed to get call stack") } - debug("call stack position: %s:%d", filename, lineNum) + debug("call stack position: %s:%d", filename, line) - node, err := getNodeAtLine(filename, lineNum) - if err != nil { - return nil, err - } - debug("found node: %s", debugFormatNode{node}) - - return getCallExprArgs(node) -} - -func getNodeAtLine(filename string, lineNum int) (ast.Node, error) { fileset := token.NewFileSet() astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors) if err != nil { return nil, fmt.Errorf("failed to parse source file %s: %w", filename, err) } + expr, err := getCallExprArgs(fileset, astFile, line) + if err != nil { + return nil, fmt.Errorf("call from %s:%d: %w", filename, line, err) + } + return expr, nil +} + +func getNodeAtLine(fileset *token.FileSet, astFile *ast.File, lineNum int) (ast.Node, error) { if node := scanToLine(fileset, astFile, lineNum); node != nil { return node, nil } @@ -63,8 +57,7 @@ func getNodeAtLine(filename string, lineNum int) (ast.Node, error) { return node, err } } - return nil, fmt.Errorf( - "failed to find an expression on line %d in %s", lineNum, filename) + return nil, nil } func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node { @@ -73,7 +66,7 @@ func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node { switch { case node == nil || matchedNode != nil: return false - case nodePosition(fileset, node).Line == lineNum: + case fileset.Position(node.Pos()).Line == lineNum: matchedNode = node return false } @@ -82,46 +75,17 @@ func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node { return matchedNode } -// In golang 1.9 the line number changed from being the line where the statement -// ended to the line where the statement began. -func nodePosition(fileset *token.FileSet, node ast.Node) token.Position { - if goVersionBefore19 { - return fileset.Position(node.End()) - } - return fileset.Position(node.Pos()) -} - -// GoVersionLessThan returns true if runtime.Version() is semantically less than -// version major.minor. Returns false if a release version can not be parsed from -// runtime.Version(). -func GoVersionLessThan(major, minor int64) bool { - version := runtime.Version() - // not a release version - if !strings.HasPrefix(version, "go") { - return false - } - version = strings.TrimPrefix(version, "go") - parts := strings.Split(version, ".") - if len(parts) < 2 { - return false - } - rMajor, err := strconv.ParseInt(parts[0], 10, 32) - if err != nil { - return false - } - if rMajor != major { - return rMajor < major - } - rMinor, err := strconv.ParseInt(parts[1], 10, 32) - if err != nil { - return false +func getCallExprArgs(fileset *token.FileSet, astFile *ast.File, line int) ([]ast.Expr, error) { + node, err := getNodeAtLine(fileset, astFile, line) + switch { + case err != nil: + return nil, err + case node == nil: + return nil, fmt.Errorf("failed to find an expression") } - return rMinor < minor -} -var goVersionBefore19 = GoVersionLessThan(1, 9) + debug("found node: %s", debugFormatNode{node}) -func getCallExprArgs(node ast.Node) ([]ast.Expr, error) { visitor := &callExprVisitor{} ast.Walk(visitor, node) if visitor.expr == nil { @@ -172,6 +136,9 @@ type debugFormatNode struct { } func (n debugFormatNode) String() string { + if n.Node == nil { + return "none" + } out, err := FormatNode(n.Node) if err != nil { return fmt.Sprintf("failed to format %s: %s", n.Node, err) diff --git a/internal/source/version.go b/internal/source/version.go new file mode 100644 index 0000000..5fa8a90 --- /dev/null +++ b/internal/source/version.go @@ -0,0 +1,35 @@ +package source + +import ( + "runtime" + "strconv" + "strings" +) + +// GoVersionLessThan returns true if runtime.Version() is semantically less than +// version major.minor. Returns false if a release version can not be parsed from +// runtime.Version(). +func GoVersionLessThan(major, minor int64) bool { + version := runtime.Version() + // not a release version + if !strings.HasPrefix(version, "go") { + return false + } + version = strings.TrimPrefix(version, "go") + parts := strings.Split(version, ".") + if len(parts) < 2 { + return false + } + rMajor, err := strconv.ParseInt(parts[0], 10, 32) + if err != nil { + return false + } + if rMajor != major { + return rMajor < major + } + rMinor, err := strconv.ParseInt(parts[1], 10, 32) + if err != nil { + return false + } + return rMinor < minor +}