Skip to content

Commit

Permalink
runtime: Add outgoing trailer matching (#3725)
Browse files Browse the repository at this point in the history
* runtime: Add outgoing trailer matching

* runtime: Fix spelling

* runtime: Use simple defaults syntax
  • Loading branch information
adriansmares authored Nov 10, 2023
1 parent 1c1e884 commit 132c8be
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 20 deletions.
4 changes: 2 additions & 2 deletions runtime/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func DefaultHTTPErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marsh
doForwardTrailers := requestAcceptsTrailers(r)

if doForwardTrailers {
handleForwardResponseTrailerHeader(w, md)
handleForwardResponseTrailerHeader(w, mux, md)
w.Header().Set("Transfer-Encoding", "chunked")
}

Expand All @@ -152,7 +152,7 @@ func DefaultHTTPErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marsh
}

if doForwardTrailers {
handleForwardResponseTrailer(w, md)
handleForwardResponseTrailer(w, mux, md)
}
}

Expand Down
23 changes: 11 additions & 12 deletions runtime/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package runtime
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/textproto"
Expand Down Expand Up @@ -109,18 +108,20 @@ func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, m
}
}

func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
func handleForwardResponseTrailerHeader(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
for k := range md.TrailerMD {
tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
w.Header().Add("Trailer", tKey)
if h, ok := mux.outgoingTrailerMatcher(k); ok {
w.Header().Add("Trailer", textproto.CanonicalMIMEHeaderKey(h))
}
}
}

func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
func handleForwardResponseTrailer(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
for k, vs := range md.TrailerMD {
tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
for _, v := range vs {
w.Header().Add(tKey, v)
if h, ok := mux.outgoingTrailerMatcher(k); ok {
for _, v := range vs {
w.Header().Add(h, v)
}
}
}
}
Expand Down Expand Up @@ -148,12 +149,10 @@ func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marsha
doForwardTrailers := requestAcceptsTrailers(req)

if doForwardTrailers {
handleForwardResponseTrailerHeader(w, md)
handleForwardResponseTrailerHeader(w, mux, md)
w.Header().Set("Transfer-Encoding", "chunked")
}

handleForwardResponseTrailerHeader(w, md)

contentType := marshaler.ContentType(resp)
w.Header().Set("Content-Type", contentType)

Expand All @@ -179,7 +178,7 @@ func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marsha
}

if doForwardTrailers {
handleForwardResponseTrailer(w, md)
handleForwardResponseTrailer(w, mux, md)
}
}

Expand Down
166 changes: 166 additions & 0 deletions runtime/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import (
"io"
"net/http"
"net/http/httptest"
"reflect"
"testing"

"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
pb "github.com/grpc-ecosystem/grpc-gateway/v2/runtime/internal/examplepb"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
Expand Down Expand Up @@ -318,3 +320,167 @@ func TestForwardResponseMessage(t *testing.T) {
})
}
}

func TestOutgoingHeaderMatcher(t *testing.T) {
t.Parallel()
msg := &pb.SimpleMessage{Id: "foo"}
for _, tc := range []struct {
name string
md runtime.ServerMetadata
headers http.Header
matcher runtime.HeaderMatcherFunc
}{
{
name: "default matcher",
md: runtime.ServerMetadata{
HeaderMD: metadata.Pairs(
"foo", "bar",
"baz", "qux",
),
},
headers: http.Header{
"Content-Type": []string{"application/json"},
"Grpc-Metadata-Foo": []string{"bar"},
"Grpc-Metadata-Baz": []string{"qux"},
},
},
{
name: "custom matcher",
md: runtime.ServerMetadata{
HeaderMD: metadata.Pairs(
"foo", "bar",
"baz", "qux",
),
},
headers: http.Header{
"Content-Type": []string{"application/json"},
"Custom-Foo": []string{"bar"},
},
matcher: func(key string) (string, bool) {
switch key {
case "foo":
return "custom-foo", true
default:
return "", false
}
},
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := runtime.NewServerMetadataContext(context.Background(), tc.md)

req := httptest.NewRequest("GET", "http://example.com/foo", nil)
resp := httptest.NewRecorder()

runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(runtime.WithOutgoingHeaderMatcher(tc.matcher)), &runtime.JSONPb{}, resp, req, msg)

w := resp.Result()
defer w.Body.Close()
if w.StatusCode != http.StatusOK {
t.Fatalf("StatusCode %d want %d", w.StatusCode, http.StatusOK)
}

if !reflect.DeepEqual(w.Header, tc.headers) {
t.Fatalf("Header %v want %v", w.Header, tc.headers)
}
})
}
}

func TestOutgoingTrailerMatcher(t *testing.T) {
t.Parallel()
msg := &pb.SimpleMessage{Id: "foo"}
for _, tc := range []struct {
name string
md runtime.ServerMetadata
caller http.Header
headers http.Header
trailer http.Header
matcher runtime.HeaderMatcherFunc
}{
{
name: "default matcher, caller accepts",
md: runtime.ServerMetadata{
TrailerMD: metadata.Pairs(
"foo", "bar",
"baz", "qux",
),
},
caller: http.Header{
"Te": []string{"trailers"},
},
headers: http.Header{
"Content-Type": []string{"application/json"},
"Trailer": []string{"Grpc-Trailer-Foo,Grpc-Trailer-Baz"},
},
trailer: http.Header{
"Grpc-Trailer-Foo": []string{"bar"},
"Grpc-Trailer-Baz": []string{"qux"},
},
},
{
name: "default matcher, caller rejects",
md: runtime.ServerMetadata{
TrailerMD: metadata.Pairs(
"foo", "bar",
"baz", "qux",
),
},
headers: http.Header{
"Content-Type": []string{"application/json"},
},
},
{
name: "custom matcher",
md: runtime.ServerMetadata{
TrailerMD: metadata.Pairs(
"foo", "bar",
"baz", "qux",
),
},
caller: http.Header{
"Te": []string{"trailers"},
},
headers: http.Header{
"Content-Type": []string{"application/json"},
"Trailer": []string{"Custom-Trailer-Foo"},
},
trailer: http.Header{
"Custom-Trailer-Foo": []string{"bar"},
},
matcher: func(key string) (string, bool) {
switch key {
case "foo":
return "custom-trailer-foo", true
default:
return "", false
}
},
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := runtime.NewServerMetadataContext(context.Background(), tc.md)

req := httptest.NewRequest("GET", "http://example.com/foo", nil)
req.Header = tc.caller
resp := httptest.NewRecorder()

runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(runtime.WithOutgoingTrailerMatcher(tc.matcher)), &runtime.JSONPb{}, resp, req, msg)

w := resp.Result()
_, _ = io.Copy(io.Discard, w.Body)
defer w.Body.Close()
if w.StatusCode != http.StatusOK {
t.Fatalf("StatusCode %d want %d", w.StatusCode, http.StatusOK)
}

if !reflect.DeepEqual(w.Trailer, tc.trailer) {
t.Fatalf("Trailer %v want %v", w.Trailer, tc.trailer)
}
})
}
}
32 changes: 26 additions & 6 deletions runtime/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ type ServeMux struct {
marshalers marshalerRegistry
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
outgoingTrailerMatcher HeaderMatcherFunc
metadataAnnotators []func(context.Context, *http.Request) metadata.MD
errorHandler ErrorHandlerFunc
streamErrorHandler StreamErrorHandlerFunc
Expand Down Expand Up @@ -114,10 +115,18 @@ func DefaultHeaderMatcher(key string) (string, bool) {
return "", false
}

func defaultOutgoingHeaderMatcher(key string) (string, bool) {
return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
}

func defaultOutgoingTrailerMatcher(key string) (string, bool) {
return fmt.Sprintf("%s%s", MetadataTrailerPrefix, key), true
}

// WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
//
// This matcher will be called with each header in http.Request. If matcher returns true, that header will be
// passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header.
// passed to gRPC context. To transform the header before passing to gRPC context, matcher should return the modified header.
func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
for _, header := range fn.matchedMalformedHeaders() {
grpclog.Warningf("The configured forwarding filter would allow %q to be sent to the gRPC server, which will likely cause errors. See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more information.", header)
Expand Down Expand Up @@ -147,13 +156,24 @@ func (fn HeaderMatcherFunc) matchedMalformedHeaders() []string {
//
// This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
// passed to http response returned from gateway. To transform the header before passing to response,
// matcher should return modified header.
// matcher should return the modified header.
func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
return func(mux *ServeMux) {
mux.outgoingHeaderMatcher = fn
}
}

// WithOutgoingTrailerMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
//
// This matcher will be called with each header in response trailer metadata. If matcher returns true, that header will be
// passed to http response returned from gateway. To transform the header before passing to response,
// matcher should return the modified header.
func WithOutgoingTrailerMatcher(fn HeaderMatcherFunc) ServeMuxOption {
return func(mux *ServeMux) {
mux.outgoingTrailerMatcher = fn
}
}

// WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
//
// This can be used by services that need to read from http.Request and modify gRPC context. A common use case
Expand Down Expand Up @@ -273,11 +293,11 @@ func NewServeMux(opts ...ServeMuxOption) *ServeMux {
if serveMux.incomingHeaderMatcher == nil {
serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
}

if serveMux.outgoingHeaderMatcher == nil {
serveMux.outgoingHeaderMatcher = func(key string) (string, bool) {
return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
}
serveMux.outgoingHeaderMatcher = defaultOutgoingHeaderMatcher
}
if serveMux.outgoingTrailerMatcher == nil {
serveMux.outgoingTrailerMatcher = defaultOutgoingTrailerMatcher
}

return serveMux
Expand Down

0 comments on commit 132c8be

Please sign in to comment.