diff --git a/mockgen/gob.go b/mockgen/gob.go new file mode 100644 index 0000000..b5ab066 --- /dev/null +++ b/mockgen/gob.go @@ -0,0 +1,21 @@ +package main + +import ( + "encoding/gob" + "os" + + "go.uber.org/mock/mockgen/model" +) + +func gobMode(path string) (*model.Package, error) { + in, err := os.Open(path) + if err != nil { + return nil, err + } + defer in.Close() + var pkg model.Package + if err := gob.NewDecoder(in).Decode(&pkg); err != nil { + return nil, err + } + return &pkg, nil +} diff --git a/mockgen/gob_test.go b/mockgen/gob_test.go new file mode 100644 index 0000000..e1e5ec1 --- /dev/null +++ b/mockgen/gob_test.go @@ -0,0 +1,31 @@ +package main + +import ( + "encoding/gob" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGobMode(t *testing.T) { + + // Encode a package to a temporary gob. + parser := packageModeParser{} + want, err := parser.parsePackage( + "go.uber.org/mock/mockgen/internal/tests/package_mode" /* package name */, + []string{ "Human", "Earth" } /* ifaces */, + ) + path := filepath.Join(t.TempDir(), "model.gob") + outfile, err := os.Create(path) + require.NoError(t, err) + require.NoError(t, gob.NewEncoder(outfile).Encode(want)) + outfile.Close() + + // Ensure gobMode loads it correctly. + got, err := gobMode(path) + require.NoError(t, err) + assert.Equal(t, want, got) +} diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index 75c831a..b5365de 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -69,6 +69,7 @@ var ( imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.") auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.") excludeInterfaces = flag.String("exclude_interfaces", "", "(source mode) Comma-separated names of interfaces to be excluded") + modelGob = flag.String("model_gob", "", "Skip package/source loading entirely and use the gob encoded model.Package at the given path") debugParser = flag.Bool("debug_parser", false, "Print out parser results only.") showVersion = flag.Bool("version", false, "Print version.") @@ -88,7 +89,9 @@ func main() { var pkg *model.Package var err error var packageName string - if *source != "" { + if *modelGob != "" { + pkg, err = gobMode(*modelGob) + } else if *source != "" { pkg, err = sourceMode(*source) } else { if flag.NArg() != 2 {