diff --git a/gen.go b/gen.go index 8cb50c7..ed2bf24 100644 --- a/gen.go +++ b/gen.go @@ -3,15 +3,24 @@ package typegen import ( "fmt" "io" + "math/big" "reflect" "strings" "text/template" + + "github.com/ipfs/go-cid" ) const MaxLength = 8192 const ByteArrayMaxLen = 2 << 20 +var ( + cidType = reflect.TypeOf(cid.Cid{}) + bigIntType = reflect.TypeOf(big.Int{}) + deferredType = reflect.TypeOf(Deferred{}) +) + func doTemplate(w io.Writer, info interface{}, templ string) error { t := template.Must(template.New(""). Funcs(template.FuncMap{ @@ -28,10 +37,19 @@ func doTemplate(w io.Writer, info interface{}, templ string) error { return t.Execute(w, info) } -func PrintHeaderAndUtilityMethods(w io.Writer, pkg string) error { +func PrintHeaderAndUtilityMethods(w io.Writer, pkg string, typeInfos []*GenTypeInfo) error { + var imports []Import + for _, gti := range typeInfos { + imports = append(imports, gti.Imports()...) + } + + imports = append(imports, defaultImports...) + imports = dedupImports(imports) + data := struct { Package string - }{pkg} + Imports []Import + }{pkg, imports} return doTemplate(w, data, `// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. package {{ .Package }} @@ -39,8 +57,9 @@ package {{ .Package }} import ( "fmt" "io" - cbg "github.com/whyrusleeping/cbor-gen" - xerrors "golang.org/x/xerrors" + +{{ range .Imports }}{{ .Name }} "{{ .PkgPath }}" +{{ end }} ) @@ -69,13 +88,14 @@ func typeName(pkg string, t reflect.Type) string { case reflect.Map: return "map[" + typeName(pkg, t.Key()) + "]" + typeName(pkg, t.Elem()) default: - name := t.String() - if t.PkgPath() == "github.com/whyrusleeping/cbor-gen" { - name = "cbg." + strings.TrimPrefix(name, "typegen.") - } else { - name = strings.TrimPrefix(name, pkg+".") + pkgPath := t.PkgPath() + if pkgPath == "" { + // It's a built-in. + return t.String() + } else if pkgPath == pkg { + return t.Name() } - return name + return fmt.Sprintf("%s.%s", resolvePkgName(pkgPath, t.String()), t.Name()) } } @@ -100,6 +120,25 @@ type GenTypeInfo struct { Fields []Field } +func (gti *GenTypeInfo) Imports() []Import { + var imports []Import + for _, f := range gti.Fields { + switch f.Type.Kind() { + case reflect.Struct: + if !f.Pointer && f.Type != bigIntType { + continue + } + if f.Type == cidType { + continue + } + case reflect.Bool: + continue + } + imports = append(imports, ImportsForType(f.Pkg, f.Type)...) + } + return imports +} + func (gti *GenTypeInfo) NeedsScratch() bool { for _, f := range gti.Fields { switch f.Type.Kind() { @@ -113,11 +152,7 @@ func (gti *GenTypeInfo) NeedsScratch() bool { return true case reflect.Struct: - fname := f.Type.PkgPath() + "." + f.Type.Name() - switch fname { - case "math/big.Int": - return true - case "github.com/ipfs/go-cid.Cid": + if f.Type == bigIntType || f.Type == cidType { return true } // nope @@ -132,9 +167,11 @@ func nameIsExported(name string) bool { return strings.ToUpper(name[0:1]) == name[0:1] } -func ParseTypeInfo(pkg string, i interface{}) (*GenTypeInfo, error) { +func ParseTypeInfo(i interface{}) (*GenTypeInfo, error) { t := reflect.TypeOf(i) + pkg := t.PkgPath() + out := GenTypeInfo{ Name: t.Name(), } @@ -208,9 +245,8 @@ func emitCborMarshalStringField(w io.Writer, f Field) error { `) } func emitCborMarshalStructField(w io.Writer, f Field) error { - fname := f.Type.PkgPath() + "." + f.Type.Name() - switch fname { - case "math/big.Int": + switch f.Type { + case bigIntType: return doTemplate(w, f, ` { if err := cbg.CborWriteHeader(w, cbg.MajTag, 2); err != nil { @@ -230,7 +266,7 @@ func emitCborMarshalStructField(w io.Writer, f Field) error { } `) - case "github.com/ipfs/go-cid.Cid": + case cidType: return doTemplate(w, f, ` {{ if .Pointer }} if {{ .Name }} == nil { @@ -403,9 +439,8 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error { default: return fmt.Errorf("do not yet support slices of %s yet", e.Kind()) case reflect.Struct: - fname := e.PkgPath() + "." + e.Name() - switch fname { - case "github.com/ipfs/go-cid.Cid": + switch e { + case cidType: err := doTemplate(w, f, ` if err := cbg.WriteCidBuf(scratch, w, v); err != nil { return xerrors.Errorf("failed writing cid field {{ .Name }}: %w", err) @@ -548,10 +583,8 @@ func emitCborUnmarshalStringField(w io.Writer, f Field) error { } func emitCborUnmarshalStructField(w io.Writer, f Field) error { - fname := f.Type.PkgPath() + "." + f.Type.Name() - - switch fname { - case "math/big.Int": + switch f.Type { + case bigIntType: return doTemplate(w, f, ` maj, extra, err = {{ ReadHeader "br" }} if err != nil { @@ -585,7 +618,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error { {{ .Name }} = big.NewInt(0) } `) - case "github.com/ipfs/go-cid.Cid": + case cidType: return doTemplate(w, f, ` { {{ if .Pointer }} @@ -610,7 +643,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error { {{ end }} } `) - case "github.com/whyrusleeping/cbor-gen.Deferred": + case deferredType: return doTemplate(w, f, ` { {{ if .Pointer }} @@ -1059,12 +1092,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error { } // Generates 'tuple representation' cbor encoders for the given type -func GenTupleEncodersForType(inpkg string, i interface{}, w io.Writer) error { - gti, err := ParseTypeInfo(inpkg, i) - if err != nil { - return err - } - +func GenTupleEncodersForType(gti *GenTypeInfo, w io.Writer) error { if err := emitCborMarshalStructTuple(w, gti); err != nil { return err } @@ -1251,12 +1279,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error { } // Generates 'tuple representation' cbor encoders for the given type -func GenMapEncodersForType(inpkg string, i interface{}, w io.Writer) error { - gti, err := ParseTypeInfo(inpkg, i) - if err != nil { - return err - } - +func GenMapEncodersForType(gti *GenTypeInfo, w io.Writer) error { if err := emitCborMarshalStructMap(w, gti); err != nil { return err } diff --git a/package.go b/package.go new file mode 100644 index 0000000..1943e09 --- /dev/null +++ b/package.go @@ -0,0 +1,99 @@ +package typegen + +import ( + "fmt" + "reflect" + "sort" + "strings" + "sync" +) + +var ( + knownPackageNamesMu sync.Mutex + pkgNameToPkgPath = make(map[string]string) + pkgPathToPkgName = make(map[string]string) + + defaultImports = []Import{ + {Name: "cbg", PkgPath: "github.com/whyrusleeping/cbor-gen"}, + {Name: "xerrors", PkgPath: "golang.org/x/xerrors"}, + } +) + +func init() { + for _, imp := range defaultImports { + if was, conflict := pkgNameToPkgPath[imp.Name]; conflict { + panic(fmt.Sprintf("reused pkg name %s for %s and %s", imp.Name, imp.PkgPath, was)) + } + if _, conflict := pkgPathToPkgName[imp.Name]; conflict { + panic(fmt.Sprintf("duplicate default import %s", imp.PkgPath)) + } + pkgNameToPkgPath[imp.Name] = imp.PkgPath + pkgPathToPkgName[imp.PkgPath] = imp.Name + } +} + +func resolvePkgName(path, typeName string) string { + parts := strings.Split(typeName, ".") + if len(parts) != 2 { + panic(fmt.Sprintf("expected type to have a package name: %s", typeName)) + } + defaultName := parts[0] + + knownPackageNamesMu.Lock() + defer knownPackageNamesMu.Unlock() + + // Check for a known name and use it. + if name, ok := pkgPathToPkgName[path]; ok { + return name + } + + // Allocate a name. + for i := 0; ; i++ { + tryName := defaultName + if i > 0 { + tryName = fmt.Sprintf("%s%d", defaultName, i) + } + if _, taken := pkgNameToPkgPath[tryName]; !taken { + pkgNameToPkgPath[tryName] = path + pkgPathToPkgName[path] = tryName + return tryName + } + } + +} + +type Import struct { + Name, PkgPath string +} + +func ImportsForType(currPkg string, t reflect.Type) []Import { + switch t.Kind() { + case reflect.Array, reflect.Slice, reflect.Ptr: + return ImportsForType(currPkg, t.Elem()) + case reflect.Map: + return dedupImports(append(ImportsForType(currPkg, t.Key()), ImportsForType(currPkg, t.Elem())...)) + default: + path := t.PkgPath() + if path == "" || path == currPkg { + // built-in or in current package. + return nil + } + + return []Import{{PkgPath: path, Name: resolvePkgName(path, t.String())}} + } +} + +func dedupImports(imps []Import) []Import { + impSet := make(map[string]string, len(imps)) + for _, imp := range imps { + impSet[imp.PkgPath] = imp.Name + } + deduped := make([]Import, 0, len(imps)) + for pkg, name := range impSet { + deduped = append(deduped, Import{Name: name, PkgPath: pkg}) + } + sort.Slice(deduped, func(i, j int) bool { + return deduped[i].PkgPath < deduped[j].PkgPath + }) + return deduped +} diff --git a/writefile.go b/writefile.go index 18de6e6..6c9a63b 100644 --- a/writefile.go +++ b/writefile.go @@ -4,7 +4,6 @@ import ( "bytes" "go/format" "os" - "os/exec" "golang.org/x/xerrors" ) @@ -12,12 +11,21 @@ import ( func WriteTupleEncodersToFile(fname, pkg string, types ...interface{}) error { buf := new(bytes.Buffer) - if err := PrintHeaderAndUtilityMethods(buf, pkg); err != nil { + typeInfos := make([]*GenTypeInfo, len(types)) + for i, t := range types { + gti, err := ParseTypeInfo(t) + if err != nil { + return xerrors.Errorf("failed to parse type info: %w", err) + } + typeInfos[i] = gti + } + + if err := PrintHeaderAndUtilityMethods(buf, pkg, typeInfos); err != nil { return xerrors.Errorf("failed to write header: %w", err) } - for _, t := range types { - if err := GenTupleEncodersForType(pkg, t, buf); err != nil { + for _, t := range typeInfos { + if err := GenTupleEncodersForType(t, buf); err != nil { return xerrors.Errorf("failed to generate encoders: %w", err) } } @@ -39,22 +47,27 @@ func WriteTupleEncodersToFile(fname, pkg string, types ...interface{}) error { } _ = fi.Close() - if err := exec.Command("goimports", "-w", fname).Run(); err != nil { - return err - } - return nil } func WriteMapEncodersToFile(fname, pkg string, types ...interface{}) error { buf := new(bytes.Buffer) - if err := PrintHeaderAndUtilityMethods(buf, pkg); err != nil { + typeInfos := make([]*GenTypeInfo, len(types)) + for i, t := range types { + gti, err := ParseTypeInfo(t) + if err != nil { + return xerrors.Errorf("failed to parse type info: %w", err) + } + typeInfos[i] = gti + } + + if err := PrintHeaderAndUtilityMethods(buf, pkg, typeInfos); err != nil { return xerrors.Errorf("failed to write header: %w", err) } - for _, t := range types { - if err := GenMapEncodersForType(pkg, t, buf); err != nil { + for _, t := range typeInfos { + if err := GenMapEncodersForType(t, buf); err != nil { return xerrors.Errorf("failed to generate encoders: %w", err) } } @@ -76,9 +89,5 @@ func WriteMapEncodersToFile(fname, pkg string, types ...interface{}) error { } _ = fi.Close() - if err := exec.Command("goimports", "-w", fname).Run(); err != nil { - return err - } - return nil }