Skip to content

Commit

Permalink
Improvements around aliased types (#3940)
Browse files Browse the repository at this point in the history
* Improvements around aliased types

Issue a diagnostic error if an exported, aliased type contains nested
types that aren't exported from the same package.
Include methods for aliased types in the source package.
Removed references to ioutil as it's been deprecated.

* skip unexported fields

* handle arrays

* fix happy line

* little more refinement

* refine per feedback

* array with pointer-to-type

* add tests
  • Loading branch information
jhendrixMSFT authored Aug 16, 2022
1 parent 7daac93 commit 4a83eff
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/go/cmd/api_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package cmd

import (
"encoding/json"
"io/ioutil"
"os"
"path/filepath"
"sort"
"strings"
Expand All @@ -19,7 +19,7 @@ func CreateAPIView(pkgDir, outputDir string) error {
}
filename := filepath.Join(outputDir, review.Name+".json")
file, _ := json.MarshalIndent(review, "", " ")
err = ioutil.WriteFile(filename, file, 0644)
err = os.WriteFile(filename, file, 0644)
if err != nil {
return err
}
Expand Down
40 changes: 35 additions & 5 deletions src/go/cmd/api_view_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package cmd

import (
"encoding/json"
"io/ioutil"
"os"
"path/filepath"
"testing"
Expand All @@ -23,7 +22,7 @@ func TestFuncDecl(t *testing.T) {
if err != nil {
t.Fatal(err)
}
file, err := ioutil.ReadFile("./output/testfuncdecl.json")
file, err := os.ReadFile("./output/testfuncdecl.json")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -51,7 +50,7 @@ func TestInterface(t *testing.T) {
if err != nil {
t.Fatal(err)
}
file, err := ioutil.ReadFile("./output/testinterface.json")
file, err := os.ReadFile("./output/testinterface.json")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -79,7 +78,7 @@ func TestStruct(t *testing.T) {
if err != nil {
t.Fatal(err)
}
file, err := ioutil.ReadFile("./output/teststruct.json")
file, err := os.ReadFile("./output/teststruct.json")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -107,7 +106,7 @@ func TestConst(t *testing.T) {
if err != nil {
t.Fatal(err)
}
file, err := ioutil.ReadFile("./output/testconst.json")
file, err := os.ReadFile("./output/testconst.json")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -265,3 +264,34 @@ func TestRecursiveAliasDefinitions(t *testing.T) {
})
}
}

func TestAliasDiagnostics(t *testing.T) {
review, err := createReview(filepath.Clean("testdata/test_alias_diagnostics"))
require.NoError(t, err)
require.Equal(t, "Go", review.Language)
require.Equal(t, "test_alias_diagnostics", review.Name)
require.Equal(t, 6, len(review.Diagnostics))
for _, diagnostic := range review.Diagnostics {
if diagnostic.TargetID == "test_alias_diagnostics.WidgetValue" {
require.Equal(t, DiagnosticLevelInfo, diagnostic.Level)
require.Equal(t, aliasFor+"internal.WidgetValue", diagnostic.Text)
} else {
require.Equal(t, "test_alias_diagnostics.Widget", diagnostic.TargetID)
switch diagnostic.Level {
case DiagnosticLevelInfo:
require.Equal(t, aliasFor+"internal.Widget", diagnostic.Text)
case DiagnosticLevelError:
switch txt := diagnostic.Text; txt {
case missingAliasFor + "WidgetProperties":
case missingAliasFor + "WidgetPropertiesP":
case missingAliasFor + "WidgetThings":
case missingAliasFor + "WidgetThingsP":
default:
t.Fatalf("unexpected diagnostic text %s", txt)
}
default:
t.Fatalf("unexpected diagnostic level %d", diagnostic.Level)
}
}
}
}
18 changes: 13 additions & 5 deletions src/go/cmd/content.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,9 @@ func (c *content) filterDeclarations(typ string, decls map[string]Declaration, t
return results
}

// searchForMethods takes the name of the receiver and looks for Funcs that are methods on that receiver.
func (c *content) searchForMethods(s string, tokenList *[]Token) {
// findMethods takes the name of the receiver and looks for Funcs that are methods on that receiver.
func (c *content) findMethods(s string) map[string]Func {
methods := map[string]Func{}
methodNames := []string{}
for key, fn := range c.Funcs {
name := fn.Name()
if unicode.IsLower(rune(name[0])) {
Expand All @@ -339,14 +338,23 @@ func (c *content) searchForMethods(s string, tokenList *[]Token) {
}
if s == n || "*"+s == n {
methods[key] = fn
methodNames = append(methodNames, key)
delete(c.Funcs, key)
}
}
return methods
}

// searchForMethods takes the name of the receiver and looks for Funcs that are methods on that receiver.
func (c *content) searchForMethods(s string, tokenList *[]Token) {
methods := c.findMethods(s)
methodNames := []string{}
for key := range methods {
methodNames = append(methodNames, key)
}
sort.Strings(methodNames)
for _, name := range methodNames {
fn := methods[name]
*tokenList = append(*tokenList, fn.MakeTokens()...)
delete(c.Funcs, name)
}
}

Expand Down
75 changes: 75 additions & 0 deletions src/go/cmd/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ func NewModule(dir string) (*Module, error) {
m := Module{Name: filepath.Base(dir), packages: map[string]*Pkg{}}

baseImportPath := path.Dir(mf.Module.Mod.Path) + "/"
if baseImportPath == "./" {
// this is a relative path in the tests, so remove this prefix.
// if not, then the package name added below won't match the imported packages.
baseImportPath = ""
}
err = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
if d.IsDir() {
if !indexTestdata && strings.Contains(path, "testdata") {
Expand Down Expand Up @@ -120,15 +125,39 @@ func NewModule(dir string) (*Module, error) {
t = p.c.addInterface(*def.p, alias, p.Name(), n)
case *ast.StructType:
t = p.c.addStruct(*def.p, alias, p.Name(), def.n)
hoistMethodsForType(source, alias, p)
// ensure that all struct field types that are structs are also aliased from this package
for _, field := range n.Fields.List {
fieldTypeName := unwrapStructFieldTypeName(field)
if fieldTypeName == "" {
// we can ignore this field
continue
}

// ensure that our package exports this type
if _, ok := p.typeAliases[fieldTypeName]; ok {
// found an alias
continue
}

// no alias, add a diagnostic
p.diagnostics = append(p.diagnostics, Diagnostic{
Level: DiagnosticLevelError,
TargetID: t.ID(),
Text: missingAliasFor + fieldTypeName,
})
}
case *ast.Ident:
t = p.c.addSimpleType(*p, alias, p.Name(), def.n.Type.(*ast.Ident).Name)
hoistMethodsForType(source, alias, p)
default:
fmt.Printf("unexpected node type %T\n", def.n.Type)
t = p.c.addSimpleType(*p, alias, p.Name(), originalName)
}
} else {
fmt.Println("found no definition for " + qn)
}

if t != nil {
p.diagnostics = append(p.diagnostics, Diagnostic{
Level: level,
Expand All @@ -141,6 +170,52 @@ func NewModule(dir string) (*Module, error) {
return &m, nil
}

// returns the type name for the specified struct field.
// if the field can be ignored, an empty string is returned.
func unwrapStructFieldTypeName(field *ast.Field) string {
if field.Names != nil && !field.Names[0].IsExported() {
// field isn't exported so skip it
return ""
}

// start with the field expression
exp := field.Type

// if it's an array, get the element expression.
// current codegen doesn't support *[]Type so no need to handle it.
if at, ok := exp.(*ast.ArrayType); ok {
// FieldName []FieldType
// FieldName []*FieldType
exp = at.Elt
}

// from here we either have a pointer-to-type or type
var ident *ast.Ident
if se, ok := exp.(*ast.StarExpr); ok {
// FieldName *FieldType
ident, _ = se.X.(*ast.Ident)
} else {
// FieldName FieldType
ident, _ = exp.(*ast.Ident)
}

// !IsExported() is a hacky way to ignore primitive types
// FieldName bool
if ident == nil || !ident.IsExported() {
return ""
}

// returns FieldType
return ident.Name
}

func hoistMethodsForType(pkg *Pkg, typeName string, target *Pkg) {
methods := pkg.c.findMethods(typeName)
for sig, fn := range methods {
target.c.Funcs[sig] = fn.ForAlias(target.Name())
}
}

func parseModFile(dir string) (*modfile.File, error) {
path := filepath.Join(dir, "go.mod")
content, err := os.ReadFile(path)
Expand Down
4 changes: 2 additions & 2 deletions src/go/cmd/pkg.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"go/ast"
"go/parser"
"go/token"
"io/ioutil"
"os"
"path/filepath"
"strings"
Expand All @@ -19,6 +18,7 @@ import (
// diagnostic messages
const (
aliasFor = "Alias for "
missingAliasFor = "missing alias for nested type "
embedsUnexportedStruct = "Anonymously embeds unexported struct "
sealedInterface = "Applications can't implement this interface"
)
Expand Down Expand Up @@ -200,7 +200,7 @@ func (pkg Pkg) getText(start token.Pos, end token.Pos) string {
p := pkg.fs.Position(start)
// check if the file has been loaded, if not then load it
if _, ok := pkg.files[p.Filename]; !ok {
b, err := ioutil.ReadFile(p.Filename)
b, err := os.ReadFile(p.Filename)
if err != nil {
panic(err)
}
Expand Down
3 changes: 3 additions & 0 deletions src/go/cmd/testdata/test_alias_diagnostics/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module test_alias_diagnostics

go 1.18
20 changes: 20 additions & 0 deletions src/go/cmd/testdata/test_alias_diagnostics/internal/internal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package internal

type Widget struct {
OK bool
Value WidgetValue
MissingScalar WidgetProperties
MissingScalarP *WidgetPropertiesP
MissingSlice []WidgetThings
MissingSliceP []*WidgetThingsP
}

type WidgetProperties struct{}

type WidgetPropertiesP struct{}

type WidgetThings struct{}

type WidgetThingsP struct{}

type WidgetValue struct{}
9 changes: 9 additions & 0 deletions src/go/cmd/testdata/test_alias_diagnostics/test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package test_alias_diagnostics

import (
"test_alias_diagnostics/internal"
)

type Widget = internal.Widget

type WidgetValue = internal.WidgetValue
11 changes: 11 additions & 0 deletions src/go/cmd/token_makers.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,17 @@ func (f Func) ID() string {
return f.id
}

func (f Func) ForAlias(pkg string) Func {
clone := f
// replace everything to the left of - with the new package name
i := strings.Index(clone.id, "-")
if i < 0 {
panic("missing sig separator in id")
}
clone.id = pkg + clone.id[i:]
return clone
}

func (f Func) MakeTokens() []Token {
list := &[]Token{}
if f.embedded {
Expand Down

0 comments on commit 4a83eff

Please sign in to comment.