Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix import handling #38

Merged
merged 1 commit into from
Aug 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 64 additions & 41 deletions gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,24 @@ package typegen
import (
"fmt"
"io"
"math/big"
"reflect"
"strings"
"text/template"

"github.com/ipfs/go-cid"
)

const MaxLength = 8192

const ByteArrayMaxLen = 2 << 20

var (
cidType = reflect.TypeOf(cid.Cid{})
bigIntType = reflect.TypeOf(big.Int{})
deferredType = reflect.TypeOf(Deferred{})
)

func doTemplate(w io.Writer, info interface{}, templ string) error {
t := template.Must(template.New("").
Funcs(template.FuncMap{
Expand All @@ -28,19 +37,29 @@ func doTemplate(w io.Writer, info interface{}, templ string) error {
return t.Execute(w, info)
}

func PrintHeaderAndUtilityMethods(w io.Writer, pkg string) error {
func PrintHeaderAndUtilityMethods(w io.Writer, pkg string, typeInfos []*GenTypeInfo) error {
var imports []Import
for _, gti := range typeInfos {
imports = append(imports, gti.Imports()...)
}

imports = append(imports, defaultImports...)
imports = dedupImports(imports)

data := struct {
Package string
}{pkg}
Imports []Import
}{pkg, imports}
return doTemplate(w, data, `// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT.

package {{ .Package }}

import (
"fmt"
"io"
cbg "github.com/whyrusleeping/cbor-gen"
xerrors "golang.org/x/xerrors"

{{ range .Imports }}{{ .Name }} "{{ .PkgPath }}"
{{ end }}
)


Expand Down Expand Up @@ -69,13 +88,14 @@ func typeName(pkg string, t reflect.Type) string {
case reflect.Map:
return "map[" + typeName(pkg, t.Key()) + "]" + typeName(pkg, t.Elem())
default:
name := t.String()
if t.PkgPath() == "github.com/whyrusleeping/cbor-gen" {
name = "cbg." + strings.TrimPrefix(name, "typegen.")
} else {
name = strings.TrimPrefix(name, pkg+".")
pkgPath := t.PkgPath()
if pkgPath == "" {
// It's a built-in.
return t.String()
} else if pkgPath == pkg {
return t.Name()
}
return name
return fmt.Sprintf("%s.%s", resolvePkgName(pkgPath, t.String()), t.Name())
}
}

Expand All @@ -100,6 +120,25 @@ type GenTypeInfo struct {
Fields []Field
}

func (gti *GenTypeInfo) Imports() []Import {
var imports []Import
for _, f := range gti.Fields {
switch f.Type.Kind() {
case reflect.Struct:
if !f.Pointer && f.Type != bigIntType {
continue
}
if f.Type == cidType {
continue
}
case reflect.Bool:
continue
}
imports = append(imports, ImportsForType(f.Pkg, f.Type)...)
}
return imports
}

func (gti *GenTypeInfo) NeedsScratch() bool {
for _, f := range gti.Fields {
switch f.Type.Kind() {
Expand All @@ -113,11 +152,7 @@ func (gti *GenTypeInfo) NeedsScratch() bool {
return true

case reflect.Struct:
fname := f.Type.PkgPath() + "." + f.Type.Name()
switch fname {
case "math/big.Int":
return true
case "github.com/ipfs/go-cid.Cid":
if f.Type == bigIntType || f.Type == cidType {
return true
}
// nope
Expand All @@ -132,9 +167,11 @@ func nameIsExported(name string) bool {
return strings.ToUpper(name[0:1]) == name[0:1]
}

func ParseTypeInfo(pkg string, i interface{}) (*GenTypeInfo, error) {
func ParseTypeInfo(i interface{}) (*GenTypeInfo, error) {
t := reflect.TypeOf(i)

pkg := t.PkgPath()

out := GenTypeInfo{
Name: t.Name(),
}
Expand Down Expand Up @@ -208,9 +245,8 @@ func emitCborMarshalStringField(w io.Writer, f Field) error {
`)
}
func emitCborMarshalStructField(w io.Writer, f Field) error {
fname := f.Type.PkgPath() + "." + f.Type.Name()
switch fname {
case "math/big.Int":
switch f.Type {
case bigIntType:
return doTemplate(w, f, `
{
if err := cbg.CborWriteHeader(w, cbg.MajTag, 2); err != nil {
Expand All @@ -230,7 +266,7 @@ func emitCborMarshalStructField(w io.Writer, f Field) error {
}
`)

case "github.com/ipfs/go-cid.Cid":
case cidType:
return doTemplate(w, f, `
{{ if .Pointer }}
if {{ .Name }} == nil {
Expand Down Expand Up @@ -403,9 +439,8 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error {
default:
return fmt.Errorf("do not yet support slices of %s yet", e.Kind())
case reflect.Struct:
fname := e.PkgPath() + "." + e.Name()
switch fname {
case "github.com/ipfs/go-cid.Cid":
switch e {
case cidType:
err := doTemplate(w, f, `
if err := cbg.WriteCidBuf(scratch, w, v); err != nil {
return xerrors.Errorf("failed writing cid field {{ .Name }}: %w", err)
Expand Down Expand Up @@ -548,10 +583,8 @@ func emitCborUnmarshalStringField(w io.Writer, f Field) error {
}

func emitCborUnmarshalStructField(w io.Writer, f Field) error {
fname := f.Type.PkgPath() + "." + f.Type.Name()

switch fname {
case "math/big.Int":
switch f.Type {
case bigIntType:
return doTemplate(w, f, `
maj, extra, err = {{ ReadHeader "br" }}
if err != nil {
Expand Down Expand Up @@ -585,7 +618,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error {
{{ .Name }} = big.NewInt(0)
}
`)
case "github.com/ipfs/go-cid.Cid":
case cidType:
return doTemplate(w, f, `
{
{{ if .Pointer }}
Expand All @@ -610,7 +643,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error {
{{ end }}
}
`)
case "github.com/whyrusleeping/cbor-gen.Deferred":
case deferredType:
return doTemplate(w, f, `
{
{{ if .Pointer }}
Expand Down Expand Up @@ -1059,12 +1092,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error {
}

// Generates 'tuple representation' cbor encoders for the given type
func GenTupleEncodersForType(inpkg string, i interface{}, w io.Writer) error {
gti, err := ParseTypeInfo(inpkg, i)
if err != nil {
return err
}

func GenTupleEncodersForType(gti *GenTypeInfo, w io.Writer) error {
if err := emitCborMarshalStructTuple(w, gti); err != nil {
return err
}
Expand Down Expand Up @@ -1251,12 +1279,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error {
}

// Generates 'tuple representation' cbor encoders for the given type
func GenMapEncodersForType(inpkg string, i interface{}, w io.Writer) error {
gti, err := ParseTypeInfo(inpkg, i)
if err != nil {
return err
}

func GenMapEncodersForType(gti *GenTypeInfo, w io.Writer) error {
if err := emitCborMarshalStructMap(w, gti); err != nil {
return err
}
Expand Down
99 changes: 99 additions & 0 deletions package.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package typegen

import (
"fmt"
"reflect"
"sort"
"strings"
"sync"
)

var (
knownPackageNamesMu sync.Mutex
pkgNameToPkgPath = make(map[string]string)
pkgPathToPkgName = make(map[string]string)

defaultImports = []Import{
{Name: "cbg", PkgPath: "github.com/whyrusleeping/cbor-gen"},
{Name: "xerrors", PkgPath: "golang.org/x/xerrors"},
}
)

func init() {
for _, imp := range defaultImports {
if was, conflict := pkgNameToPkgPath[imp.Name]; conflict {
panic(fmt.Sprintf("reused pkg name %s for %s and %s", imp.Name, imp.PkgPath, was))
}
if _, conflict := pkgPathToPkgName[imp.Name]; conflict {
panic(fmt.Sprintf("duplicate default import %s", imp.PkgPath))
}
pkgNameToPkgPath[imp.Name] = imp.PkgPath
pkgPathToPkgName[imp.PkgPath] = imp.Name
}
}

func resolvePkgName(path, typeName string) string {
parts := strings.Split(typeName, ".")
if len(parts) != 2 {
panic(fmt.Sprintf("expected type to have a package name: %s", typeName))
}
defaultName := parts[0]

knownPackageNamesMu.Lock()
defer knownPackageNamesMu.Unlock()

// Check for a known name and use it.
if name, ok := pkgPathToPkgName[path]; ok {
return name
}

// Allocate a name.
for i := 0; ; i++ {
tryName := defaultName
if i > 0 {
tryName = fmt.Sprintf("%s%d", defaultName, i)
}
if _, taken := pkgNameToPkgPath[tryName]; !taken {
pkgNameToPkgPath[tryName] = path
pkgPathToPkgName[path] = tryName
return tryName
}
}

}

type Import struct {
Name, PkgPath string
}

func ImportsForType(currPkg string, t reflect.Type) []Import {
switch t.Kind() {
case reflect.Array, reflect.Slice, reflect.Ptr:
return ImportsForType(currPkg, t.Elem())
case reflect.Map:
return dedupImports(append(ImportsForType(currPkg, t.Key()), ImportsForType(currPkg, t.Elem())...))
default:
path := t.PkgPath()
if path == "" || path == currPkg {
// built-in or in current package.
return nil
}

return []Import{{PkgPath: path, Name: resolvePkgName(path, t.String())}}
}
}

func dedupImports(imps []Import) []Import {
impSet := make(map[string]string, len(imps))
for _, imp := range imps {
impSet[imp.PkgPath] = imp.Name
}
deduped := make([]Import, 0, len(imps))
for pkg, name := range impSet {
deduped = append(deduped, Import{Name: name, PkgPath: pkg})
}
sort.Slice(deduped, func(i, j int) bool {
return deduped[i].PkgPath < deduped[j].PkgPath
})
return deduped
}
Loading