Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hyphenated file + include-as support #423

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions compile/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package compile

import (
"path/filepath"
"strings"

"go.uber.org/thriftrw/ast"
"go.uber.org/thriftrw/idl"
Expand Down Expand Up @@ -240,14 +241,18 @@ func (c compiler) gather(m *Module, prog *ast.Program) error {
}

// include loads the file specified by the given include in the given Module.
//
// The path to the file is relative to the ThriftPath of the given module.
func (c compiler) include(m *Module, include *ast.Include) (*IncludedModule, error) {
if len(include.Name) > 0 {
// TODO(abg): Add support for include-as flag somewhere.
return nil, includeError{
Include: include,
Reason: includeAsDisabledError{},
includeName := include.Name
// include.Name has include-as name
if len(include.Name) == 0 {
includeName = fileBaseName(include.Path)
// if hyphenated file, include-as is necessary
if strings.Contains(includeName, "-") {
return nil, includeError{
Include: include,
Reason: includeAsNeededInHyphenatedFile{},
}
}
}

Expand All @@ -257,5 +262,5 @@ func (c compiler) include(m *Module, include *ast.Include) (*IncludedModule, err
return nil, includeError{Include: include, Reason: err}
}

return &IncludedModule{Name: fileBaseName(include.Path), Module: incM}, nil
return &IncludedModule{Name: includeName, Module: incM}, nil
}
10 changes: 5 additions & 5 deletions compile/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ func (e fileCompileError) Error() string {
return fmt.Sprintf("could not compile file %q: %v", e.Path, e.Reason)
}

// includeAsDisabledError is raised when the user attempts to use the include-as
// syntax without explicitly enabling it.
type includeAsDisabledError struct{}
// includeAsNeededInHyphenatedFile is raised when the user attempts to use the include-as
// syntax when file is hyphenated.
type includeAsNeededInHyphenatedFile struct{}

func (e includeAsDisabledError) Error() string {
return "include-as syntax is currently disabled"
func (e includeAsNeededInHyphenatedFile) Error() string {
return "include-as is needed when file is hyphenated"
}

// includeError is raised when there is an error including another Thrift
Expand Down
2 changes: 1 addition & 1 deletion compile/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func capitalize(s string) string {
return string(unicode.ToUpper(x)) + string(s[i:])
}

// fileBaseName returns the base name of the given file without the extension.
// fileBaseName returns the normalized base name of the given file without the extension.
func fileBaseName(p string) string {
return strings.TrimSuffix(filepath.Base(p), filepath.Ext(p))
}
Expand Down
50 changes: 46 additions & 4 deletions gen/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ func Generate(m *compile.Module, o *Options) error {
genBuilder := newGenerateServiceBuilder(importer)

generate := func(m *compile.Module) error {
if err := isDuplicateFileAfterNormalization(m.ThriftPath); err != nil {
return err
}

path, contents, err := generateModule(m, importer, genBuilder, o)
if err != nil {
return generateError{Name: m.ThriftPath, Reason: err}
Expand All @@ -113,7 +117,6 @@ func Generate(m *compile.Module, o *Options) error {
if err := addFile(files, path, contents); err != nil {
return generateError{Name: m.ThriftPath, Reason: err}
}

return nil
}

Expand Down Expand Up @@ -165,6 +168,35 @@ func Generate(m *compile.Module, o *Options) error {
return nil
}

func isDuplicateFileAfterNormalization(f string) error {
absFileName, err := filepath.Abs(f)
if err != nil {
return generateError{Name: f, Reason: fmt.Errorf("File %q does not exist: %v", f,
err)}
}
normalizedFileName := filePathWithUnderscore(absFileName)
if fileNormalized(absFileName, normalizedFileName) && fileExists(normalizedFileName) {
return generateError{Name: f,
Reason: fmt.Errorf("File after normalization %q is colliding with existing file %q in the path, with error: %v",
normalizedFileName, absFileName, err)}
}
return nil
}

func fileNormalized(f1, f2 string) bool {
return !(f1 == f2)
}

// fileExists checks if a file exists and is not a directory before we
// try using it to prevent further errors.
func fileExists(filename string) bool {
info, err := os.Stat(filename)
if os.IsNotExist(err) {
return false
}
return !info.IsDir()
}

// ThriftPackageImporter determines import paths from a Thrift root.
type ThriftPackageImporter interface {
// RelativePackage returns the import path for the top-level package of the
Expand All @@ -186,7 +218,7 @@ type thriftPackageImporter struct {
}

func (i thriftPackageImporter) RelativePackage(file string) (string, error) {
return filepath.Rel(i.ThriftRoot, strings.TrimSuffix(file, ".thrift"))
return filepath.Rel(i.ThriftRoot, strings.TrimSuffix(filePathWithUnderscore(file), ".thrift"))
}

func (i thriftPackageImporter) RelativeThriftFilePath(file string) (string, error) {
Expand Down Expand Up @@ -226,12 +258,14 @@ func generateModule(
builder *generateServiceBuilder,
o *Options,
) (outputFilepath string, contents []byte, err error) {
// converts file from /home/abc/ab-def.thrift to /home/abc/ab_def.thrift for golang code generation
normalizedThriftPath := filePathWithUnderscore(m.ThriftPath)
// packageRelPath is the path relative to outputDir into which we'll be
// writing the package for this Thrift file. For $thriftRoot/foo/bar.thrift,
// packageRelPath is foo/bar, and packageDir is $outputDir/foo/bar. All
// files for bar.thrift will be written to the $outputDir/foo/bar/ tree. The
// package will be importable via $importPrefix/foo/bar.
packageRelPath, err := i.RelativePackage(m.ThriftPath)
packageRelPath, err := i.RelativePackage(normalizedThriftPath)
if err != nil {
return "", nil, err
}
Expand All @@ -248,7 +282,7 @@ func generateModule(

// importPath is the full import path for the top-level package generated
// for this Thrift file.
importPath, err := i.Package(m.ThriftPath)
importPath, err := i.Package(normalizedThriftPath)
if err != nil {
return "", nil, err
}
Expand Down Expand Up @@ -321,3 +355,11 @@ func generateModule(

return outputFilepath, buff.Bytes(), nil
}

func replaceHyphenWithUnderscore(str string) string {
return strings.Replace(str, "-", "_", -1)
}

func filePathWithUnderscore(p string) string {
return filepath.Join(filepath.Dir(p), replaceHyphenWithUnderscore(filepath.Base(p)))
}
45 changes: 45 additions & 0 deletions gen/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,51 @@ func TestGenerateWithRelativePaths(t *testing.T) {
}
}

func TestGenerateWithHyphenPaths(t *testing.T) {
outputDir, err := ioutil.TempDir("", "thriftrw-generate-test")
require.NoError(t, err)
defer os.RemoveAll(outputDir)

thriftRoot, err := os.Getwd()
require.NoError(t, err)

tests := []struct {
filepath string
expectedError bool
}{
{
filepath: "internal/tests/thrift/include_hyphen_files.thrift",
expectedError: false,
},
{
filepath: "internal/tests/thrift/abc-defs.thrift",
expectedError: false,
},
{
filepath: "internal/tests/thrift/nestedfiles_conflict/include_hyphen_files_nest.thrift",
expectedError: true,
},
}

for _, test := range tests {
module, err := compile.Compile(test.filepath)
require.NoError(t, err)

opt := &Options{
OutputDir: outputDir,
PackagePrefix: "go.uber.org/thriftrw/gen",
ThriftRoot: thriftRoot,
}
err = Generate(module, opt)
if test.expectedError {
assert.Error(t, err, "expected code generation with filepath %v to fail", test.filepath)
assert.Contains(t, err.Error(), "is colliding with existing file")
} else {
assert.NoError(t, err)
}
}
}

func TestGenerate(t *testing.T) {
var (
ts compile.TypeSpec = &compile.TypedefSpec{
Expand Down
4 changes: 2 additions & 2 deletions gen/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func (g *generator) LookupTypeName(t compile.TypeSpec) (string, error) {
"LookupTypeName called with native type (%T) %v", t, t)
}

importPath, err := g.thriftImporter.Package(t.ThriftFile())
importPath, err := g.thriftImporter.Package(filePathWithUnderscore(t.ThriftFile()))
if err != nil {
return "", err
}
Expand All @@ -200,7 +200,7 @@ func (g *generator) LookupTypeName(t compile.TypeSpec) (string, error) {
}

func (g *generator) LookupConstantName(c *compile.Constant) (string, error) {
importPath, err := g.thriftImporter.Package(c.File)
importPath, err := g.thriftImporter.Package(filePathWithUnderscore(c.File))
if err != nil {
return "", err
}
Expand Down
2 changes: 1 addition & 1 deletion gen/golden_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestCodeIsUpToDate(t *testing.T) {
defer os.RemoveAll(outputDir)

for _, thriftFile := range thriftFiles {
pkgRelPath := strings.TrimSuffix(filepath.Base(thriftFile), ".thrift")
pkgRelPath := strings.TrimSuffix(filePathWithUnderscore(filepath.Base(thriftFile)), ".thrift")
currentPackageDir := filepath.Join("internal/tests", pkgRelPath)
newPackageDir := filepath.Join(outputDir, pkgRelPath)

Expand Down
3 changes: 2 additions & 1 deletion gen/internal/tests/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ nozap: thrift/nozap.thrift $(THRIFTRW)
$(THRIFTRW) --no-recurse --no-zap $<

%: thrift/%.thrift $(THRIFTRW)
$(THRIFTRW) --no-recurse $<
$(THRIFTRW) $<

Loading