Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openapi3: allow variables in schemes in gorillamux router + better server variables validation #337

Merged
merged 2 commits into from
Apr 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions openapi3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ func (server *Server) Validate(c context.Context) (err error) {
// ServerVariable is specified by OpenAPI/Swagger standard version 3.0.
type ServerVariable struct {
ExtensionProps
Enum []interface{} `json:"enum,omitempty" yaml:"enum,omitempty"`
Default interface{} `json:"default,omitempty" yaml:"default,omitempty"`
Description string `json:"description,omitempty" yaml:"description,omitempty"`
Enum []string `json:"enum,omitempty" yaml:"enum,omitempty"`
Default string `json:"default,omitempty" yaml:"default,omitempty"`
Description string `json:"description,omitempty" yaml:"description,omitempty"`
}

func (serverVariable *ServerVariable) MarshalJSON() ([]byte, error) {
Expand All @@ -164,17 +164,12 @@ func (serverVariable *ServerVariable) UnmarshalJSON(data []byte) error {
}

func (serverVariable *ServerVariable) Validate(c context.Context) error {
switch serverVariable.Default.(type) {
case float64, string, nil:
default:
return errors.New("value of default must be either a number or a string")
}
for _, item := range serverVariable.Enum {
switch item.(type) {
case float64, string:
default:
return errors.New("all 'enum' items must be either a number or a string")
if serverVariable.Default == "" {
data, err := serverVariable.MarshalJSON()
if err != nil {
return err
}
return fmt.Errorf("field default is required in %s", data)
}
return nil
}
9 changes: 5 additions & 4 deletions openapi3/swagger_loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,10 +514,11 @@ servers:
`{url: "http://{x}.{y}.example.com"}`: errors.New("invalid servers: server has undeclared variables"),
`{url: "http://{x}.y}.example.com"}`: errors.New("invalid servers: server URL has mismatched { and }"),
`{url: "http://{x.example.com"}`: errors.New("invalid servers: server URL has mismatched { and }"),
`{url: "http://{x}.example.com", variables: {x: {default: "www"}}}`: nil,
`{url: "http://{x}.example.com", variables: {x: {enum: ["www"]}}}`: nil,
`{url: "http://www.example.com", variables: {x: {enum: ["www"]}}}`: errors.New("invalid servers: server has undeclared variables"),
`{url: "http://{y}.example.com", variables: {x: {enum: ["www"]}}}`: errors.New("invalid servers: server has undeclared variables"),
`{url: "http://{x}.example.com", variables: {x: {default: "www"}}}`: nil,
`{url: "http://{x}.example.com", variables: {x: {default: "www", enum: ["www"]}}}`: nil,
`{url: "http://{x}.example.com", variables: {x: {enum: ["www"]}}}`: errors.New(`invalid servers: field default is required in {"enum":["www"]}`),
`{url: "http://www.example.com", variables: {x: {enum: ["www"]}}}`: errors.New("invalid servers: server has undeclared variables"),
`{url: "http://{y}.example.com", variables: {x: {enum: ["www"]}}}`: errors.New("invalid servers: server has undeclared variables"),
} {
t.Run(value, func(t *testing.T) {
loader := NewSwaggerLoader()
Expand Down
68 changes: 59 additions & 9 deletions routers/gorillamux/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,17 @@ type Router struct {
// TODO: Handle/HandlerFunc + ServeHTTP (When there is a match, the route variables can be retrieved calling mux.Vars(request))
func NewRouter(doc *openapi3.Swagger) (routers.Router, error) {
type srv struct {
scheme, host, base string
server *openapi3.Server
schemes []string
host, base string
server *openapi3.Server
}
servers := make([]srv, 0, len(doc.Servers))
for _, server := range doc.Servers {
u, err := url.Parse(bEncode(server.URL))
serverURL := server.URL
scheme0 := strings.Split(serverURL, "://")[0]
schemes := permutePart(scheme0, server)

u, err := url.Parse(bEncode(strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1)))
if err != nil {
return nil, err
}
Expand All @@ -42,10 +47,10 @@ func NewRouter(doc *openapi3.Swagger) (routers.Router, error) {
path = path[:len(path)-1]
}
servers = append(servers, srv{
host: bDecode(u.Host), //u.Hostname()?
base: path,
scheme: bDecode(u.Scheme),
server: server,
host: bDecode(u.Host), //u.Hostname()?
base: path,
schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624
server: server,
})
}
if len(servers) == 0 {
Expand All @@ -65,8 +70,8 @@ func NewRouter(doc *openapi3.Swagger) (routers.Router, error) {

for _, s := range servers {
muxRoute := muxRouter.Path(s.base + path).Methods(methods...)
if scheme := s.scheme; scheme != "" {
muxRoute.Schemes(scheme)
if schemes := s.schemes; len(schemes) != 0 {
muxRoute.Schemes(schemes...)
}
if host := s.host; host != "" {
muxRoute.Host(host)
Expand Down Expand Up @@ -151,3 +156,48 @@ func bDecode(s string) string {
s = strings.Replace(s, brURL, "}", -1)
return s
}

func permutePart(part0 string, srv *openapi3.Server) []string {
type mapAndSlice struct {
m map[string]struct{}
s []string
}
var2val := make(map[string]mapAndSlice)
max := 0
for name0, v := range srv.Variables {
name := "{" + name0 + "}"
if !strings.Contains(part0, name) {
continue
}
m := map[string]struct{}{v.Default: {}}
for _, value := range v.Enum {
m[value] = struct{}{}
}
if l := len(m); l > max {
max = l
}
s := make([]string, 0, len(m))
for value := range m {
s = append(s, value)
}
var2val[name] = mapAndSlice{m: m, s: s}
}
if len(var2val) == 0 {
return []string{part0}
}

partsMap := make(map[string]struct{}, max*len(var2val))
for i := 0; i < max; i++ {
part := part0
for name, mas := range var2val {
part = strings.Replace(part, name, mas.s[i%len(mas.s)], -1)
}
partsMap[part] = struct{}{}
}
parts := make([]string, 0, len(partsMap))
for part := range partsMap {
parts = append(parts, part)
}
sort.Strings(parts)
return parts
}
22 changes: 19 additions & 3 deletions routers/gorillamux/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ func TestRouter(t *testing.T) {

doc.Servers = []*openapi3.Server{
{URL: "https://www.example.com/api/v1"},
{URL: "https://{d0}.{d1}.com/api/v1/", Variables: map[string]*openapi3.ServerVariable{
"d0": {Default: "www"},
"d1": {Enum: []interface{}{"example"}},
{URL: "{scheme}://{d0}.{d1}.com/api/v1/", Variables: map[string]*openapi3.ServerVariable{
"d0": {Default: "www"},
"d1": {Default: "example", Enum: []string{"example"}},
"scheme": {Default: "https", Enum: []string{"https", "http"}},
}},
}
err = doc.Validate(context.Background())
Expand All @@ -176,6 +177,7 @@ func TestRouter(t *testing.T) {
expect(r, http.MethodGet, "https://domain0.domain1.com/api/v1/hello", helloGET, map[string]string{
"d0": "domain0",
"d1": "domain1",
// "scheme": "https", TODO: https://github.com/gorilla/mux/issues/624
})

{
Expand All @@ -190,3 +192,17 @@ func TestRouter(t *testing.T) {
require.Nil(t, pathParams)
}
}

func TestPermuteScheme(t *testing.T) {
scheme0 := "{sche}{me}"
server := &openapi3.Server{URL: scheme0 + "://{d0}.{d1}.com/api/v1/", Variables: map[string]*openapi3.ServerVariable{
"d0": {Default: "www"},
"d1": {Default: "example", Enum: []string{"example"}},
"sche": {Default: "http"},
"me": {Default: "s", Enum: []string{"", "s"}},
}}
err := server.Validate(context.Background())
require.NoError(t, err)
perms := permutePart(scheme0, server)
require.Equal(t, []string{"http", "https"}, perms)
}
2 changes: 1 addition & 1 deletion routers/legacy/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func TestRouter(t *testing.T) {
{URL: "https://www.example.com/api/v1"},
{URL: "https://{d0}.{d1}.com/api/v1/", Variables: map[string]*openapi3.ServerVariable{
"d0": {Default: "www"},
"d1": {Enum: []interface{}{"example"}},
"d1": {Default: "example", Enum: []string{"example"}},
}},
}
err = doc.Validate(context.Background())
Expand Down