Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export utils for import tracking and refs #247

Merged
merged 4 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 25 additions & 33 deletions examples/defaulter-gen/generators/defaulter.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,21 +501,27 @@ func mustEnforceDefault(t *types.Type, depth int, omitEmpty bool) (interface{},
var refRE = regexp.MustCompile(`^ref\((?P<reference>[^"]+)\)$`)
var refREIdentIndex = refRE.SubexpIndex("reference")

// parseAsRef looks for strings that match one of the following:
// ParseSymbolReference looks for strings that match one of the following:
alexzielenski marked this conversation as resolved.
Show resolved Hide resolved
// - ref(Ident)
// - ref(pkgpath.Ident)
// If the input string matches either of these, it will return the (optional)
// pkgpath, the Ident, and true. Otherwise it will return empty strings and
// false.
func parseAsRef(s string) (string, bool) {
func ParseSymbolReference(s, sourcePackage string) (types.Name, bool) {
matches := refRE.FindStringSubmatch(s)
if len(matches) < refREIdentIndex || matches[refREIdentIndex] == "" {
return "", false
return types.Name{}, false
}
return matches[refREIdentIndex], true

contents := matches[refREIdentIndex]
name := types.ParseFullyQualifiedName(contents)
if len(name.Package) == 0 {
name.Package = sourcePackage
}
return name, true
}

func populateDefaultValue(node *callNode, t *types.Type, tags string, commentLines []string) *callNode {
func populateDefaultValue(node *callNode, t *types.Type, tags string, commentLines []string, commentPackage string) *callNode {
defaultMap := extractDefaultTag(commentLines)
var defaultString string
if len(defaultMap) == 1 {
Expand All @@ -531,9 +537,9 @@ func populateDefaultValue(node *callNode, t *types.Type, tags string, commentLin
} else if len(defaultMap) == 0 {
return node
}
var symbolReference string
var symbolReference types.Name
var defaultValue interface{}
if id, ok := parseAsRef(defaultString); ok {
if id, ok := ParseSymbolReference(defaultString, commentPackage); ok {
symbolReference = id
defaultString = ""
} else if err := json.Unmarshal([]byte(defaultString), &defaultValue); err != nil {
Expand Down Expand Up @@ -642,15 +648,15 @@ func (c *callTreeForType) build(t *types.Type, root bool) *callNode {
child.elem = true
}
parent.children = append(parent.children, *child)
} else if member := populateDefaultValue(nil, t.Elem, "", t.Elem.CommentLines); member != nil {
} else if member := populateDefaultValue(nil, t.Elem, "", t.Elem.CommentLines, t.Elem.Name.Package); member != nil {
member.index = true
parent.children = append(parent.children, *member)
}
case types.Map:
if child := c.build(t.Elem, false); child != nil {
child.key = true
parent.children = append(parent.children, *child)
} else if member := populateDefaultValue(nil, t.Elem, "", t.Elem.CommentLines); member != nil {
} else if member := populateDefaultValue(nil, t.Elem, "", t.Elem.CommentLines, t.Elem.Name.Package); member != nil {
member.key = true
parent.children = append(parent.children, *member)
}
Expand All @@ -667,9 +673,9 @@ func (c *callTreeForType) build(t *types.Type, root bool) *callNode {
}
if child := c.build(field.Type, false); child != nil {
child.field = name
populateDefaultValue(child, field.Type, field.Tags, field.CommentLines)
populateDefaultValue(child, field.Type, field.Tags, field.CommentLines, field.Type.Name.Package)
parent.children = append(parent.children, *child)
} else if member := populateDefaultValue(nil, field.Type, field.Tags, field.CommentLines); member != nil {
} else if member := populateDefaultValue(nil, field.Type, field.Tags, field.CommentLines, t.Name.Package); member != nil {
member.field = name
parent.children = append(parent.children, *member)
}
Expand All @@ -691,11 +697,6 @@ const (
conversionPackagePath = "k8s.io/apimachinery/pkg/conversion"
)

type symbolTracker interface {
namer.ImportTracker
AddSymbol(types.Name)
}

// genDefaulter produces a file with a autogenerated conversions.
type genDefaulter struct {
generator.DefaultGen
Expand All @@ -704,7 +705,7 @@ type genDefaulter struct {
peerPackages []string
newDefaulters defaulterFuncMap
existingDefaulters defaulterFuncMap
imports symbolTracker
imports namer.ImportTracker
typesForInit []*types.Type
}

Expand All @@ -718,7 +719,7 @@ func NewGenDefaulter(sanitizedName, typesPackage, outputPackage string, existing
peerPackages: peerPkgs,
newDefaulters: newDefaulters,
existingDefaulters: existingDefaulters,
imports: generator.NewImportTracker(),
imports: generator.NewImportTrackerForPackage(outputPackage),
typesForInit: make([]*types.Type, 0),
}
}
Expand Down Expand Up @@ -794,22 +795,13 @@ func (g *genDefaulter) GenerateType(c *generator.Context, t *types.Type, w io.Wr
}
i := 0
callTree.VisitInOrder(func(ancestors []*callNode, current *callNode) {
if len(current.defaultValue.SymbolReference) > 0 {
// If the defaultValue was a reference to a symbol instead of a constant,
// make sure to add it to imports and resolve the name of the symbol
// before generating the defaults.
parsedName := types.ParseFullyQualifiedName(current.defaultValue.SymbolReference)
g.imports.AddSymbol(parsedName)
if ref := &current.defaultValue.SymbolReference; len(ref.Name) > 0 {
// Ensure package for symbol is imported in output generation
g.imports.AddSymbol(*ref)

// Rewrite the fully qualified name using the local package name
// from the imports
localPackage := g.imports.LocalNameOf(parsedName.Package)
if len(localPackage) > 0 {
current.defaultValue.SymbolReference = localPackage + "." + parsedName.Name
} else {
current.defaultValue.SymbolReference = parsedName.Name
}

ref.Package = g.imports.LocalNameOf(ref.Package)
}

if len(current.call) == 0 {
Expand Down Expand Up @@ -915,7 +907,7 @@ type defaultValue struct {
// The name of the symbol relative to the parsed package path
// i.e. k8s.io/pkg.apis.v1.Foo if from another package or simply `Foo`
// if within the same package.
SymbolReference string
SymbolReference types.Name
}

func (d defaultValue) IsEmpty() bool {
Expand All @@ -927,7 +919,7 @@ func (d defaultValue) Resolved() string {
if len(d.InlineConstant) > 0 {
return d.InlineConstant
}
return d.SymbolReference
return d.SymbolReference.String()
}

// CallNodeVisitorFunc is a function for visiting a call tree. ancestors is the list of all parents
Expand Down
23 changes: 21 additions & 2 deletions generator/import_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,34 @@ import (
"k8s.io/gengo/types"
)

func NewImportTracker(typesToAdd ...*types.Type) *namer.DefaultImportTracker {
tracker := namer.NewDefaultImportTracker(types.Name{})
// NewImportTrackerForPackage creates a new import tracker which is aware
// of a generator's output package. The tracker will not add import lines
// when symbols or types are added from the same package, and LocalNameOf
// will return empty string for the output package.
//
// e.g.:
//
// tracker := NewImportTrackerForPackage("bar.com/pkg/foo")
// tracker.AddSymbol(types.Name{"bar.com/pkg/foo.MyType"})
// tracker.AddSymbol(types.Name{"bar.com/pkg/baz.MyType"})
// tracker.AddSymbol(types.Name{"bar.com/pkg/baz/baz.MyType"})
//
// tracker.LocalNameOf("bar.com/pkg/foo") -> ""
// tracker.LocalNameOf("bar.com/pkg/baz") -> "baz"
// tracker.LocalNameOf("bar.com/pkg/baz/baz") -> "bazbaz"
// tracker.ImportLines() -> {`baz "bar.com/pkg/baz"`, `bazbaz "bar.com/pkg/baz/baz"`}
func NewImportTrackerForPackage(local string, typesToAdd ...*types.Type) *namer.DefaultImportTracker {
alexzielenski marked this conversation as resolved.
Show resolved Hide resolved
tracker := namer.NewDefaultImportTracker(types.Name{Package: local})
tracker.IsInvalidType = func(*types.Type) bool { return false }
tracker.LocalName = func(name types.Name) string { return golangTrackerLocalName(&tracker, name) }
tracker.PrintImport = func(path, name string) string { return name + " \"" + path + "\"" }

tracker.AddTypes(typesToAdd...)
return &tracker
}

func NewImportTracker(typesToAdd ...*types.Type) *namer.DefaultImportTracker {
return NewImportTrackerForPackage("", typesToAdd...)
}

func golangTrackerLocalName(tracker namer.ImportTracker, t types.Name) string {
Expand Down
19 changes: 18 additions & 1 deletion generator/import_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
func TestNewImportTracker(t *testing.T) {
tests := []struct {
name string
localPackage string
inputTypes []*types.Type
expectedImports []string
}{
Expand Down Expand Up @@ -63,10 +64,26 @@ func TestNewImportTracker(t *testing.T) {
`_struct "my/reserved/pkg/struct"`,
},
},
{
name: "local-symbol",
localPackage: "bar.com/my/pkg",
inputTypes: []*types.Type{
{Name: types.Name{Package: "bar.com/my/pkg"}},
{Name: types.Name{Package: "bar.com/external/pkg"}},
},
expectedImports: []string{
`pkg "bar.com/external/pkg"`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actualImports := NewImportTracker(tt.inputTypes...).ImportLines()
var actualImports []string
if len(tt.localPackage) == 0 {
actualImports = NewImportTracker(tt.inputTypes...).ImportLines()
} else {
actualImports = NewImportTrackerForPackage(tt.localPackage, tt.inputTypes...).ImportLines()
}
if !reflect.DeepEqual(actualImports, tt.expectedImports) {
t.Errorf("ImportLines(%v) = %v, want %v", tt.inputTypes, actualImports, tt.expectedImports)
}
Expand Down
1 change: 1 addition & 0 deletions namer/namer.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ func (ns *NameStrategy) Name(t *types.Type) string {
// import. You can implement yourself or use the one in the generation package.
type ImportTracker interface {
AddType(*types.Type)
AddSymbol(types.Name)
LocalNameOf(packagePath string) string
PathOf(localName string) (string, bool)
ImportLines() []string
Expand Down