Skip to content

Commit

Permalink
Ensure streaming ResponseWriters implement Flush (#406)
Browse files Browse the repository at this point in the history
  • Loading branch information
akshayjshah authored Nov 30, 2022
1 parent 241768d commit 9a4d409
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 6 deletions.
61 changes: 61 additions & 0 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1774,6 +1774,64 @@ func TestFailCompression(t *testing.T) {
assert.Equal(t, connect.CodeOf(err), connect.CodeInternal)
}

func TestUnflushableResponseWriter(t *testing.T) {
t.Parallel()
assertIsFlusherErr := func(t *testing.T, err error) {
t.Helper()
assert.NotNil(t, err)
assert.Equal(t, connect.CodeOf(err), connect.CodeInternal, assert.Sprintf("got %v", err))
assert.True(
t,
strings.HasSuffix(err.Error(), "unflushableWriter does not implement http.Flusher"),
assert.Sprintf("error doesn't reference http.Flusher: %s", err.Error()),
)
}
mux := http.NewServeMux()
path, handler := pingv1connect.NewPingServiceHandler(pingServer{})
wrapped := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler.ServeHTTP(&unflushableWriter{w}, r)
})
mux.Handle(path, wrapped)
server := httptest.NewUnstartedServer(mux)
server.EnableHTTP2 = true
server.StartTLS()
t.Cleanup(server.Close)

tests := []struct {
name string
options []connect.ClientOption
}{
{"connect", nil},
{"grpc", []connect.ClientOption{connect.WithGRPC()}},
{"grpcweb", []connect.ClientOption{connect.WithGRPCWeb()}},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
pingclient := pingv1connect.NewPingServiceClient(server.Client(), server.URL, tt.options...)
stream, err := pingclient.CountUp(
context.Background(),
connect.NewRequest(&pingv1.CountUpRequest{Number: 5}),
)
if err != nil {
assertIsFlusherErr(t, err)
return
}
assert.False(t, stream.Receive())
assertIsFlusherErr(t, stream.Err())
})
}
}

type unflushableWriter struct {
w http.ResponseWriter
}

func (w *unflushableWriter) Header() http.Header { return w.w.Header() }
func (w *unflushableWriter) Write(b []byte) (int, error) { return w.w.Write(b) }
func (w *unflushableWriter) WriteHeader(code int) { w.w.WriteHeader(code) }

func gzipCompressedSize(tb testing.TB, message proto.Message) int {
tb.Helper()
uncompressed, err := proto.Marshal(message)
Expand Down Expand Up @@ -2069,6 +2127,9 @@ func (l *trimTrailerWriter) Flush() {
}

func (l *trimTrailerWriter) removeTrailers() {
for _, v := range l.w.Header().Values("Trailer") {
l.w.Header().Del(v)
}
l.w.Header().Del("Trailer")
for k := range l.w.Header() {
if strings.HasPrefix(k, http.TrailerPrefix) {
Expand Down
11 changes: 11 additions & 0 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,17 @@ func negotiateCompression( //nolint:nonamedreturns
return requestCompression, responseCompression, nil
}

// checkServerStreamsCanFlush ensures that bidi and server streaming handlers
// have received an http.ResponseWriter that implements http.Flusher, since
// they must flush data after sending each message.
func checkServerStreamsCanFlush(spec Spec, responseWriter http.ResponseWriter) *Error {
requiresFlusher := (spec.StreamType & StreamTypeServer) == StreamTypeServer
if _, flushable := responseWriter.(http.Flusher); requiresFlusher && !flushable {
return NewError(CodeInternal, fmt.Errorf("%T does not implement http.Flusher", responseWriter))
}
return nil
}

func flushResponseWriter(w http.ResponseWriter) {
if f, ok := w.(http.Flusher); ok {
f.Flush()
Expand Down
6 changes: 4 additions & 2 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ func (h *connectHandler) NewConn(
contentEncoding,
acceptEncoding,
)
if failed == nil {
failed = checkServerStreamsCanFlush(h.Spec, responseWriter)
}

// Write any remaining headers here:
// (1) any writes to the stream will implicitly send the headers, so we
Expand Down Expand Up @@ -209,8 +212,7 @@ func (h *connectHandler) NewConn(
}
}
conn = wrapHandlerConnWithCodedErrors(conn)
// We can't return failed as-is: a nil *Error is non-nil when returned as an
// error interface.

if failed != nil {
// Negotiation failed, so we can't establish a stream.
_ = conn.Close(failed)
Expand Down
27 changes: 23 additions & 4 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ func (g *grpcHandler) NewConn(
request.Header.Get(grpcHeaderCompression),
request.Header.Get(grpcHeaderAcceptCompression),
)
if failed == nil {
failed = checkServerStreamsCanFlush(g.Spec, responseWriter)
}

// Write any remaining headers here:
// (1) any writes to the stream will implicitly send the headers, so we
Expand Down Expand Up @@ -516,10 +519,26 @@ func (hc *grpcHandlerConn) Close(err error) (retErr error) {
// we're sending a "trailers-only" response, we must send trailing metadata
// as HTTP trailers. (If we had frame-level control of the HTTP/2 layer, we
// could send trailers-only responses as a single HEADER frame and no DATA
// frames, but net/http doesn't expose APIs that low-level.) In net/http's
// ResponseWriter API, we send HTTP trailers by writing to the headers map
// with a special prefix. This prefixing is an implementation detail, so we
// should hide it and _not_ mutate the user-visible headers.
// frames, but net/http doesn't expose APIs that low-level.)
if !hc.wroteToBody {
// This block works around a bug in x/net/http2. Until Go 1.20, trailers
// written using http.TrailerPrefix were only sent if either (1) there's
// data in the body, or (2) the innermost http.ResponseWriter is flushed.
// To ensure that we always send a valid gRPC response, even if the user
// has wrapped the response writer in net/http middleware that doesn't
// implement http.Flusher, we must pre-declare our HTTP trailers. We can
// remove this when Go 1.21 ships and we drop support for Go 1.19.
for key, values := range mergedTrailers {
hc.responseWriter.Header().Add("Trailer", key)
for _, value := range values {
hc.responseWriter.Header().Add(key, value)
}
}
return nil
}
// In net/http's ResponseWriter API, we send HTTP trailers by writing to the
// headers map with a special prefix. This prefixing is an implementation
// detail, so we should hide it and _not_ mutate the user-visible headers.
//
// Note that this is _very_ finicky and difficult to test with net/http,
// since correctness depends on low-level framing details. Breaking this
Expand Down

0 comments on commit 9a4d409

Please sign in to comment.