Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

runtime: Add outgoing trailer matching #3725

Merged
merged 3 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions runtime/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func DefaultHTTPErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marsh
doForwardTrailers := requestAcceptsTrailers(r)

if doForwardTrailers {
handleForwardResponseTrailerHeader(w, md)
handleForwardResponseTrailerHeader(w, mux, md)
w.Header().Set("Transfer-Encoding", "chunked")
}

Expand All @@ -152,7 +152,7 @@ func DefaultHTTPErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marsh
}

if doForwardTrailers {
handleForwardResponseTrailer(w, md)
handleForwardResponseTrailer(w, mux, md)
}
}

Expand Down
23 changes: 11 additions & 12 deletions runtime/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package runtime
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/textproto"
Expand Down Expand Up @@ -109,18 +108,20 @@ func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, m
}
}

func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
func handleForwardResponseTrailerHeader(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
for k := range md.TrailerMD {
tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
w.Header().Add("Trailer", tKey)
if h, ok := mux.outgoingTrailerMatcher(k); ok {
w.Header().Add("Trailer", textproto.CanonicalMIMEHeaderKey(h))
}
}
}

func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
func handleForwardResponseTrailer(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
for k, vs := range md.TrailerMD {
tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
for _, v := range vs {
w.Header().Add(tKey, v)
if h, ok := mux.outgoingTrailerMatcher(k); ok {
for _, v := range vs {
w.Header().Add(h, v)
}
}
}
}
Expand Down Expand Up @@ -148,12 +149,10 @@ func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marsha
doForwardTrailers := requestAcceptsTrailers(req)

if doForwardTrailers {
handleForwardResponseTrailerHeader(w, md)
handleForwardResponseTrailerHeader(w, mux, md)
w.Header().Set("Transfer-Encoding", "chunked")
}

handleForwardResponseTrailerHeader(w, md)

contentType := marshaler.ContentType(resp)
w.Header().Set("Content-Type", contentType)

Expand All @@ -179,7 +178,7 @@ func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marsha
}

if doForwardTrailers {
handleForwardResponseTrailer(w, md)
handleForwardResponseTrailer(w, mux, md)
}
}

Expand Down
166 changes: 166 additions & 0 deletions runtime/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import (
"io"
"net/http"
"net/http/httptest"
"reflect"
"testing"

"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
pb "github.com/grpc-ecosystem/grpc-gateway/v2/runtime/internal/examplepb"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
Expand Down Expand Up @@ -318,3 +320,167 @@ func TestForwardResponseMessage(t *testing.T) {
})
}
}

func TestOutgoingHeaderMatcher(t *testing.T) {
t.Parallel()
msg := &pb.SimpleMessage{Id: "foo"}
for _, tc := range []struct {
name string
md runtime.ServerMetadata
headers http.Header
matcher runtime.HeaderMatcherFunc
}{
{
name: "default matcher",
md: runtime.ServerMetadata{
HeaderMD: metadata.Pairs(
"foo", "bar",
"baz", "qux",
),
},
headers: http.Header{
"Content-Type": []string{"application/json"},
"Grpc-Metadata-Foo": []string{"bar"},
"Grpc-Metadata-Baz": []string{"qux"},
},
},
{
name: "custom matcher",
md: runtime.ServerMetadata{
HeaderMD: metadata.Pairs(
"foo", "bar",
"baz", "qux",
),
},
headers: http.Header{
"Content-Type": []string{"application/json"},
"Custom-Foo": []string{"bar"},
},
matcher: func(key string) (string, bool) {
switch key {
case "foo":
return "custom-foo", true
default:
return "", false
}
},
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := runtime.NewServerMetadataContext(context.Background(), tc.md)

req := httptest.NewRequest("GET", "http://example.com/foo", nil)
resp := httptest.NewRecorder()

runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(runtime.WithOutgoingHeaderMatcher(tc.matcher)), &runtime.JSONPb{}, resp, req, msg)

w := resp.Result()
defer w.Body.Close()
if w.StatusCode != http.StatusOK {
t.Fatalf("StatusCode %d want %d", w.StatusCode, http.StatusOK)
}

if !reflect.DeepEqual(w.Header, tc.headers) {
t.Fatalf("Header %v want %v", w.Header, tc.headers)
}
})
}
}

func TestOutgoingTrailerMatcher(t *testing.T) {
t.Parallel()
msg := &pb.SimpleMessage{Id: "foo"}
for _, tc := range []struct {
name string
md runtime.ServerMetadata
caller http.Header
headers http.Header
trailer http.Header
matcher runtime.HeaderMatcherFunc
}{
{
name: "default matcher, caller accepts",
md: runtime.ServerMetadata{
TrailerMD: metadata.Pairs(
"foo", "bar",
"baz", "qux",
),
},
caller: http.Header{
"Te": []string{"trailers"},
},
headers: http.Header{
"Content-Type": []string{"application/json"},
"Trailer": []string{"Grpc-Trailer-Foo,Grpc-Trailer-Baz"},
},
trailer: http.Header{
"Grpc-Trailer-Foo": []string{"bar"},
"Grpc-Trailer-Baz": []string{"qux"},
},
},
{
name: "default matcher, caller rejects",
md: runtime.ServerMetadata{
TrailerMD: metadata.Pairs(
"foo", "bar",
"baz", "qux",
),
},
headers: http.Header{
"Content-Type": []string{"application/json"},
},
},
{
name: "custom matcher",
md: runtime.ServerMetadata{
TrailerMD: metadata.Pairs(
"foo", "bar",
"baz", "qux",
),
},
caller: http.Header{
"Te": []string{"trailers"},
},
headers: http.Header{
"Content-Type": []string{"application/json"},
"Trailer": []string{"Custom-Trailer-Foo"},
},
trailer: http.Header{
"Custom-Trailer-Foo": []string{"bar"},
},
matcher: func(key string) (string, bool) {
switch key {
case "foo":
return "custom-trailer-foo", true
default:
return "", false
}
},
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := runtime.NewServerMetadataContext(context.Background(), tc.md)

req := httptest.NewRequest("GET", "http://example.com/foo", nil)
req.Header = tc.caller
resp := httptest.NewRecorder()

runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(runtime.WithOutgoingTrailerMatcher(tc.matcher)), &runtime.JSONPb{}, resp, req, msg)

w := resp.Result()
_, _ = io.Copy(io.Discard, w.Body)
defer w.Body.Close()
if w.StatusCode != http.StatusOK {
t.Fatalf("StatusCode %d want %d", w.StatusCode, http.StatusOK)
}

if !reflect.DeepEqual(w.Trailer, tc.trailer) {
t.Fatalf("Trailer %v want %v", w.Trailer, tc.trailer)
}
})
}
}
32 changes: 26 additions & 6 deletions runtime/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ type ServeMux struct {
marshalers marshalerRegistry
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
outgoingTrailerMatcher HeaderMatcherFunc
metadataAnnotators []func(context.Context, *http.Request) metadata.MD
errorHandler ErrorHandlerFunc
streamErrorHandler StreamErrorHandlerFunc
Expand Down Expand Up @@ -114,10 +115,18 @@ func DefaultHeaderMatcher(key string) (string, bool) {
return "", false
}

func defaultOutgoingHeaderMatcher(key string) (string, bool) {
return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
}

func defaultOutgoingTrailerMatcher(key string) (string, bool) {
return fmt.Sprintf("%s%s", MetadataTrailerPrefix, key), true
}

// WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
//
// This matcher will be called with each header in http.Request. If matcher returns true, that header will be
// passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header.
// passed to gRPC context. To transform the header before passing to gRPC context, matcher should return the modified header.
func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
for _, header := range fn.matchedMalformedHeaders() {
grpclog.Warningf("The configured forwarding filter would allow %q to be sent to the gRPC server, which will likely cause errors. See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more information.", header)
Expand Down Expand Up @@ -147,13 +156,24 @@ func (fn HeaderMatcherFunc) matchedMalformedHeaders() []string {
//
// This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
// passed to http response returned from gateway. To transform the header before passing to response,
// matcher should return modified header.
// matcher should return the modified header.
func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
return func(mux *ServeMux) {
mux.outgoingHeaderMatcher = fn
}
}

// WithOutgoingTrailerMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
//
// This matcher will be called with each header in response trailer metadata. If matcher returns true, that header will be
// passed to http response returned from gateway. To transform the header before passing to response,
// matcher should return the modified header.
func WithOutgoingTrailerMatcher(fn HeaderMatcherFunc) ServeMuxOption {
return func(mux *ServeMux) {
mux.outgoingTrailerMatcher = fn
}
}

// WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
//
// This can be used by services that need to read from http.Request and modify gRPC context. A common use case
Expand Down Expand Up @@ -273,11 +293,11 @@ func NewServeMux(opts ...ServeMuxOption) *ServeMux {
if serveMux.incomingHeaderMatcher == nil {
serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
}

if serveMux.outgoingHeaderMatcher == nil {
serveMux.outgoingHeaderMatcher = func(key string) (string, bool) {
return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
}
serveMux.outgoingHeaderMatcher = defaultOutgoingHeaderMatcher
}
if serveMux.outgoingTrailerMatcher == nil {
serveMux.outgoingTrailerMatcher = defaultOutgoingTrailerMatcher
}

return serveMux
Expand Down
Loading