From 8be329be78227cae2a1207cc3263648e37d58649 Mon Sep 17 00:00:00 2001 From: Alex Zielenski Date: Fri, 18 Aug 2023 13:07:57 -0700 Subject: [PATCH 1/4] expose ParseSymbolReference --- .../defaulter-gen/generators/defaulter.go | 52 ++++++++++--------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/examples/defaulter-gen/generators/defaulter.go b/examples/defaulter-gen/generators/defaulter.go index 48566aee..b7eeb7a9 100644 --- a/examples/defaulter-gen/generators/defaulter.go +++ b/examples/defaulter-gen/generators/defaulter.go @@ -501,21 +501,27 @@ func mustEnforceDefault(t *types.Type, depth int, omitEmpty bool) (interface{}, var refRE = regexp.MustCompile(`^ref\((?P[^"]+)\)$`) var refREIdentIndex = refRE.SubexpIndex("reference") -// parseAsRef looks for strings that match one of the following: +// ParseSymbolReference looks for strings that match one of the following: // - ref(Ident) // - ref(pkgpath.Ident) // If the input string matches either of these, it will return the (optional) // pkgpath, the Ident, and true. Otherwise it will return empty strings and // false. -func parseAsRef(s string) (string, bool) { +func ParseSymbolReference(s, sourcePackage string) (types.Name, bool) { matches := refRE.FindStringSubmatch(s) if len(matches) < refREIdentIndex || matches[refREIdentIndex] == "" { - return "", false + return types.Name{}, false } - return matches[refREIdentIndex], true + + contents := matches[refREIdentIndex] + name := types.ParseFullyQualifiedName(contents) + if len(name.Package) == 0 { + name.Package = sourcePackage + } + return name, true } -func populateDefaultValue(node *callNode, t *types.Type, tags string, commentLines []string) *callNode { +func populateDefaultValue(node *callNode, t *types.Type, tags string, commentLines []string, commentPackage string) *callNode { defaultMap := extractDefaultTag(commentLines) var defaultString string if len(defaultMap) == 1 { @@ -531,9 +537,9 @@ func populateDefaultValue(node *callNode, t *types.Type, tags string, commentLin } else if len(defaultMap) == 0 { return node } - var symbolReference string + var symbolReference types.Name var defaultValue interface{} - if id, ok := parseAsRef(defaultString); ok { + if id, ok := ParseSymbolReference(defaultString, commentPackage); ok { symbolReference = id defaultString = "" } else if err := json.Unmarshal([]byte(defaultString), &defaultValue); err != nil { @@ -642,7 +648,7 @@ func (c *callTreeForType) build(t *types.Type, root bool) *callNode { child.elem = true } parent.children = append(parent.children, *child) - } else if member := populateDefaultValue(nil, t.Elem, "", t.Elem.CommentLines); member != nil { + } else if member := populateDefaultValue(nil, t.Elem, "", t.Elem.CommentLines, t.Elem.Name.Package); member != nil { member.index = true parent.children = append(parent.children, *member) } @@ -650,7 +656,7 @@ func (c *callTreeForType) build(t *types.Type, root bool) *callNode { if child := c.build(t.Elem, false); child != nil { child.key = true parent.children = append(parent.children, *child) - } else if member := populateDefaultValue(nil, t.Elem, "", t.Elem.CommentLines); member != nil { + } else if member := populateDefaultValue(nil, t.Elem, "", t.Elem.CommentLines, t.Elem.Name.Package); member != nil { member.key = true parent.children = append(parent.children, *member) } @@ -667,9 +673,9 @@ func (c *callTreeForType) build(t *types.Type, root bool) *callNode { } if child := c.build(field.Type, false); child != nil { child.field = name - populateDefaultValue(child, field.Type, field.Tags, field.CommentLines) + populateDefaultValue(child, field.Type, field.Tags, field.CommentLines, field.Type.Name.Package) parent.children = append(parent.children, *child) - } else if member := populateDefaultValue(nil, field.Type, field.Tags, field.CommentLines); member != nil { + } else if member := populateDefaultValue(nil, field.Type, field.Tags, field.CommentLines, t.Name.Package); member != nil { member.field = name parent.children = append(parent.children, *member) } @@ -794,22 +800,20 @@ func (g *genDefaulter) GenerateType(c *generator.Context, t *types.Type, w io.Wr } i := 0 callTree.VisitInOrder(func(ancestors []*callNode, current *callNode) { - if len(current.defaultValue.SymbolReference) > 0 { - // If the defaultValue was a reference to a symbol instead of a constant, - // make sure to add it to imports and resolve the name of the symbol - // before generating the defaults. - parsedName := types.ParseFullyQualifiedName(current.defaultValue.SymbolReference) - g.imports.AddSymbol(parsedName) + if ref := ¤t.defaultValue.SymbolReference; len(ref.Name) > 0 { + // Ensure package for symbol is imported in output generation + g.imports.AddSymbol(*ref) // Rewrite the fully qualified name using the local package name // from the imports - localPackage := g.imports.LocalNameOf(parsedName.Package) - if len(localPackage) > 0 { - current.defaultValue.SymbolReference = localPackage + "." + parsedName.Name + // + // Can't just directly assign g.imports.LocalNameOf since the + // import tracker actually has no knowledge of the current package + if g.isOtherPackage(ref.Package) { + ref.Package = g.imports.LocalNameOf(ref.Package) } else { - current.defaultValue.SymbolReference = parsedName.Name + ref.Package = "" } - } if len(current.call) == 0 { @@ -915,7 +919,7 @@ type defaultValue struct { // The name of the symbol relative to the parsed package path // i.e. k8s.io/pkg.apis.v1.Foo if from another package or simply `Foo` // if within the same package. - SymbolReference string + SymbolReference types.Name } func (d defaultValue) IsEmpty() bool { @@ -927,7 +931,7 @@ func (d defaultValue) Resolved() string { if len(d.InlineConstant) > 0 { return d.InlineConstant } - return d.SymbolReference + return d.SymbolReference.String() } // CallNodeVisitorFunc is a function for visiting a call tree. ancestors is the list of all parents From 4f910be675ea02ec985926c495dfe962fc4c3863 Mon Sep 17 00:00:00 2001 From: Alex Zielenski Date: Fri, 18 Aug 2023 13:40:47 -0700 Subject: [PATCH 2/4] allow local package to be provded with import tracker --- examples/defaulter-gen/generators/defaulter.go | 18 +++--------------- generator/import_tracker.go | 7 +++++-- namer/namer.go | 1 + 3 files changed, 9 insertions(+), 17 deletions(-) diff --git a/examples/defaulter-gen/generators/defaulter.go b/examples/defaulter-gen/generators/defaulter.go index b7eeb7a9..5fd6c30d 100644 --- a/examples/defaulter-gen/generators/defaulter.go +++ b/examples/defaulter-gen/generators/defaulter.go @@ -697,11 +697,6 @@ const ( conversionPackagePath = "k8s.io/apimachinery/pkg/conversion" ) -type symbolTracker interface { - namer.ImportTracker - AddSymbol(types.Name) -} - // genDefaulter produces a file with a autogenerated conversions. type genDefaulter struct { generator.DefaultGen @@ -710,7 +705,7 @@ type genDefaulter struct { peerPackages []string newDefaulters defaulterFuncMap existingDefaulters defaulterFuncMap - imports symbolTracker + imports namer.ImportTracker typesForInit []*types.Type } @@ -724,7 +719,7 @@ func NewGenDefaulter(sanitizedName, typesPackage, outputPackage string, existing peerPackages: peerPkgs, newDefaulters: newDefaulters, existingDefaulters: existingDefaulters, - imports: generator.NewImportTracker(), + imports: generator.NewImportTrackerForPackage(outputPackage), typesForInit: make([]*types.Type, 0), } } @@ -806,14 +801,7 @@ func (g *genDefaulter) GenerateType(c *generator.Context, t *types.Type, w io.Wr // Rewrite the fully qualified name using the local package name // from the imports - // - // Can't just directly assign g.imports.LocalNameOf since the - // import tracker actually has no knowledge of the current package - if g.isOtherPackage(ref.Package) { - ref.Package = g.imports.LocalNameOf(ref.Package) - } else { - ref.Package = "" - } + ref.Package = g.imports.LocalNameOf(ref.Package) } if len(current.call) == 0 { diff --git a/generator/import_tracker.go b/generator/import_tracker.go index f7c25a01..b33f1a5a 100644 --- a/generator/import_tracker.go +++ b/generator/import_tracker.go @@ -26,15 +26,18 @@ import ( "k8s.io/gengo/types" ) -func NewImportTracker(typesToAdd ...*types.Type) *namer.DefaultImportTracker { - tracker := namer.NewDefaultImportTracker(types.Name{}) +func NewImportTrackerForPackage(local string, typesToAdd ...*types.Type) *namer.DefaultImportTracker { + tracker := namer.NewDefaultImportTracker(types.Name{Package: local}) tracker.IsInvalidType = func(*types.Type) bool { return false } tracker.LocalName = func(name types.Name) string { return golangTrackerLocalName(&tracker, name) } tracker.PrintImport = func(path, name string) string { return name + " \"" + path + "\"" } tracker.AddTypes(typesToAdd...) return &tracker +} +func NewImportTracker(typesToAdd ...*types.Type) *namer.DefaultImportTracker { + return NewImportTrackerForPackage("", typesToAdd...) } func golangTrackerLocalName(tracker namer.ImportTracker, t types.Name) string { diff --git a/namer/namer.go b/namer/namer.go index 6feb2d0c..a0f1a24a 100644 --- a/namer/namer.go +++ b/namer/namer.go @@ -300,6 +300,7 @@ func (ns *NameStrategy) Name(t *types.Type) string { // import. You can implement yourself or use the one in the generation package. type ImportTracker interface { AddType(*types.Type) + AddSymbol(types.Name) LocalNameOf(packagePath string) string PathOf(localName string) (string, bool) ImportLines() []string From c5fa375b5643cc79ba2268140006828d7aaacea9 Mon Sep 17 00:00:00 2001 From: Alexander Zielenski <351783+alexzielenski@users.noreply.github.com> Date: Mon, 21 Aug 2023 16:01:34 -0700 Subject: [PATCH 3/4] added doc --- generator/import_tracker.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/generator/import_tracker.go b/generator/import_tracker.go index b33f1a5a..04f1ecf2 100644 --- a/generator/import_tracker.go +++ b/generator/import_tracker.go @@ -26,6 +26,22 @@ import ( "k8s.io/gengo/types" ) +// NewImportTrackerForPackage creates a new import tracker which is aware +// of a generator's output package. The tracker will not add import lines +// when symbols or types are added from the same package, and LocalNameOf +// will return empty string for the output package. +// +// e.g.: +// +// tracker := NewImportTrackerForPackage("bar.com/pkg/foo") +// tracker.AddSymbol(types.Name{"bar.com/pkg/foo.MyType"}) +// tracker.AddSymbol(types.Name{"bar.com/pkg/baz.MyType"}) +// tracker.AddSymbol(types.Name{"bar.com/pkg/baz/baz.MyType"}) +// +// tracker.LocalNameOf("bar.com/pkg/foo") -> "" +// tracker.LocalNameOf("bar.com/pkg/baz") -> "baz" +// tracker.LocalNameOf("bar.com/pkg/baz/baz") -> "bazbaz" +// tracker.ImportLines() -> {"bar.com/pkg/baz", `bazbaz "bar.com/pkg/baz/baz"`} func NewImportTrackerForPackage(local string, typesToAdd ...*types.Type) *namer.DefaultImportTracker { tracker := namer.NewDefaultImportTracker(types.Name{Package: local}) tracker.IsInvalidType = func(*types.Type) bool { return false } From 91d5541831de4b3c9ab49373a71477ba0cb0ee39 Mon Sep 17 00:00:00 2001 From: Alexander Zielenski <351783+alexzielenski@users.noreply.github.com> Date: Mon, 21 Aug 2023 16:46:37 -0700 Subject: [PATCH 4/4] add local package test --- generator/import_tracker.go | 2 +- generator/import_tracker_test.go | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/generator/import_tracker.go b/generator/import_tracker.go index 04f1ecf2..99525c40 100644 --- a/generator/import_tracker.go +++ b/generator/import_tracker.go @@ -41,7 +41,7 @@ import ( // tracker.LocalNameOf("bar.com/pkg/foo") -> "" // tracker.LocalNameOf("bar.com/pkg/baz") -> "baz" // tracker.LocalNameOf("bar.com/pkg/baz/baz") -> "bazbaz" -// tracker.ImportLines() -> {"bar.com/pkg/baz", `bazbaz "bar.com/pkg/baz/baz"`} +// tracker.ImportLines() -> {`baz "bar.com/pkg/baz"`, `bazbaz "bar.com/pkg/baz/baz"`} func NewImportTrackerForPackage(local string, typesToAdd ...*types.Type) *namer.DefaultImportTracker { tracker := namer.NewDefaultImportTracker(types.Name{Package: local}) tracker.IsInvalidType = func(*types.Type) bool { return false } diff --git a/generator/import_tracker_test.go b/generator/import_tracker_test.go index d9a88197..7a7bda8b 100644 --- a/generator/import_tracker_test.go +++ b/generator/import_tracker_test.go @@ -26,6 +26,7 @@ import ( func TestNewImportTracker(t *testing.T) { tests := []struct { name string + localPackage string inputTypes []*types.Type expectedImports []string }{ @@ -63,10 +64,26 @@ func TestNewImportTracker(t *testing.T) { `_struct "my/reserved/pkg/struct"`, }, }, + { + name: "local-symbol", + localPackage: "bar.com/my/pkg", + inputTypes: []*types.Type{ + {Name: types.Name{Package: "bar.com/my/pkg"}}, + {Name: types.Name{Package: "bar.com/external/pkg"}}, + }, + expectedImports: []string{ + `pkg "bar.com/external/pkg"`, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - actualImports := NewImportTracker(tt.inputTypes...).ImportLines() + var actualImports []string + if len(tt.localPackage) == 0 { + actualImports = NewImportTracker(tt.inputTypes...).ImportLines() + } else { + actualImports = NewImportTrackerForPackage(tt.localPackage, tt.inputTypes...).ImportLines() + } if !reflect.DeepEqual(actualImports, tt.expectedImports) { t.Errorf("ImportLines(%v) = %v, want %v", tt.inputTypes, actualImports, tt.expectedImports) }