diff --git a/mockgen/parse.go b/mockgen/parse.go index 808fd144..8f31a71d 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -17,6 +17,7 @@ package main // This file contains the model construction by parsing source files. import ( + "errors" "flag" "fmt" "go/ast" @@ -49,7 +50,7 @@ func sourceMode(source string) (*model.Package, error) { return nil, fmt.Errorf("failed getting source directory: %v", err) } - packageImport, err := parsePackageImport(source, srcDir) + packageImport, err := parsePackageImport(srcDir) if err != nil { return nil, err } @@ -539,28 +540,29 @@ func packageNameOfDir(srcDir string) (string, error) { return "", fmt.Errorf("go source file not found %s", srcDir) } - packageImport, err := parsePackageImport(goFilePath, srcDir) + packageImport, err := parsePackageImport(srcDir) if err != nil { return "", err } return packageImport, nil } +var ErrOutsideGoPath = errors.New("Source directory is outside GOPATH") + // parseImportPackage get package import path via source file // an alternative implementation is to use: // cfg := &packages.Config{Mode: packages.NeedName, Tests: true, Dir: srcDir} // pkgs, err := packages.Load(cfg, "file="+source) // However, it will call "go list" and slow down the performance -func parsePackageImport(source, srcDir string) (string, error) { +func parsePackageImport(srcDir string) (string, error) { moduleMode := os.Getenv("GO111MODULE") // trying to find the module if moduleMode != "off" { currentDir := srcDir - var modulePath string for { dat, err := ioutil.ReadFile(filepath.Join(currentDir, "go.mod")) if os.IsNotExist(err) { - if currentDir == "." || currentDir == "/" { + if currentDir == "/" { break } currentDir = filepath.Dir(currentDir) @@ -568,23 +570,18 @@ func parsePackageImport(source, srcDir string) (string, error) { } else if err != nil { return "", err } - modulePath = modfile.ModulePath(dat) - break - } - if modulePath != "" { + modulePath := modfile.ModulePath(dat) return filepath.Join(modulePath, strings.TrimPrefix(srcDir, currentDir)), nil } } - if moduleMode != "on" { - goPath := os.Getenv("GOPATH") - if goPath == "" { - return "", fmt.Errorf("GOPATH is not set") - } - sourceRoot := filepath.Join(goPath, "src") + "/" - if !strings.HasPrefix(srcDir, sourceRoot) { - return "", fmt.Errorf("%s is outside GOPATH %s", srcDir, goPath) - } - return strings.TrimPrefix(srcDir, sourceRoot), nil + // fall back to GOPATH mode + goPath := os.Getenv("GOPATH") + if goPath == "" { + return "", fmt.Errorf("GOPATH is not set") + } + sourceRoot := filepath.Join(goPath, "src") + "/" + if !strings.HasPrefix(srcDir, sourceRoot) { + return "", ErrOutsideGoPath } - return "", fmt.Errorf("cannot find package path for %s", srcDir) + return strings.TrimPrefix(srcDir, sourceRoot), nil } diff --git a/mockgen/parse_test.go b/mockgen/parse_test.go index dab22fa5..08b51abc 100644 --- a/mockgen/parse_test.go +++ b/mockgen/parse_test.go @@ -4,6 +4,9 @@ import ( "go/ast" "go/parser" "go/token" + "io/ioutil" + "os" + "path/filepath" "testing" ) @@ -113,3 +116,69 @@ func Benchmark_parseFile(b *testing.B) { sourceMode(source) } } + +func TestParsePackageImport(t *testing.T) { + for _, testCase := range []struct { + name string + envs map[string]string + dir string + pkgPath string + err error + }{ + { + name: "go mod default", + envs: map[string]string{"GO111MODULE": ""}, + dir: "testdata/gomod/bar", + pkgPath: "github.com/golang/foo/bar", + }, + { + name: "go mod off", + envs: map[string]string{"GO111MODULE": "off", "GOPATH": "testdata/gopath"}, + dir: "testdata/gopath/src/example.com/foo", + pkgPath: "example.com/foo", + }, + { + name: "outside GOPATH", + envs: map[string]string{"GO111MODULE": "off", "GOPATH": "testdata/gopath"}, + dir: "testdata", + err: ErrOutsideGoPath, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + for key, value := range testCase.envs { + os.Setenv(key, value) + } + pkgPath, err := parsePackageImport(testCase.dir) + if err != testCase.err { + t.Errorf("expect %v, got %v", testCase.err, err) + } + if pkgPath != testCase.pkgPath { + t.Errorf("expect %s, got %s", testCase.pkgPath, pkgPath) + } + }) + } +} + +func TestParsePackageImport_FallbackGoPath(t *testing.T) { + goPath, err := ioutil.TempDir("", "gopath") + if err != nil { + t.Error(err) + } + defer func() { + if err = os.RemoveAll(goPath); err != nil { + t.Error(err) + } + }() + srcDir := filepath.Join(goPath, "src/example.com/foo") + err = os.MkdirAll(srcDir, 0755) + if err != nil { + t.Error(err) + } + os.Setenv("GOPATH", goPath) + os.Setenv("GO111MODULE", "on") + pkgPath, err := parsePackageImport(srcDir) + expected := "example.com/foo" + if pkgPath != expected { + t.Errorf("expect %s, got %s", expected, pkgPath) + } +} diff --git a/mockgen/testdata/gomod/bar/bar.go b/mockgen/testdata/gomod/bar/bar.go new file mode 100644 index 00000000..ddac0faf --- /dev/null +++ b/mockgen/testdata/gomod/bar/bar.go @@ -0,0 +1 @@ +package bar diff --git a/mockgen/testdata/gomod/go.mod b/mockgen/testdata/gomod/go.mod new file mode 100644 index 00000000..d80591b4 --- /dev/null +++ b/mockgen/testdata/gomod/go.mod @@ -0,0 +1 @@ +module github.com/golang/foo