From 932959088011c5b5020dbbf5ebe764f91378b43a Mon Sep 17 00:00:00 2001 From: James Hamlin Date: Sat, 15 Jun 2019 13:28:06 -0700 Subject: [PATCH] Support colon in final path segment, last match wins behavior (behind flags) Signed-off-by: James Hamlin --- .../descriptor/registry.go | 14 ++++ .../gengateway/template.go | 8 +- .../gengateway/template_test.go | 4 +- protoc-gen-grpc-gateway/main.go | 2 + runtime/mux.go | 13 +++- runtime/mux_test.go | 78 ++++++++++++++++--- runtime/pattern.go | 48 ++++++++++-- 7 files changed, 146 insertions(+), 21 deletions(-) diff --git a/protoc-gen-grpc-gateway/descriptor/registry.go b/protoc-gen-grpc-gateway/descriptor/registry.go index 1131ca453f9..2f056364b8f 100644 --- a/protoc-gen-grpc-gateway/descriptor/registry.go +++ b/protoc-gen-grpc-gateway/descriptor/registry.go @@ -67,6 +67,10 @@ type Registry struct { // If false, the default behavior is to concat the last 2 elements of the FQN if they are unique, otherwise concat // all the elements of the FQN without any separator useFQNForSwaggerName bool + + // allowColonFinalSegments determines whether colons are permitted + // in the final segment of a path. + allowColonFinalSegments bool } type repeatedFieldSeparator struct { @@ -422,6 +426,16 @@ func (r *Registry) SetUseFQNForSwaggerName(use bool) { r.useFQNForSwaggerName = use } +// GetAllowColonFinalSegments returns allowColonFinalSegments +func (r *Registry) GetAllowColonFinalSegments() bool { + return r.allowColonFinalSegments +} + +// SetAllowColonFinalSegments sets allowColonFinalSegments +func (r *Registry) SetAllowColonFinalSegments(use bool) { + r.allowColonFinalSegments = use +} + // GetUseFQNForSwaggerName returns useFQNForSwaggerName func (r *Registry) GetUseFQNForSwaggerName() bool { return r.useFQNForSwaggerName diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index a14de379642..d5a4980d65c 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -136,6 +136,7 @@ type trailerParams struct { Services []*descriptor.Service UseRequestContext bool RegisterFuncSuffix string + AssumeColonVerb bool } func applyTemplate(p param, reg *descriptor.Registry) (string, error) { @@ -176,10 +177,15 @@ func applyTemplate(p param, reg *descriptor.Registry) (string, error) { return "", errNoTargetService } + assumeColonVerb := true + if reg != nil { + assumeColonVerb = !reg.GetAllowColonFinalSegments() + } tp := trailerParams{ Services: targetServices, UseRequestContext: p.UseRequestContext, RegisterFuncSuffix: p.RegisterFuncSuffix, + AssumeColonVerb: assumeColonVerb, } if err := trailerTemplate.Execute(w, tp); err != nil { return "", err @@ -517,7 +523,7 @@ func (m response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}) XXX_ResponseBody( var ( {{range $m := $svc.Methods}} {{range $b := $m.Bindings}} - pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = runtime.MustPattern(runtime.NewPattern({{$b.PathTmpl.Version}}, {{$b.PathTmpl.OpCodes | printf "%#v"}}, {{$b.PathTmpl.Pool | printf "%#v"}}, {{$b.PathTmpl.Verb | printf "%q"}})) + pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = runtime.MustPattern(runtime.NewPattern({{$b.PathTmpl.Version}}, {{$b.PathTmpl.OpCodes | printf "%#v"}}, {{$b.PathTmpl.Pool | printf "%#v"}}, {{$b.PathTmpl.Verb | printf "%q"}}, runtime.AssumeColonVerbOpt({{$.AssumeColonVerb}}))) {{end}} {{end}} ) diff --git a/protoc-gen-grpc-gateway/gengateway/template_test.go b/protoc-gen-grpc-gateway/gengateway/template_test.go index 5e287a680d1..fbf54e6a69d 100644 --- a/protoc-gen-grpc-gateway/gengateway/template_test.go +++ b/protoc-gen-grpc-gateway/gengateway/template_test.go @@ -242,7 +242,7 @@ func TestApplyTemplateRequestWithoutClientStreaming(t *testing.T) { if want := `func RegisterExampleServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {`; !strings.Contains(got, want) { t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want) } - if want := `pattern_ExampleService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), ""))`; !strings.Contains(got, want) { + if want := `pattern_ExampleService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), "", runtime.AssumeColonVerbOpt(true)))`; !strings.Contains(got, want) { t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want) } } @@ -394,7 +394,7 @@ func TestApplyTemplateRequestWithClientStreaming(t *testing.T) { if want := `func RegisterExampleServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {`; !strings.Contains(got, want) { t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want) } - if want := `pattern_ExampleService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), ""))`; !strings.Contains(got, want) { + if want := `pattern_ExampleService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), "", runtime.AssumeColonVerbOpt(true)))`; !strings.Contains(got, want) { t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want) } } diff --git a/protoc-gen-grpc-gateway/main.go b/protoc-gen-grpc-gateway/main.go index 60d9de92e6b..291ba7deb2f 100644 --- a/protoc-gen-grpc-gateway/main.go +++ b/protoc-gen-grpc-gateway/main.go @@ -33,6 +33,7 @@ var ( allowRepeatedFieldsInBody = flag.Bool("allow_repeated_fields_in_body", false, "allows to use repeated field in `body` and `response_body` field of `google.api.http` annotation option") repeatedPathParamSeparator = flag.String("repeated_path_param_separator", "csv", "configures how repeated fields should be split. Allowed values are `csv`, `pipes`, `ssv` and `tsv`.") allowPatchFeature = flag.Bool("allow_patch_feature", true, "determines whether to use PATCH feature involving update masks (using google.protobuf.FieldMask).") + allowColonFinalSegments = flag.Bool("allow_colon_final_segments", false, "determines whether colons are permitted in the final segment of a path") versionFlag = flag.Bool("version", false, "print the current verison") ) @@ -93,6 +94,7 @@ func main() { reg.SetImportPath(*importPath) reg.SetAllowDeleteBody(*allowDeleteBody) reg.SetAllowRepeatedFieldsInBody(*allowRepeatedFieldsInBody) + reg.SetAllowColonFinalSegments(*allowColonFinalSegments) if err := reg.SetRepeatedPathParamSeparator(*repeatedPathParamSeparator); err != nil { emitError(err) return diff --git a/runtime/mux.go b/runtime/mux.go index 093373a204a..f5843d1a497 100644 --- a/runtime/mux.go +++ b/runtime/mux.go @@ -37,6 +37,7 @@ type ServeMux struct { streamErrorHandler StreamErrorHandlerFunc protoErrorHandler ProtoErrorHandlerFunc disablePathLengthFallback bool + lastMatchWins bool } // ServeMuxOption is an option that can be given to a ServeMux on construction. @@ -133,6 +134,12 @@ func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption { } } +func WithLastMatchWins() ServeMuxOption { + return func(serveMux *ServeMux) { + serveMux.lastMatchWins = true + } +} + // NewServeMux returns a new ServeMux whose internal mapping is empty. func NewServeMux(opts ...ServeMuxOption) *ServeMux { serveMux := &ServeMux{ @@ -173,7 +180,11 @@ func NewServeMux(opts ...ServeMuxOption) *ServeMux { // Handle associates "h" to the pair of HTTP method and path pattern. func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) { - s.handlers[meth] = append(s.handlers[meth], handler{pat: pat, h: h}) + if s.lastMatchWins { + s.handlers[meth] = append([]handler{handler{pat: pat, h: h}}, s.handlers[meth]...) + } else { + s.handlers[meth] = append(s.handlers[meth], handler{pat: pat, h: h}) + } } // ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.Path. diff --git a/runtime/mux_test.go b/runtime/mux_test.go index e033091a885..23243d8f3a4 100644 --- a/runtime/mux_test.go +++ b/runtime/mux_test.go @@ -22,7 +22,8 @@ func TestMuxServeHTTP(t *testing.T) { verb string } for _, spec := range []struct { - patterns []stubPattern + patterns []stubPattern + patternOpts []runtime.PatternOpt reqMethod string reqPath string @@ -33,6 +34,7 @@ func TestMuxServeHTTP(t *testing.T) { disablePathLengthFallback bool errHandler runtime.ProtoErrorHandlerFunc + muxOpts []runtime.ServeMuxOption }{ { patterns: nil, @@ -253,11 +255,11 @@ func TestMuxServeHTTP(t *testing.T) { pool: []string{"unimplemented"}, }, }, - reqMethod: "GET", - reqPath: "/foobar", + reqMethod: "GET", + reqPath: "/foobar", respStatus: http.StatusNotFound, respContent: "GET /foobar", - errHandler: unknownPathIs404, + errHandler: unknownPathIs404, }, { // server returning unimplemented results in 'Not Implemented' code @@ -269,14 +271,72 @@ func TestMuxServeHTTP(t *testing.T) { pool: []string{"unimplemented"}, }, }, - reqMethod: "GET", - reqPath: "/unimplemented", + reqMethod: "GET", + reqPath: "/unimplemented", respStatus: http.StatusNotImplemented, respContent: `GET /unimplemented`, - errHandler: unknownPathIs404, + errHandler: unknownPathIs404, + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(utilities.OpLitPush), 0, int(utilities.OpPush), 0, int(utilities.OpConcatN), 1, int(utilities.OpCapture), 1}, + pool: []string{"foo", "id"}, + }, + }, + patternOpts: []runtime.PatternOpt{runtime.AssumeColonVerbOpt(false)}, + reqMethod: "GET", + reqPath: "/foo/bar", + headers: map[string]string{ + "Content-Type": "application/json", + }, + respStatus: http.StatusOK, + respContent: "GET /foo/{id=*}", + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(utilities.OpLitPush), 0, int(utilities.OpPush), 0, int(utilities.OpConcatN), 1, int(utilities.OpCapture), 1}, + pool: []string{"foo", "id"}, + }, + }, + patternOpts: []runtime.PatternOpt{runtime.AssumeColonVerbOpt(false)}, + reqMethod: "GET", + reqPath: "/foo/bar:123", + headers: map[string]string{ + "Content-Type": "application/json", + }, + respStatus: http.StatusOK, + respContent: "GET /foo/{id=*}", + }, + { + patterns: []stubPattern{ + { + method: "POST", + ops: []int{int(utilities.OpLitPush), 0, int(utilities.OpPush), 0, int(utilities.OpConcatN), 1, int(utilities.OpCapture), 1}, + pool: []string{"foo", "id"}, + }, + { + method: "POST", + ops: []int{int(utilities.OpLitPush), 0, int(utilities.OpPush), 0, int(utilities.OpConcatN), 1, int(utilities.OpCapture), 1}, + pool: []string{"foo", "id"}, + verb: "verb", + }, + }, + patternOpts: []runtime.PatternOpt{runtime.AssumeColonVerbOpt(false)}, + reqMethod: "POST", + reqPath: "/foo/bar:verb", + headers: map[string]string{ + "Content-Type": "application/json", + }, + respStatus: http.StatusOK, + respContent: "POST /foo/{id=*}:verb", + muxOpts: []runtime.ServeMuxOption{runtime.WithLastMatchWins()}, }, } { - var opts []runtime.ServeMuxOption + opts := spec.muxOpts if spec.disablePathLengthFallback { opts = append(opts, runtime.WithDisablePathLengthFallback()) } @@ -286,7 +346,7 @@ func TestMuxServeHTTP(t *testing.T) { mux := runtime.NewServeMux(opts...) for _, p := range spec.patterns { func(p stubPattern) { - pat, err := runtime.NewPattern(1, p.ops, p.pool, p.verb) + pat, err := runtime.NewPattern(1, p.ops, p.pool, p.verb, spec.patternOpts...) if err != nil { t.Fatalf("runtime.NewPattern(1, %#v, %#v, %q) failed with %v; want success", p.ops, p.pool, p.verb, err) } diff --git a/runtime/pattern.go b/runtime/pattern.go index f16a84ad389..4e3569fc5ac 100644 --- a/runtime/pattern.go +++ b/runtime/pattern.go @@ -35,14 +35,30 @@ type Pattern struct { tailLen int // verb is the VERB part of the path pattern. It is empty if the pattern does not have VERB part. verb string + // assumeColonVerb indicates whether blah + assumeColonVerb bool } +type patternOptions struct { + assumeColonVerb bool +} + +// PatternOpt is an option for creating Patterns. +type PatternOpt func(*patternOptions) + // NewPattern returns a new Pattern from the given definition values. // "ops" is a sequence of op codes. "pool" is a constant pool. // "verb" is the verb part of the pattern. It is empty if the pattern does not have the part. // "version" must be 1 for now. // It returns an error if the given definition is invalid. -func NewPattern(version int, ops []int, pool []string, verb string) (Pattern, error) { +func NewPattern(version int, ops []int, pool []string, verb string, opts ...PatternOpt) (Pattern, error) { + options := patternOptions{ + assumeColonVerb: true, + } + for _, o := range opts { + o(&options) + } + if version != 1 { grpclog.Infof("unsupported version: %d", version) return Pattern{}, ErrInvalidPattern @@ -122,12 +138,13 @@ func NewPattern(version int, ops []int, pool []string, verb string) (Pattern, er typedOps = append(typedOps, op) } return Pattern{ - ops: typedOps, - pool: pool, - vars: vars, - stacksize: maxstack, - tailLen: tailLen, - verb: verb, + ops: typedOps, + pool: pool, + vars: vars, + stacksize: maxstack, + tailLen: tailLen, + verb: verb, + assumeColonVerb: options.assumeColonVerb, }, nil } @@ -144,7 +161,16 @@ func MustPattern(p Pattern, err error) Pattern { // If otherwise, the function returns an error. func (p Pattern) Match(components []string, verb string) (map[string]string, error) { if p.verb != verb { - return nil, ErrNotMatch + if p.assumeColonVerb || p.verb != "" { + return nil, ErrNotMatch + } + if len(components) == 0 { + components = []string{":" + verb} + } else { + components = append([]string{}, components...) + components[len(components)-1] += ":" + verb + } + verb = "" } var pos int @@ -225,3 +251,9 @@ func (p Pattern) String() string { } return "/" + segs } + +func AssumeColonVerbOpt(val bool) PatternOpt { + return PatternOpt(func(o *patternOptions) { + o.assumeColonVerb = val + }) +}