diff --git a/cmd/swag/main.go b/cmd/swag/main.go index c14226f24..b2242423e 100644 --- a/cmd/swag/main.go +++ b/cmd/swag/main.go @@ -23,6 +23,7 @@ const ( parseInternalFlag = "parseInternal" generatedTimeFlag = "generatedTime" parseDepthFlag = "parseDepth" + instanceNameFlag = "instanceName" ) var initFlags = []cli.Flag{ @@ -87,6 +88,11 @@ var initFlags = []cli.Flag{ Value: 100, Usage: "Dependency parse depth", }, + &cli.StringFlag{ + Name: instanceNameFlag, + Value: "", + Usage: "This parameter can be used to name different swagger document instances. It is optional.", + }, } func initAction(c *cli.Context) error { @@ -111,6 +117,7 @@ func initAction(c *cli.Context) error { GeneratedTime: c.Bool(generatedTimeFlag), CodeExampleFilesDir: c.String(codeExampleFilesFlag), ParseDepth: c.Int(parseDepthFlag), + InstanceName: c.String(instanceNameFlag), }) } diff --git a/gen/gen.go b/gen/gen.go index 01f611ef1..f405b3006 100644 --- a/gen/gen.go +++ b/gen/gen.go @@ -74,10 +74,18 @@ type Config struct { // ParseDepth dependency parse depth ParseDepth int + + // InstanceName is used to get distinct names for different swagger documents in the + // same project. The default value is "swagger". + InstanceName string } // Build builds swagger json file for given searchDir and mainAPIFile. Returns json func (g *Gen) Build(config *Config) error { + if config.InstanceName == "" { + config.InstanceName = swag.Name + } + searchDirs := strings.Split(config.SearchDir, ",") for _, searchDir := range searchDirs { if _, err := os.Stat(searchDir); os.IsNotExist(err) { @@ -233,6 +241,7 @@ func (g *Gen) writeGoDoc(packageName string, output io.Writer, swagger *spec.Swa Title string Description string Version string + InstanceName string }{ Timestamp: time.Now(), GeneratedTime: config.GeneratedTime, @@ -244,6 +253,7 @@ func (g *Gen) writeGoDoc(packageName string, output io.Writer, swagger *spec.Swa Title: swagger.Info.Title, Description: swagger.Info.Description, Version: swagger.Info.Version, + InstanceName: config.InstanceName, }) if err != nil { return err @@ -323,6 +333,6 @@ func (s *s) ReadDoc() string { } func init() { - swag.Register(swag.Name, &s{}) + swag.Register({{ printf "%q" .InstanceName }}, &s{}) } ` diff --git a/gen/gen_test.go b/gen/gen_test.go index 555fd8dd3..6fb8944c7 100644 --- a/gen/gen_test.go +++ b/gen/gen_test.go @@ -8,6 +8,7 @@ import ( "os/exec" "path/filepath" "plugin" + "strings" "testing" "github.com/go-openapi/spec" @@ -39,6 +40,53 @@ func TestGen_Build(t *testing.T) { } } +func TestGen_BuildInstanceName(t *testing.T) { + searchDir := "../testdata/simple" + + config := &Config{ + SearchDir: searchDir, + MainAPIFile: "./main.go", + OutputDir: "../testdata/simple/docs", + PropNamingStrategy: "", + } + assert.NoError(t, New().Build(config)) + + goSourceFile := filepath.Join(config.OutputDir, "docs.go") + + // Validate default registration name + expectedCode, err := ioutil.ReadFile(goSourceFile) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(expectedCode), "swag.Register(\"swagger\", &s{})") { + t.Fatal(errors.New("generated go code does not contain the correct default registration sequence")) + } + + // Custom name + config.InstanceName = "custom" + assert.NoError(t, New().Build(config)) + expectedCode, err = ioutil.ReadFile(goSourceFile) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(expectedCode), "swag.Register(\"custom\", &s{})") { + t.Fatal(errors.New("generated go code does not contain the correct registration sequence")) + } + + // cleanup + expectedFiles := []string{ + filepath.Join(config.OutputDir, "docs.go"), + filepath.Join(config.OutputDir, "swagger.json"), + filepath.Join(config.OutputDir, "swagger.yaml"), + } + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + t.Fatal(err) + } + _ = os.Remove(expectedFile) + } +} + func TestGen_BuildSnakecase(t *testing.T) { searchDir := "../testdata/simple2" config := &Config{ diff --git a/swagger.go b/swagger.go index 9adb1fcc1..a01d6dfdb 100644 --- a/swagger.go +++ b/swagger.go @@ -2,6 +2,7 @@ package swag import ( "errors" + "fmt" "sync" ) @@ -10,7 +11,7 @@ const Name = "swagger" var ( swaggerMu sync.RWMutex - swag Swagger + swags map[string]Swagger ) // Swagger is a interface to read swagger document. @@ -26,17 +27,35 @@ func Register(name string, swagger Swagger) { panic("swagger is nil") } - if swag != nil { + if swags == nil { + swags = make(map[string]Swagger) + } + + if _, ok := swags[name]; ok { panic("Register called twice for swag: " + name) } - swag = swagger + swags[name] = swagger } -// ReadDoc reads swagger document. -func ReadDoc() (string, error) { - if swag != nil { - return swag.ReadDoc(), nil +// ReadDoc reads swagger document. An optional name parameter can be passed to read a specific document. +// The default name is "swagger". +func ReadDoc(optionalName ...string) (string, error) { + swaggerMu.RLock() + defer swaggerMu.RUnlock() + + if swags == nil { + return "", errors.New("no swag has yet been registered") + } + + name := Name + if len(optionalName) != 0 && optionalName[0] != "" { + name = optionalName[0] + } + + swag, ok := swags[name] + if !ok { + return "", fmt.Errorf("no swag named \"%s\" was registered", name) } - return "", errors.New("not yet registered swag") + return swag.ReadDoc(), nil } diff --git a/swagger_test.go b/swagger_test.go index bc44488bf..6483ef8ad 100644 --- a/swagger_test.go +++ b/swagger_test.go @@ -162,12 +162,36 @@ func TestRegister(t *testing.T) { assert.Equal(t, doc, d) } +func TestRegisterByName(t *testing.T) { + setup() + Register("another_name", &s{}) + d, _ := ReadDoc("another_name") + assert.Equal(t, doc, d) +} + +func TestRegisterMultiple(t *testing.T) { + setup() + Register(Name, &s{}) + Register("another_name", &s{}) + d1, _ := ReadDoc(Name) + d2, _ := ReadDoc("another_name") + assert.Equal(t, doc, d1) + assert.Equal(t, doc, d2) +} + func TestReadDocBeforeRegistered(t *testing.T) { setup() _, err := ReadDoc() assert.Error(t, err) } +func TestReadDocWithInvalidName(t *testing.T) { + setup() + Register(Name, &s{}) + _, err := ReadDoc("invalid") + assert.Error(t, err) +} + func TestNilRegister(t *testing.T) { setup() var swagger Swagger @@ -185,5 +209,5 @@ func TestCalledTwicelRegister(t *testing.T) { } func setup() { - swag = nil + swags = nil }