Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix when embedded is interface with type params #99

Merged
merged 9 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ func typeSpecs(f *ast.File) []*ast.TypeSpec {
return result
}

func getEmbeddedMethods(t ast.Expr, pr typePrinter, input targetProcessInput) (param genericParam, methods methodsList, err error) {
func getEmbeddedMethods(t ast.Expr, pr typePrinter, input targetProcessInput, checkInterface bool) (param genericParam, methods methodsList, err error) {
param.Name, err = pr.PrintType(t)
if err != nil {
return
Expand All @@ -471,13 +471,13 @@ func getEmbeddedMethods(t ast.Expr, pr typePrinter, input targetProcessInput) (p
return

case *ast.Ident:
methods, err = processIdent(v, input)
methods, err = processIdent(v, input, checkInterface)
return
}
return
}

func processEmbedded(t ast.Expr, pr typePrinter, input targetProcessInput) (genericParam genericParam, embeddedMethods methodsList, err error) {
func processEmbedded(t ast.Expr, pr typePrinter, input targetProcessInput, checkInterface bool) (genericParam genericParam, embeddedMethods methodsList, err error) {
var x ast.Expr
var hasGenericsParams bool
var genericParams genericParams
Expand All @@ -486,8 +486,12 @@ func processEmbedded(t ast.Expr, pr typePrinter, input targetProcessInput) (gene
case *ast.IndexExpr:
x = v.X
hasGenericsParams = true

genericParam, _, err = processEmbedded(v.Index, pr, input)
// Don't check if embedded interface's generic params are also interfaces, e.g. given the interface:
// type SomeInterface {
// EmbeddedGenericInterface[Bar]
// }
// we won't be checking if Bar is also an interface
genericParam, _, err = processEmbedded(v.Index, pr, input, false)
if err != nil {
return
}
Expand All @@ -501,7 +505,12 @@ func processEmbedded(t ast.Expr, pr typePrinter, input targetProcessInput) (gene

if v.Indices != nil {
for _, index := range v.Indices {
genericParam, _, err = processEmbedded(index, pr, input)
// Don't check if embedded interface's generic params are also interfaces, e.g. given the interface:
// type SomeInterface {
// EmbeddedGenericInterface[Bar]
// }
// we won't be checking if Bar is also an interface
genericParam, _, err = processEmbedded(index, pr, input, false)
if err != nil {
return
}
Expand All @@ -515,7 +524,7 @@ func processEmbedded(t ast.Expr, pr typePrinter, input targetProcessInput) (gene
}

input.genericParams = genericParams
genericParam, embeddedMethods, err = getEmbeddedMethods(x, pr, input)
genericParam, embeddedMethods, err = getEmbeddedMethods(x, pr, input, checkInterface)
if err != nil {
return
}
Expand Down Expand Up @@ -551,7 +560,7 @@ func processInterface(it *ast.InterfaceType, targetInput targetProcessInput) (me
}

default:
_, embeddedMethods, err = processEmbedded(v, pr, targetInput)
_, embeddedMethods, err = processEmbedded(v, pr, targetInput, true)
}

if err != nil {
Expand Down Expand Up @@ -618,19 +627,23 @@ func mergeMethods(methods, embeddedMethods methodsList) (methodsList, error) {

var errNotAnInterface = errors.New("embedded type is not an interface")

func processIdent(i *ast.Ident, input targetProcessInput) (methodsList, error) {
func processIdent(i *ast.Ident, input targetProcessInput, checkInterface bool) (methodsList, error) {
var embeddedInterface *ast.InterfaceType
var genericsTypes genericTypes
for _, t := range input.types {
if t.Name.Name == i.Name {
var ok bool
embeddedInterface, ok = t.Type.(*ast.InterfaceType)
if !ok {
return nil, errors.Wrap(errNotAnInterface, t.Name.Name)
if ok {
genericsTypes = buildGenericTypesFromSpec(t, input.types, input.typesPrefix)
break
}

if !checkInterface {
break
}

genericsTypes = buildGenericTypesFromSpec(t, input.types, input.typesPrefix)
break
return nil, errors.Wrap(errNotAnInterface, t.Name.Name)
}
}

Expand Down
20 changes: 17 additions & 3 deletions generator/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ func Test_findImportPathForName(t *testing.T) {

func Test_processIdent(t *testing.T) {
type args struct {
i *ast.Ident
input targetProcessInput
i *ast.Ident
input targetProcessInput
toCheckForInterface bool
}
tests := []struct {
name string
Expand All @@ -129,19 +130,32 @@ func Test_processIdent(t *testing.T) {
input: targetProcessInput{
types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "name"}, Type: &ast.StructType{}}},
},
toCheckForInterface: true,
},
wantErr: true,
inspectErr: func(err error, t *testing.T) {
assert.Equal(t, errNotAnInterface, errors.Cause(err))
},
},
{
name: "not an interface but no need to check",
args: args{
i: &ast.Ident{Name: "name"},
input: targetProcessInput{
types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "name"}, Type: &ast.StructType{}}},
},
toCheckForInterface: false,
},
wantErr: false,
},
{
name: "embedded interface found",
args: args{
i: &ast.Ident{Name: "name"},
input: targetProcessInput{
types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "name"}, Type: &ast.InterfaceType{}}},
},
toCheckForInterface: true,
},
wantErr: false,
},
Expand All @@ -152,7 +166,7 @@ func Test_processIdent(t *testing.T) {
mc := minimock.NewController(t)
defer mc.Wait(time.Second)

got1, err := processIdent(tt.args.i, tt.args.input)
got1, err := processIdent(tt.args.i, tt.args.input, tt.args.toCheckForInterface)

assert.Equal(t, tt.want1, got1, "processIdent returned unexpected result")

Expand Down
41 changes: 41 additions & 0 deletions printer/printer.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ func (p *Printer) PrintType(node ast.Node) (string, error) {
return p.printStruct(t)
case *ast.Ident:
return p.printIdent(t)
case *ast.IndexExpr:
return p.printGeneric(t)
case *ast.IndexListExpr:
return p.printGenericList(t)
}

err := printer.Fprint(p.buf, p.fs, node)
Expand Down Expand Up @@ -151,6 +155,43 @@ func (p *Printer) printIdent(i *ast.Ident) (string, error) {
return p.buf.String(), err
}

func (p *Printer) printGeneric(pt *ast.IndexExpr) (string, error) {
t, err := p.PrintType(pt.X)
if err != nil {
return "", err
}

generic, err := p.PrintType(pt.Index)
if err != nil {
return "", err
}

return t + "[" + generic + "]", nil
}

func (p *Printer) printGenericList(pt *ast.IndexListExpr) (string, error) {
t, err := p.PrintType(pt.X)
if err != nil {
return "", err
}

baseStr := t + "["
for i, expr := range pt.Indices {
generic, err := p.PrintType(expr)
if err != nil {
return "", err
}

if i == len(pt.Indices)-1 {
baseStr = baseStr + generic + "]"
} else {
baseStr = baseStr + generic + ", "
}
}

return baseStr, nil
}

func (p *Printer) printPointer(pt *ast.StarExpr) (string, error) {
pointerTo, err := p.PrintType(pt.X)
if err != nil {
Expand Down
Loading
Loading