Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure streaming ResponseWriters implement Flush #406

Merged
merged 3 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1774,6 +1774,61 @@ func TestFailCompression(t *testing.T) {
assert.Equal(t, connect.CodeOf(err), connect.CodeInternal)
}

func TestUnflushableResponseWriter(t *testing.T) {
assertIsFlusherErr := func(t *testing.T, err error) {
t.Helper()
assert.NotNil(t, err)
assert.Equal(t, connect.CodeOf(err), connect.CodeInternal, assert.Sprintf("got %v", err))
assert.True(
t,
strings.HasSuffix(err.Error(), "unflushableWriter does not implement http.Flusher"),
assert.Sprintf("error doesn't reference http.Flusher: %s", err.Error()),
)
}
mux := http.NewServeMux()
path, handler := pingv1connect.NewPingServiceHandler(pingServer{})
wrapped := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler.ServeHTTP(&unflushableWriter{w}, r)
})
mux.Handle(path, wrapped)
server := httptest.NewUnstartedServer(mux)
server.EnableHTTP2 = true
server.StartTLS()
t.Cleanup(server.Close)

tests := []struct {
name string
options []connect.ClientOption
}{
{"connect", nil},
{"grpc", []connect.ClientOption{connect.WithGRPC()}},
{"grpcweb", []connect.ClientOption{connect.WithGRPCWeb()}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pingclient := pingv1connect.NewPingServiceClient(server.Client(), server.URL, tt.options...)
stream, err := pingclient.CountUp(
context.Background(),
connect.NewRequest(&pingv1.CountUpRequest{Number: 5}),
)
if err != nil {
assertIsFlusherErr(t, err)
return
}
assert.False(t, stream.Receive())
assertIsFlusherErr(t, stream.Err())
})
}
}

type unflushableWriter struct {
w http.ResponseWriter
}

func (w *unflushableWriter) Header() http.Header { return w.w.Header() }
func (w *unflushableWriter) Write(b []byte) (int, error) { return w.w.Write(b) }
func (w *unflushableWriter) WriteHeader(code int) { w.w.WriteHeader(code) }

func gzipCompressedSize(tb testing.TB, message proto.Message) int {
tb.Helper()
uncompressed, err := proto.Marshal(message)
Expand Down
11 changes: 11 additions & 0 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,17 @@ func negotiateCompression( //nolint:nonamedreturns
return requestCompression, responseCompression, nil
}

// checkServerStreamsCanFlush ensures that bidi and server streaming handlers
// have received an http.ResponseWriter that implements http.Flusher, since
// they must flush data after sending each message.
func checkServerStreamsCanFlush(spec Spec, responseWriter http.ResponseWriter) *Error {
requiresFlusher := (spec.StreamType & StreamTypeServer) == StreamTypeServer
if _, flushable := responseWriter.(http.Flusher); requiresFlusher && !flushable {
return NewError(CodeInternal, fmt.Errorf("%T does not implement http.Flusher", responseWriter))
}
return nil
}

func flushResponseWriter(w http.ResponseWriter) {
if f, ok := w.(http.Flusher); ok {
f.Flush()
Expand Down
6 changes: 4 additions & 2 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ func (h *connectHandler) NewConn(
contentEncoding,
acceptEncoding,
)
if failed == nil {
failed = checkServerStreamsCanFlush(h.Spec, responseWriter)
}

// Write any remaining headers here:
// (1) any writes to the stream will implicitly send the headers, so we
Expand Down Expand Up @@ -209,8 +212,7 @@ func (h *connectHandler) NewConn(
}
}
conn = wrapHandlerConnWithCodedErrors(conn)
// We can't return failed as-is: a nil *Error is non-nil when returned as an
// error interface.

if failed != nil {
// Negotiation failed, so we can't establish a stream.
_ = conn.Close(failed)
Expand Down
27 changes: 23 additions & 4 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ func (g *grpcHandler) NewConn(
request.Header.Get(grpcHeaderCompression),
request.Header.Get(grpcHeaderAcceptCompression),
)
if failed == nil {
failed = checkServerStreamsCanFlush(g.Spec, responseWriter)
}

// Write any remaining headers here:
// (1) any writes to the stream will implicitly send the headers, so we
Expand Down Expand Up @@ -516,10 +519,26 @@ func (hc *grpcHandlerConn) Close(err error) (retErr error) {
// we're sending a "trailers-only" response, we must send trailing metadata
// as HTTP trailers. (If we had frame-level control of the HTTP/2 layer, we
// could send trailers-only responses as a single HEADER frame and no DATA
// frames, but net/http doesn't expose APIs that low-level.) In net/http's
// ResponseWriter API, we send HTTP trailers by writing to the headers map
// with a special prefix. This prefixing is an implementation detail, so we
// should hide it and _not_ mutate the user-visible headers.
// frames, but net/http doesn't expose APIs that low-level.)
if !hc.wroteToBody {
// This block works around a bug in x/net/http2. Until Go 1.20, trailers
// written using http.TrailerPrefix were only sent if either (1) there's
// data in the body, or (2) the innermost http.ResponseWriter is flushed.
// To ensure that we always send a valid gRPC response, even if the user
// has wrapped the response writer in net/http middleware that doesn't
// implement http.Flusher, we must pre-declare our HTTP trailers. We can
// remove this when Go 1.21 ships and we drop support for Go 1.19.
for key, values := range mergedTrailers {
hc.responseWriter.Header().Add("Trailer", key)
for _, value := range values {
hc.responseWriter.Header().Add(key, value)
}
}
return nil
}
// In net/http's ResponseWriter API, we send HTTP trailers by writing to the
// headers map with a special prefix. This prefixing is an implementation
// detail, so we should hide it and _not_ mutate the user-visible headers.
//
// Note that this is _very_ finicky and difficult to test with net/http,
// since correctness depends on low-level framing details. Breaking this
Expand Down