Skip to content

Commit

Permalink
Merge pull request #38 from mkraft/package-name
Browse files Browse the repository at this point in the history
Removes import paths that match target package name.
  • Loading branch information
vburenin authored May 31, 2019
2 parents c5f707b + 7533e84 commit b7104cb
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 13 deletions.
68 changes: 64 additions & 4 deletions ifacemaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"bytes"
"fmt"
"io"
"os"
"testing"
Expand Down Expand Up @@ -68,18 +69,56 @@ func SomeFunction() string {
return "Something"
}`

var src2 = `package maker
import (
"github.com/vburenin/ifacemaker/maker/footest"
)
type TestImpl struct{}
func (s *TestImpl) GetUser(userID string) *footest.User {
return &footest.User{}
}
func (s *TestImpl) CreateUser(user *footest.User) (*footest.User, error) {
return &footest.User{}, nil
}
func (s *TestImpl) fooHelper() string {
return ""
}`

var src3 = `package footest
type User struct {
ID string
Name string
}`

var srcFile = os.TempDir() + "/ifacemaker_src.go"
var srcFile2 = os.TempDir() + "/test_impl.go"
var srcFile3 = os.TempDir() + "/footest/footest.go"

func TestMain(m *testing.M) {
writeTestSourceFile()
dirPath := os.TempDir() + "/footest"
if _, err := os.Stat(dirPath); os.IsNotExist(err) {
err := os.Mkdir(dirPath, os.ModePerm)
if err != nil {
panic(fmt.Sprintf("Failed to create directory: %s", err))
}
}
writeTestSourceFile(src, srcFile)
writeTestSourceFile(src2, srcFile2)
writeTestSourceFile(src3, srcFile3)

os.Exit(m.Run())
}

func writeTestSourceFile() {
f, err := os.OpenFile(srcFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm)
func writeTestSourceFile(src, path string) {
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm)
if err != nil {
panic("Failed to open test source file.")
panic(fmt.Sprintf("Failed to open test source file: %s", err))
}
defer f.Close()
_, err = f.WriteString(src)
Expand Down Expand Up @@ -216,6 +255,27 @@ type PersonIface interface {
assert.Equal(t, expected, out)
}

func TestMainDoNotImportPackageName(t *testing.T) {
os.Args = []string{"cmd", "-f", srcFile2, "-s", "TestImpl", "-p", "footest", "-c", "DO NOT EDIT: Auto generated", "-i", "TestInterface", "-d=false"}
out := captureStdout(func() {
main()
})

expected := `// DO NOT EDIT: Auto generated
package footest
// TestInterface ...
type TestInterface interface {
GetUser(userID string) *User
CreateUser(user *User) (*User, error)
}
`

assert.Equal(t, expected, out)
}

// not thread safe
func captureStdout(f func()) string {
old := os.Stdout
Expand Down
11 changes: 6 additions & 5 deletions maker/maker.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func GetReceiverType(fd *ast.FuncDecl) (ast.Expr, error) {
// param or return value as a single string.
// If the FieldList input is nil, it returns
// nil
func FormatFieldList(src []byte, fl *ast.FieldList) []string {
func FormatFieldList(src []byte, fl *ast.FieldList, pkgName string) []string {
if fl == nil {
return nil
}
Expand All @@ -97,6 +97,7 @@ func FormatFieldList(src []byte, fl *ast.FieldList) []string {
names[i] = n.Name
}
t := string(src[l.Type.Pos()-1 : l.Type.End()-1])
t = strings.Replace(t, pkgName+".", "", -1)
if len(names) > 0 {
typeSharingArgs := strings.Join(names, ", ")
parts = append(parts, fmt.Sprintf("%s %s", typeSharingArgs, t))
Expand Down Expand Up @@ -164,7 +165,7 @@ func MakeInterface(comment, pkgName, ifaceName, ifaceComment string, methods []s
// not, the imports not used will be removed later using the
// 'imports' pkg If anything goes wrong, this method will
// fatally stop the execution
func ParseStruct(src []byte, structName string, copyDocs bool, copyTypeDocs bool) (methods []Method, imports []string, typeDoc string) {
func ParseStruct(src []byte, structName string, copyDocs bool, copyTypeDocs bool, pkgName string) (methods []Method, imports []string, typeDoc string) {
fset := token.NewFileSet()
a, err := parser.ParseFile(fset, "", src, parser.ParseComments)
if err != nil {
Expand All @@ -184,8 +185,8 @@ func ParseStruct(src []byte, structName string, copyDocs bool, copyTypeDocs bool
if !fd.Name.IsExported() {
continue
}
params := FormatFieldList(src, fd.Type.Params)
ret := FormatFieldList(src, fd.Type.Results)
params := FormatFieldList(src, fd.Type.Params, pkgName)
ret := FormatFieldList(src, fd.Type.Results, pkgName)
method := fmt.Sprintf("%s(%s) (%s)", fd.Name.String(), strings.Join(params, ", "), strings.Join(ret, ", "))
var docs []string
if fd.Doc != nil && copyDocs {
Expand Down Expand Up @@ -224,7 +225,7 @@ func Make(files []string, structType, comment, pkgName, ifaceName, ifaceComment
if err != nil {
return nil, err
}
methods, imports, parsedTypeDoc := ParseStruct(src, structType, copyDocs, copyTypeDoc)
methods, imports, parsedTypeDoc := ParseStruct(src, structType, copyDocs, copyTypeDoc, pkgName)
for _, m := range methods {
if _, ok := mset[m.Code]; !ok {
allMethods = append(allMethods, m.Lines()...)
Expand Down
8 changes: 4 additions & 4 deletions maker/maker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestLines(t *testing.T) {
}

func TestParseStruct(t *testing.T) {
methods, imports, typeDoc := ParseStruct(src, "Person", true, true)
methods, imports, typeDoc := ParseStruct(src, "Person", true, true, "")

assert.Equal(t, "Name() (string)", methods[0].Code)

Expand Down Expand Up @@ -124,8 +124,8 @@ func TestFormatFieldList(t *testing.T) {
for _, d := range a.Decls {
if a, fd := GetReceiverTypeName(src, d); a == "Person" {
methodName := fd.Name.String()
params := FormatFieldList(src, fd.Type.Params)
results := FormatFieldList(src, fd.Type.Results)
params := FormatFieldList(src, fd.Type.Params, "")
results := FormatFieldList(src, fd.Type.Results, "")

var expectedParams []string
var expectedResults []string
Expand Down Expand Up @@ -154,7 +154,7 @@ func TestFormatFieldList(t *testing.T) {
}

func TestNoCopyTypeDocs(t *testing.T) {
_, _, typeDoc := ParseStruct(src, "Person", true, false)
_, _, typeDoc := ParseStruct(src, "Person", true, false, "")
assert.Equal(t, "", typeDoc)
}

Expand Down

0 comments on commit b7104cb

Please sign in to comment.