Skip to content

Commit

Permalink
Only write Content-Length if the runtime.WithWriteContentLength() opt…
Browse files Browse the repository at this point in the history
…ion is specified (#5151)
  • Loading branch information
joshgarnett authored Jan 21, 2025
1 parent e1364b5 commit 5dfd063
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 3 deletions.
2 changes: 1 addition & 1 deletion runtime/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marsha
return
}

if !doForwardTrailers {
if !doForwardTrailers && mux.writeContentLength {
w.Header().Set("Content-Length", strconv.Itoa(len(buf)))
}

Expand Down
83 changes: 81 additions & 2 deletions runtime/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,77 @@ func TestForwardResponseMessage(t *testing.T) {
}

func TestOutgoingHeaderMatcher(t *testing.T) {
t.Parallel()
msg := &pb.SimpleMessage{Id: "foo"}
for _, tc := range []struct {
name string
md runtime.ServerMetadata
headers http.Header
matcher runtime.HeaderMatcherFunc
}{
{
name: "default matcher",
md: runtime.ServerMetadata{
HeaderMD: metadata.Pairs(
"foo", "bar",
"baz", "qux",
),
},
headers: http.Header{
"Content-Type": []string{"application/json"},
"Grpc-Metadata-Foo": []string{"bar"},
"Grpc-Metadata-Baz": []string{"qux"},
},
},
{
name: "custom matcher",
md: runtime.ServerMetadata{
HeaderMD: metadata.Pairs(
"foo", "bar",
"baz", "qux",
),
},
headers: http.Header{
"Content-Type": []string{"application/json"},
"Custom-Foo": []string{"bar"},
},
matcher: func(key string) (string, bool) {
switch key {
case "foo":
return "custom-foo", true
default:
return "", false
}
},
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := runtime.NewServerMetadataContext(context.Background(), tc.md)

req := httptest.NewRequest("GET", "http://example.com/foo", nil)
resp := httptest.NewRecorder()

mux := runtime.NewServeMux(
runtime.WithOutgoingHeaderMatcher(tc.matcher),
)
runtime.ForwardResponseMessage(ctx, mux, &runtime.JSONPb{}, resp, req, msg)

w := resp.Result()
defer w.Body.Close()
if w.StatusCode != http.StatusOK {
t.Fatalf("StatusCode %d want %d", w.StatusCode, http.StatusOK)
}

if !reflect.DeepEqual(w.Header, tc.headers) {
t.Fatalf("Header %v want %v", w.Header, tc.headers)
}
})
}
}

func TestOutgoingHeaderMatcherWithContentLength(t *testing.T) {
t.Parallel()
msg := &pb.SimpleMessage{Id: "foo"}
for _, tc := range []struct {
Expand Down Expand Up @@ -431,7 +502,11 @@ func TestOutgoingHeaderMatcher(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
resp := httptest.NewRecorder()

runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(runtime.WithOutgoingHeaderMatcher(tc.matcher)), &runtime.JSONPb{}, resp, req, msg)
mux := runtime.NewServeMux(
runtime.WithOutgoingHeaderMatcher(tc.matcher),
runtime.WithWriteContentLength(),
)
runtime.ForwardResponseMessage(ctx, mux, &runtime.JSONPb{}, resp, req, msg)

w := resp.Result()
defer w.Body.Close()
Expand Down Expand Up @@ -529,7 +604,11 @@ func TestOutgoingTrailerMatcher(t *testing.T) {
req.Header = tc.caller
resp := httptest.NewRecorder()

runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(runtime.WithOutgoingTrailerMatcher(tc.matcher)), &runtime.JSONPb{}, resp, req, msg)
mux := runtime.NewServeMux(
runtime.WithOutgoingTrailerMatcher(tc.matcher),
runtime.WithWriteContentLength(),
)
runtime.ForwardResponseMessage(ctx, mux, &runtime.JSONPb{}, resp, req, msg)

w := resp.Result()
_, _ = io.Copy(io.Discard, w.Body)
Expand Down
8 changes: 8 additions & 0 deletions runtime/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ type ServeMux struct {
routingErrorHandler RoutingErrorHandlerFunc
disablePathLengthFallback bool
unescapingMode UnescapingMode
writeContentLength bool
}

// ServeMuxOption is an option that can be given to a ServeMux on construction.
Expand Down Expand Up @@ -258,6 +259,13 @@ func WithDisablePathLengthFallback() ServeMuxOption {
}
}

// WithWriteContentLength returns a ServeMuxOption to enable writing content length on non-streaming responses
func WithWriteContentLength() ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.writeContentLength = true
}
}

// WithHealthEndpointAt returns a ServeMuxOption that will add an endpoint to the created ServeMux at the path specified by endpointPath.
// When called the handler will forward the request to the upstream grpc service health check (defined in the
// gRPC Health Checking Protocol).
Expand Down

0 comments on commit 5dfd063

Please sign in to comment.