Skip to content

Commit

Permalink
Merge pull request #38 from Stebalien/feat/imports
Browse files Browse the repository at this point in the history
Fix import handling
  • Loading branch information
whyrusleeping authored Aug 12, 2020
2 parents 4fed709 + f9912a4 commit 958ddff
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 56 deletions.
105 changes: 64 additions & 41 deletions gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,24 @@ package typegen
import (
"fmt"
"io"
"math/big"
"reflect"
"strings"
"text/template"

"github.com/ipfs/go-cid"
)

const MaxLength = 8192

const ByteArrayMaxLen = 2 << 20

var (
cidType = reflect.TypeOf(cid.Cid{})
bigIntType = reflect.TypeOf(big.Int{})
deferredType = reflect.TypeOf(Deferred{})
)

func doTemplate(w io.Writer, info interface{}, templ string) error {
t := template.Must(template.New("").
Funcs(template.FuncMap{
Expand All @@ -28,19 +37,29 @@ func doTemplate(w io.Writer, info interface{}, templ string) error {
return t.Execute(w, info)
}

func PrintHeaderAndUtilityMethods(w io.Writer, pkg string) error {
func PrintHeaderAndUtilityMethods(w io.Writer, pkg string, typeInfos []*GenTypeInfo) error {
var imports []Import
for _, gti := range typeInfos {
imports = append(imports, gti.Imports()...)
}

imports = append(imports, defaultImports...)
imports = dedupImports(imports)

data := struct {
Package string
}{pkg}
Imports []Import
}{pkg, imports}
return doTemplate(w, data, `// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT.
package {{ .Package }}
import (
"fmt"
"io"
cbg "github.com/whyrusleeping/cbor-gen"
xerrors "golang.org/x/xerrors"
{{ range .Imports }}{{ .Name }} "{{ .PkgPath }}"
{{ end }}
)
Expand Down Expand Up @@ -69,13 +88,14 @@ func typeName(pkg string, t reflect.Type) string {
case reflect.Map:
return "map[" + typeName(pkg, t.Key()) + "]" + typeName(pkg, t.Elem())
default:
name := t.String()
if t.PkgPath() == "github.com/whyrusleeping/cbor-gen" {
name = "cbg." + strings.TrimPrefix(name, "typegen.")
} else {
name = strings.TrimPrefix(name, pkg+".")
pkgPath := t.PkgPath()
if pkgPath == "" {
// It's a built-in.
return t.String()
} else if pkgPath == pkg {
return t.Name()
}
return name
return fmt.Sprintf("%s.%s", resolvePkgName(pkgPath, t.String()), t.Name())
}
}

Expand All @@ -100,6 +120,25 @@ type GenTypeInfo struct {
Fields []Field
}

func (gti *GenTypeInfo) Imports() []Import {
var imports []Import
for _, f := range gti.Fields {
switch f.Type.Kind() {
case reflect.Struct:
if !f.Pointer && f.Type != bigIntType {
continue
}
if f.Type == cidType {
continue
}
case reflect.Bool:
continue
}
imports = append(imports, ImportsForType(f.Pkg, f.Type)...)
}
return imports
}

func (gti *GenTypeInfo) NeedsScratch() bool {
for _, f := range gti.Fields {
switch f.Type.Kind() {
Expand All @@ -113,11 +152,7 @@ func (gti *GenTypeInfo) NeedsScratch() bool {
return true

case reflect.Struct:
fname := f.Type.PkgPath() + "." + f.Type.Name()
switch fname {
case "math/big.Int":
return true
case "github.com/ipfs/go-cid.Cid":
if f.Type == bigIntType || f.Type == cidType {
return true
}
// nope
Expand All @@ -132,9 +167,11 @@ func nameIsExported(name string) bool {
return strings.ToUpper(name[0:1]) == name[0:1]
}

func ParseTypeInfo(pkg string, i interface{}) (*GenTypeInfo, error) {
func ParseTypeInfo(i interface{}) (*GenTypeInfo, error) {
t := reflect.TypeOf(i)

pkg := t.PkgPath()

out := GenTypeInfo{
Name: t.Name(),
}
Expand Down Expand Up @@ -208,9 +245,8 @@ func emitCborMarshalStringField(w io.Writer, f Field) error {
`)
}
func emitCborMarshalStructField(w io.Writer, f Field) error {
fname := f.Type.PkgPath() + "." + f.Type.Name()
switch fname {
case "math/big.Int":
switch f.Type {
case bigIntType:
return doTemplate(w, f, `
{
if err := cbg.CborWriteHeader(w, cbg.MajTag, 2); err != nil {
Expand All @@ -230,7 +266,7 @@ func emitCborMarshalStructField(w io.Writer, f Field) error {
}
`)

case "github.com/ipfs/go-cid.Cid":
case cidType:
return doTemplate(w, f, `
{{ if .Pointer }}
if {{ .Name }} == nil {
Expand Down Expand Up @@ -403,9 +439,8 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error {
default:
return fmt.Errorf("do not yet support slices of %s yet", e.Kind())
case reflect.Struct:
fname := e.PkgPath() + "." + e.Name()
switch fname {
case "github.com/ipfs/go-cid.Cid":
switch e {
case cidType:
err := doTemplate(w, f, `
if err := cbg.WriteCidBuf(scratch, w, v); err != nil {
return xerrors.Errorf("failed writing cid field {{ .Name }}: %w", err)
Expand Down Expand Up @@ -548,10 +583,8 @@ func emitCborUnmarshalStringField(w io.Writer, f Field) error {
}

func emitCborUnmarshalStructField(w io.Writer, f Field) error {
fname := f.Type.PkgPath() + "." + f.Type.Name()

switch fname {
case "math/big.Int":
switch f.Type {
case bigIntType:
return doTemplate(w, f, `
maj, extra, err = {{ ReadHeader "br" }}
if err != nil {
Expand Down Expand Up @@ -585,7 +618,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error {
{{ .Name }} = big.NewInt(0)
}
`)
case "github.com/ipfs/go-cid.Cid":
case cidType:
return doTemplate(w, f, `
{
{{ if .Pointer }}
Expand All @@ -610,7 +643,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error {
{{ end }}
}
`)
case "github.com/whyrusleeping/cbor-gen.Deferred":
case deferredType:
return doTemplate(w, f, `
{
{{ if .Pointer }}
Expand Down Expand Up @@ -1059,12 +1092,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error {
}

// Generates 'tuple representation' cbor encoders for the given type
func GenTupleEncodersForType(inpkg string, i interface{}, w io.Writer) error {
gti, err := ParseTypeInfo(inpkg, i)
if err != nil {
return err
}

func GenTupleEncodersForType(gti *GenTypeInfo, w io.Writer) error {
if err := emitCborMarshalStructTuple(w, gti); err != nil {
return err
}
Expand Down Expand Up @@ -1251,12 +1279,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error {
}

// Generates 'tuple representation' cbor encoders for the given type
func GenMapEncodersForType(inpkg string, i interface{}, w io.Writer) error {
gti, err := ParseTypeInfo(inpkg, i)
if err != nil {
return err
}

func GenMapEncodersForType(gti *GenTypeInfo, w io.Writer) error {
if err := emitCborMarshalStructMap(w, gti); err != nil {
return err
}
Expand Down
99 changes: 99 additions & 0 deletions package.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package typegen

import (
"fmt"
"reflect"
"sort"
"strings"
"sync"
)

var (
knownPackageNamesMu sync.Mutex
pkgNameToPkgPath = make(map[string]string)
pkgPathToPkgName = make(map[string]string)

defaultImports = []Import{
{Name: "cbg", PkgPath: "github.com/whyrusleeping/cbor-gen"},
{Name: "xerrors", PkgPath: "golang.org/x/xerrors"},
}
)

func init() {
for _, imp := range defaultImports {
if was, conflict := pkgNameToPkgPath[imp.Name]; conflict {
panic(fmt.Sprintf("reused pkg name %s for %s and %s", imp.Name, imp.PkgPath, was))
}
if _, conflict := pkgPathToPkgName[imp.Name]; conflict {
panic(fmt.Sprintf("duplicate default import %s", imp.PkgPath))
}
pkgNameToPkgPath[imp.Name] = imp.PkgPath
pkgPathToPkgName[imp.PkgPath] = imp.Name
}
}

func resolvePkgName(path, typeName string) string {
parts := strings.Split(typeName, ".")
if len(parts) != 2 {
panic(fmt.Sprintf("expected type to have a package name: %s", typeName))
}
defaultName := parts[0]

knownPackageNamesMu.Lock()
defer knownPackageNamesMu.Unlock()

// Check for a known name and use it.
if name, ok := pkgPathToPkgName[path]; ok {
return name
}

// Allocate a name.
for i := 0; ; i++ {
tryName := defaultName
if i > 0 {
tryName = fmt.Sprintf("%s%d", defaultName, i)
}
if _, taken := pkgNameToPkgPath[tryName]; !taken {
pkgNameToPkgPath[tryName] = path
pkgPathToPkgName[path] = tryName
return tryName
}
}

}

type Import struct {
Name, PkgPath string
}

func ImportsForType(currPkg string, t reflect.Type) []Import {
switch t.Kind() {
case reflect.Array, reflect.Slice, reflect.Ptr:
return ImportsForType(currPkg, t.Elem())
case reflect.Map:
return dedupImports(append(ImportsForType(currPkg, t.Key()), ImportsForType(currPkg, t.Elem())...))
default:
path := t.PkgPath()
if path == "" || path == currPkg {
// built-in or in current package.
return nil
}

return []Import{{PkgPath: path, Name: resolvePkgName(path, t.String())}}
}
}

func dedupImports(imps []Import) []Import {
impSet := make(map[string]string, len(imps))
for _, imp := range imps {
impSet[imp.PkgPath] = imp.Name
}
deduped := make([]Import, 0, len(imps))
for pkg, name := range impSet {
deduped = append(deduped, Import{Name: name, PkgPath: pkg})
}
sort.Slice(deduped, func(i, j int) bool {
return deduped[i].PkgPath < deduped[j].PkgPath
})
return deduped
}
Loading

0 comments on commit 958ddff

Please sign in to comment.