Skip to content

Commit

Permalink
Added WithDisablePathLengthFallback option (to fix issue grpc-ecosyst…
Browse files Browse the repository at this point in the history
  • Loading branch information
Uladzimir Trehubenka authored and johanbrandhorst committed Jan 22, 2019
1 parent 5a9b22a commit 00289e6
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 12 deletions.
30 changes: 19 additions & 11 deletions runtime/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[str
// It matches http requests to patterns and invokes the corresponding handler.
type ServeMux struct {
// handlers maps HTTP method to a list of handlers.
handlers map[string][]handler
forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
marshalers marshalerRegistry
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
metadataAnnotators []func(context.Context, *http.Request) metadata.MD
protoErrorHandler ProtoErrorHandlerFunc
handlers map[string][]handler
forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
marshalers marshalerRegistry
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
metadataAnnotators []func(context.Context, *http.Request) metadata.MD
protoErrorHandler ProtoErrorHandlerFunc
disablePathLengthFallback bool
}

// ServeMuxOption is an option that can be given to a ServeMux on construction.
Expand Down Expand Up @@ -102,6 +103,13 @@ func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption {
}
}

// WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback.
func WithDisablePathLengthFallback() ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.disablePathLengthFallback = true
}
}

// NewServeMux returns a new ServeMux whose internal mapping is empty.
func NewServeMux(opts ...ServeMuxOption) *ServeMux {
serveMux := &ServeMux{
Expand Down Expand Up @@ -177,7 +185,7 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
components[l-1], verb = c[:idx], c[idx+1:]
}

if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && isPathLengthFallback(r) {
if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
r.Method = strings.ToUpper(override)
if err := r.ParseForm(); err != nil {
if s.protoErrorHandler != nil {
Expand Down Expand Up @@ -211,7 +219,7 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
continue
}
// X-HTTP-Method-Override is optional. Always allow fallback to POST.
if isPathLengthFallback(r) {
if s.isPathLengthFallback(r) {
if err := r.ParseForm(); err != nil {
if s.protoErrorHandler != nil {
_, outboundMarshaler := MarshalerForRequest(s, r)
Expand Down Expand Up @@ -250,8 +258,8 @@ func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.Resp
return s.forwardResponseOptions
}

func isPathLengthFallback(r *http.Request) bool {
return r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
}

type handler struct {
Expand Down
47 changes: 46 additions & 1 deletion runtime/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ func TestMuxServeHTTP(t *testing.T) {

respStatus int
respContent string

disablePathLengthFallback bool
}{
{
patterns: nil,
Expand Down Expand Up @@ -122,6 +124,45 @@ func TestMuxServeHTTP(t *testing.T) {
respStatus: http.StatusOK,
respContent: "GET /foo",
},
{
patterns: []stubPattern{
{
method: "GET",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
},
},
reqMethod: "POST",
reqPath: "/foo",
headers: map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
},
respStatus: http.StatusMethodNotAllowed,
respContent: "Method Not Allowed\n",
disablePathLengthFallback: true,
},
{
patterns: []stubPattern{
{
method: "GET",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
},
{
method: "POST",
ops: []int{int(utilities.OpLitPush), 0},
pool: []string{"foo"},
},
},
reqMethod: "POST",
reqPath: "/foo",
headers: map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
},
respStatus: http.StatusOK,
respContent: "POST /foo",
disablePathLengthFallback: true,
},
{
patterns: []stubPattern{
{
Expand Down Expand Up @@ -199,7 +240,11 @@ func TestMuxServeHTTP(t *testing.T) {
respContent: "GET /foo/{id=*}:verb",
},
} {
mux := runtime.NewServeMux()
var opts []runtime.ServeMuxOption
if spec.disablePathLengthFallback {
opts = append(opts, runtime.WithDisablePathLengthFallback())
}
mux := runtime.NewServeMux(opts...)
for _, p := range spec.patterns {
func(p stubPattern) {
pat, err := runtime.NewPattern(1, p.ops, p.pool, p.verb)
Expand Down

0 comments on commit 00289e6

Please sign in to comment.