Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements around aliased types #3940

Merged
merged 8 commits into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/go/cmd/api_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package cmd

import (
"encoding/json"
"io/ioutil"
"os"
"path/filepath"
"sort"
"strings"
Expand All @@ -19,7 +19,7 @@ func CreateAPIView(pkgDir, outputDir string) error {
}
filename := filepath.Join(outputDir, review.Name+".json")
file, _ := json.MarshalIndent(review, "", " ")
err = ioutil.WriteFile(filename, file, 0644)
err = os.WriteFile(filename, file, 0644)
if err != nil {
return err
}
Expand Down
40 changes: 35 additions & 5 deletions src/go/cmd/api_view_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package cmd

import (
"encoding/json"
"io/ioutil"
"os"
"path/filepath"
"testing"
Expand All @@ -23,7 +22,7 @@ func TestFuncDecl(t *testing.T) {
if err != nil {
t.Fatal(err)
}
file, err := ioutil.ReadFile("./output/testfuncdecl.json")
file, err := os.ReadFile("./output/testfuncdecl.json")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -51,7 +50,7 @@ func TestInterface(t *testing.T) {
if err != nil {
t.Fatal(err)
}
file, err := ioutil.ReadFile("./output/testinterface.json")
file, err := os.ReadFile("./output/testinterface.json")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -79,7 +78,7 @@ func TestStruct(t *testing.T) {
if err != nil {
t.Fatal(err)
}
file, err := ioutil.ReadFile("./output/teststruct.json")
file, err := os.ReadFile("./output/teststruct.json")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -107,7 +106,7 @@ func TestConst(t *testing.T) {
if err != nil {
t.Fatal(err)
}
file, err := ioutil.ReadFile("./output/testconst.json")
file, err := os.ReadFile("./output/testconst.json")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -265,3 +264,34 @@ func TestRecursiveAliasDefinitions(t *testing.T) {
})
}
}

func TestAliasDiagnostics(t *testing.T) {
review, err := createReview(filepath.Clean("testdata/test_alias_diagnostics"))
require.NoError(t, err)
require.Equal(t, "Go", review.Language)
require.Equal(t, "test_alias_diagnostics", review.Name)
require.Equal(t, 6, len(review.Diagnostics))
for _, diagnostic := range review.Diagnostics {
if diagnostic.TargetID == "test_alias_diagnostics.WidgetValue" {
require.Equal(t, DiagnosticLevelInfo, diagnostic.Level)
require.Equal(t, aliasFor+"internal.WidgetValue", diagnostic.Text)
} else {
require.Equal(t, "test_alias_diagnostics.Widget", diagnostic.TargetID)
switch diagnostic.Level {
case DiagnosticLevelInfo:
require.Equal(t, aliasFor+"internal.Widget", diagnostic.Text)
case DiagnosticLevelError:
switch txt := diagnostic.Text; txt {
case missingAliasFor + "WidgetProperties":
case missingAliasFor + "WidgetPropertiesP":
case missingAliasFor + "WidgetThings":
case missingAliasFor + "WidgetThingsP":
default:
t.Fatalf("unexpected diagnostic text %s", txt)
}
default:
t.Fatalf("unexpected diagnostic level %d", diagnostic.Level)
}
}
}
}
18 changes: 13 additions & 5 deletions src/go/cmd/content.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,9 @@ func (c *content) filterDeclarations(typ string, decls map[string]Declaration, t
return results
}

// searchForMethods takes the name of the receiver and looks for Funcs that are methods on that receiver.
func (c *content) searchForMethods(s string, tokenList *[]Token) {
// findMethods takes the name of the receiver and looks for Funcs that are methods on that receiver.
func (c *content) findMethods(s string) map[string]Func {
jhendrixMSFT marked this conversation as resolved.
Show resolved Hide resolved
methods := map[string]Func{}
methodNames := []string{}
for key, fn := range c.Funcs {
name := fn.Name()
if unicode.IsLower(rune(name[0])) {
Expand All @@ -339,14 +338,23 @@ func (c *content) searchForMethods(s string, tokenList *[]Token) {
}
if s == n || "*"+s == n {
methods[key] = fn
methodNames = append(methodNames, key)
delete(c.Funcs, key)
}
}
return methods
}

// searchForMethods takes the name of the receiver and looks for Funcs that are methods on that receiver.
func (c *content) searchForMethods(s string, tokenList *[]Token) {
methods := c.findMethods(s)
methodNames := []string{}
for key := range methods {
methodNames = append(methodNames, key)
}
sort.Strings(methodNames)
for _, name := range methodNames {
fn := methods[name]
*tokenList = append(*tokenList, fn.MakeTokens()...)
delete(c.Funcs, name)
}
}

Expand Down
75 changes: 75 additions & 0 deletions src/go/cmd/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ func NewModule(dir string) (*Module, error) {
m := Module{Name: filepath.Base(dir), packages: map[string]*Pkg{}}

baseImportPath := path.Dir(mf.Module.Mod.Path) + "/"
if baseImportPath == "./" {
// this is a relative path in the tests, so remove this prefix.
// if not, then the package name added below won't match the imported packages.
baseImportPath = ""
}
err = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
if d.IsDir() {
if !indexTestdata && strings.Contains(path, "testdata") {
Expand Down Expand Up @@ -120,15 +125,39 @@ func NewModule(dir string) (*Module, error) {
t = p.c.addInterface(*def.p, alias, p.Name(), n)
case *ast.StructType:
t = p.c.addStruct(*def.p, alias, p.Name(), def.n)
hoistMethodsForType(source, alias, p)
// ensure that all struct field types that are structs are also aliased from this package
for _, field := range n.Fields.List {
fieldTypeName := unwrapStructFieldTypeName(field)
if fieldTypeName == "" {
// we can ignore this field
continue
}

// ensure that our package exports this type
if _, ok := p.typeAliases[fieldTypeName]; ok {
// found an alias
continue
}

// no alias, add a diagnostic
p.diagnostics = append(p.diagnostics, Diagnostic{
Level: DiagnosticLevelError,
TargetID: t.ID(),
Text: missingAliasFor + fieldTypeName,
})
}
case *ast.Ident:
t = p.c.addSimpleType(*p, alias, p.Name(), def.n.Type.(*ast.Ident).Name)
hoistMethodsForType(source, alias, p)
default:
fmt.Printf("unexpected node type %T\n", def.n.Type)
t = p.c.addSimpleType(*p, alias, p.Name(), originalName)
}
} else {
fmt.Println("found no definition for " + qn)
}

if t != nil {
p.diagnostics = append(p.diagnostics, Diagnostic{
Level: level,
Expand All @@ -141,6 +170,52 @@ func NewModule(dir string) (*Module, error) {
return &m, nil
}

// returns the type name for the specified struct field.
// if the field can be ignored, an empty string is returned.
func unwrapStructFieldTypeName(field *ast.Field) string {
if field.Names != nil && !field.Names[0].IsExported() {
// field isn't exported so skip it
return ""
}

// start with the field expression
exp := field.Type

// if it's an array, get the element expression.
// current codegen doesn't support *[]Type so no need to handle it.
if at, ok := exp.(*ast.ArrayType); ok {
// FieldName []FieldType
// FieldName []*FieldType
exp = at.Elt
}

// from here we either have a pointer-to-type or type
var ident *ast.Ident
if se, ok := exp.(*ast.StarExpr); ok {
// FieldName *FieldType
ident, _ = se.X.(*ast.Ident)
} else {
// FieldName FieldType
ident, _ = exp.(*ast.Ident)
}

// !IsExported() is a hacky way to ignore primitive types
// FieldName bool
if ident == nil || !ident.IsExported() {
return ""
}

// returns FieldType
return ident.Name
}

func hoistMethodsForType(pkg *Pkg, typeName string, target *Pkg) {
methods := pkg.c.findMethods(typeName)
for sig, fn := range methods {
target.c.Funcs[sig] = fn.ForAlias(target.Name())
}
}

func parseModFile(dir string) (*modfile.File, error) {
path := filepath.Join(dir, "go.mod")
content, err := os.ReadFile(path)
Expand Down
4 changes: 2 additions & 2 deletions src/go/cmd/pkg.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"go/ast"
"go/parser"
"go/token"
"io/ioutil"
"os"
"path/filepath"
"strings"
Expand All @@ -19,6 +18,7 @@ import (
// diagnostic messages
const (
aliasFor = "Alias for "
missingAliasFor = "missing alias for nested type "
embedsUnexportedStruct = "Anonymously embeds unexported struct "
sealedInterface = "Applications can't implement this interface"
)
Expand Down Expand Up @@ -200,7 +200,7 @@ func (pkg Pkg) getText(start token.Pos, end token.Pos) string {
p := pkg.fs.Position(start)
// check if the file has been loaded, if not then load it
if _, ok := pkg.files[p.Filename]; !ok {
b, err := ioutil.ReadFile(p.Filename)
b, err := os.ReadFile(p.Filename)
if err != nil {
panic(err)
}
Expand Down
3 changes: 3 additions & 0 deletions src/go/cmd/testdata/test_alias_diagnostics/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module test_alias_diagnostics

go 1.18
20 changes: 20 additions & 0 deletions src/go/cmd/testdata/test_alias_diagnostics/internal/internal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package internal

type Widget struct {
OK bool
Value WidgetValue
MissingScalar WidgetProperties
MissingScalarP *WidgetPropertiesP
MissingSlice []WidgetThings
MissingSliceP []*WidgetThingsP
}

type WidgetProperties struct{}

type WidgetPropertiesP struct{}

type WidgetThings struct{}

type WidgetThingsP struct{}

type WidgetValue struct{}
9 changes: 9 additions & 0 deletions src/go/cmd/testdata/test_alias_diagnostics/test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package test_alias_diagnostics

import (
"test_alias_diagnostics/internal"
)

type Widget = internal.Widget

type WidgetValue = internal.WidgetValue
11 changes: 11 additions & 0 deletions src/go/cmd/token_makers.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,17 @@ func (f Func) ID() string {
return f.id
}

func (f Func) ForAlias(pkg string) Func {
clone := f
// replace everything to the left of - with the new package name
i := strings.Index(clone.id, "-")
if i < 0 {
panic("missing sig separator in id")
}
clone.id = pkg + clone.id[i:]
return clone
}

func (f Func) MakeTokens() []Token {
list := &[]Token{}
if f.embedded {
Expand Down