Skip to content

Commit

Permalink
Wrap errors with context cancellation codes (#659)
Browse files Browse the repository at this point in the history
This PR wraps errors with the appropriate connect code of Cancelled or
DeadlineExceeded if the context error is not nil.

Improves error handling for some well known error cases that do not
surface context.Cancelled errors. For example HTTP2 "client disconnect"
string errors are now raised with a Cancelled code not an Unknown. This
lets handlers check the error code for better handling and reporting of
errors.

Fix for #645
  • Loading branch information
emcfarlane authored Feb 16, 2024
1 parent 1f132f4 commit 064c61e
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 18 deletions.
135 changes: 135 additions & 0 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ import (
"compress/flate"
"compress/gzip"
"context"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"math/rand"
"net"
"net/http"
"runtime"
"strings"
Expand Down Expand Up @@ -2307,6 +2309,139 @@ func TestStreamUnexpectedEOF(t *testing.T) {
}
}

// TestClientDisconnect tests that the handler receives a CodeCanceled error when
// the client abruptly disconnects.
func TestClientDisconnect(t *testing.T) {
t.Parallel()
type httpRoundTripFunc func(server *memhttp.Server, clientConn *net.Conn, onError chan struct{}) http.RoundTripper
http1RoundTripper := func(server *memhttp.Server, clientConn *net.Conn, onError chan struct{}) http.RoundTripper {
transport := server.TransportHTTP1()
dialContext := transport.DialContext
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := dialContext(ctx, network, addr)
if err != nil {
close(onError)
return nil, err
}
*clientConn = conn // Capture the client connection.
return conn, nil
}
return transport
}
http2RoundTripper := func(server *memhttp.Server, clientConn *net.Conn, onError chan struct{}) http.RoundTripper {
transport := server.Transport()
dialContext := transport.DialTLSContext
transport.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
conn, err := dialContext(ctx, network, addr, cfg)
if err != nil {
close(onError)
return nil, err
}
*clientConn = conn // Capture the client connection.
return conn, nil
}
return transport
}
testTransportClosure := func(t *testing.T, captureTransport httpRoundTripFunc) { //nolint:thelper
t.Run("handler_reads", func(t *testing.T) {
var (
handlerReceiveErr error
handlerContextErr error
gotRequest = make(chan struct{})
gotResponse = make(chan struct{})
)
pingServer := &pluggablePingServer{
sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) {
close(gotRequest)
for stream.Receive() {
// Do nothing
}
handlerReceiveErr = stream.Err()
handlerContextErr = ctx.Err()
close(gotResponse)
return connect.NewResponse(&pingv1.SumResponse{}), nil
},
}
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer))
server := memhttptest.NewServer(t, mux)
var clientConn net.Conn
transport := captureTransport(server, &clientConn, gotRequest)
serverClient := &http.Client{Transport: transport}
client := pingv1connect.NewPingServiceClient(serverClient, server.URL())
stream := client.Sum(context.Background())
// Send header.
assert.Nil(t, stream.Send(nil))
<-gotRequest
// Client abruptly disconnects.
if !assert.NotNil(t, clientConn) {
return
}
assert.Nil(t, clientConn.Close())
_, err := stream.CloseAndReceive()
assert.NotNil(t, err)
<-gotResponse
assert.NotNil(t, handlerReceiveErr)
assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled)
assert.ErrorIs(t, handlerContextErr, context.Canceled)
})
t.Run("handler_writes", func(t *testing.T) {
var (
handlerReceiveErr error
handlerContextErr error
gotRequest = make(chan struct{})
gotResponse = make(chan struct{})
)
pingServer := &pluggablePingServer{
countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error {
close(gotRequest)
var err error
for err == nil {
err = stream.Send(&pingv1.CountUpResponse{})
}
handlerReceiveErr = err
handlerContextErr = ctx.Err()
close(gotResponse)
return nil
},
}
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer))
server := memhttptest.NewServer(t, mux)
var clientConn net.Conn
transport := captureTransport(server, &clientConn, gotRequest)
serverClient := &http.Client{Transport: transport}
client := pingv1connect.NewPingServiceClient(serverClient, server.URL())
stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{}))
if !assert.Nil(t, err) {
return
}
<-gotRequest
// Client abruptly disconnects.
if !assert.NotNil(t, clientConn) {
return
}
assert.Nil(t, clientConn.Close())
for stream.Receive() {
// Do nothing
}
assert.NotNil(t, stream.Err())
<-gotResponse
assert.NotNil(t, handlerReceiveErr)
assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled)
assert.ErrorIs(t, handlerContextErr, context.Canceled)
})
}
t.Run("http1", func(t *testing.T) {
t.Parallel()
testTransportClosure(t, http1RoundTripper)
})
t.Run("http2", func(t *testing.T) {
t.Parallel()
testTransportClosure(t, http2RoundTripper)
})
}

func TestTrailersOnlyErrors(t *testing.T) {
t.Parallel()

Expand Down
21 changes: 9 additions & 12 deletions envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package connect

import (
"bytes"
"context"
"encoding/binary"
"errors"
"io"
Expand Down Expand Up @@ -117,6 +118,7 @@ func (e *envelope) Len() int {
}

type envelopeWriter struct {
ctx context.Context //nolint:containedctx
sender messageSender
codec Codec
compressMinBytes int
Expand Down Expand Up @@ -208,7 +210,7 @@ func (w *envelopeWriter) marshal(message any) *Error {

func (w *envelopeWriter) write(env *envelope) *Error {
if _, err := w.sender.Send(env); err != nil {
err = wrapIfContextError(err)
err = wrapIfContextDone(w.ctx, err)
if connectErr, ok := asError(err); ok {
return connectErr
}
Expand All @@ -218,6 +220,7 @@ func (w *envelopeWriter) write(env *envelope) *Error {
}

type envelopeReader struct {
ctx context.Context //nolint:containedctx
reader io.Reader
codec Codec
last envelope
Expand Down Expand Up @@ -312,15 +315,12 @@ func (r *envelopeReader) Read(env *envelope) *Error {
// add any alarming text about protocol errors, though.
return NewError(CodeUnknown, err)
}
err = wrapIfContextError(err)
// Something else has gone wrong - the stream didn't end cleanly.
err = wrapIfMaxBytesError(err, "read 5 byte message prefix")
err = wrapIfContextDone(r.ctx, err)
if connectErr, ok := asError(err); ok {
return connectErr
}
if maxBytesErr := asMaxBytesError(err, "read 5 byte message prefix"); maxBytesErr != nil {
// We're reading from an http.MaxBytesHandler, and we've exceeded the read limit.
return maxBytesErr
}
// Something else has gone wrong - the stream didn't end cleanly.
return errorf(
CodeInvalidArgument,
"protocol error: incomplete envelope: %w", err,
Expand All @@ -338,10 +338,6 @@ func (r *envelopeReader) Read(env *envelope) *Error {
// CopyN will return an error if it doesn't read the requested
// number of bytes.
if readN, err := io.CopyN(env.Data, r.reader, size); err != nil {
if maxBytesErr := asMaxBytesError(err, "read %d byte message", size); maxBytesErr != nil {
// We're reading from an http.MaxBytesHandler, and we've exceeded the read limit.
return maxBytesErr
}
if errors.Is(err, io.EOF) {
// We've gotten fewer bytes than we expected, so the stream has ended
// unexpectedly.
Expand All @@ -352,7 +348,8 @@ func (r *envelopeReader) Read(env *envelope) *Error {
readN,
)
}
err = wrapIfContextError(err)
err = wrapIfMaxBytesError(err, "read %d byte message", size)
err = wrapIfContextDone(r.ctx, err)
if connectErr, ok := asError(err); ok {
return connectErr
}
Expand Down
2 changes: 2 additions & 0 deletions envelope_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package connect

import (
"bytes"
"context"
"io"
"testing"

Expand Down Expand Up @@ -44,6 +45,7 @@ func TestEnvelope(t *testing.T) {
t.Parallel()
env := &envelope{Data: &bytes.Buffer{}}
rdr := envelopeReader{
ctx: context.Background(),
reader: byteByByteReader{
reader: bytes.NewReader(buf.Bytes()),
},
Expand Down
31 changes: 29 additions & 2 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,25 @@ func wrapIfContextError(err error) error {
return err
}

// wrapIfContextDone wraps errors with CodeCanceled or CodeDeadlineExceeded
// if the context is done. It leaves already-wrapped errors unchanged.
func wrapIfContextDone(ctx context.Context, err error) error {
if err == nil {
return nil
}
err = wrapIfContextError(err)
if _, ok := asError(err); ok {
return err
}
ctxErr := ctx.Err()
if errors.Is(ctxErr, context.Canceled) {
return NewError(CodeCanceled, err)
} else if errors.Is(ctxErr, context.DeadlineExceeded) {
return NewError(CodeDeadlineExceeded, err)
}
return err
}

// wrapIfLikelyH2CNotConfiguredError adds a wrapping error that has a message
// telling the caller that they likely need to use h2c but are using a raw http.Client{}.
//
Expand Down Expand Up @@ -414,10 +433,18 @@ func wrapIfRSTError(err error) error {
}
}

func asMaxBytesError(err error, tmpl string, args ...any) *Error {
// wrapIfMaxBytesError wraps errors returned reading from a http.MaxBytesHandler
// whose limit has been exceeded.
func wrapIfMaxBytesError(err error, tmpl string, args ...any) error {
if err == nil {
return nil
}
if _, ok := asError(err); ok {
return err
}
var maxBytesErr *http.MaxBytesError
if ok := errors.As(err, &maxBytesErr); !ok {
return nil
return err
}
prefix := fmt.Sprintf(tmpl, args...)
return errorf(CodeResourceExhausted, "%s: exceeded %d byte http.MaxBytesReader limit", prefix, maxBytesErr.Limit)
Expand Down
17 changes: 13 additions & 4 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ func (h *connectHandler) NewConn(
responseWriter http.ResponseWriter,
request *http.Request,
) (handlerConnCloser, bool) {
ctx := request.Context()
query := request.URL.Query()
// We need to parse metadata before entering the interceptor stack; we'll
// send the error to the client later on.
Expand Down Expand Up @@ -255,6 +256,7 @@ func (h *connectHandler) NewConn(
request: request,
responseWriter: responseWriter,
marshaler: connectUnaryMarshaler{
ctx: ctx,
sender: writeSender{writer: responseWriter},
codec: codec,
compressMinBytes: h.CompressMinBytes,
Expand All @@ -265,6 +267,7 @@ func (h *connectHandler) NewConn(
sendMaxBytes: h.SendMaxBytes,
},
unmarshaler: connectUnaryUnmarshaler{
ctx: ctx,
reader: requestBody,
codec: codec,
compressionPool: h.CompressionPools.Get(requestCompression),
Expand All @@ -281,6 +284,7 @@ func (h *connectHandler) NewConn(
responseWriter: responseWriter,
marshaler: connectStreamingMarshaler{
envelopeWriter: envelopeWriter{
ctx: ctx,
sender: writeSender{responseWriter},
codec: codec,
compressMinBytes: h.CompressMinBytes,
Expand All @@ -291,6 +295,7 @@ func (h *connectHandler) NewConn(
},
unmarshaler: connectStreamingUnmarshaler{
envelopeReader: envelopeReader{
ctx: ctx,
reader: requestBody,
codec: codec,
compressionPool: h.CompressionPools.Get(requestCompression),
Expand Down Expand Up @@ -376,6 +381,7 @@ func (c *connectClient) NewConn(
bufferPool: c.BufferPool,
marshaler: connectUnaryRequestMarshaler{
connectUnaryMarshaler: connectUnaryMarshaler{
ctx: ctx,
sender: duplexCall,
codec: c.Codec,
compressMinBytes: c.CompressMinBytes,
Expand All @@ -387,6 +393,7 @@ func (c *connectClient) NewConn(
},
},
unmarshaler: connectUnaryUnmarshaler{
ctx: ctx,
reader: duplexCall,
codec: c.Codec,
bufferPool: c.BufferPool,
Expand Down Expand Up @@ -416,6 +423,7 @@ func (c *connectClient) NewConn(
codec: c.Codec,
marshaler: connectStreamingMarshaler{
envelopeWriter: envelopeWriter{
ctx: ctx,
sender: duplexCall,
codec: c.Codec,
compressMinBytes: c.CompressMinBytes,
Expand All @@ -426,6 +434,7 @@ func (c *connectClient) NewConn(
},
unmarshaler: connectStreamingUnmarshaler{
envelopeReader: envelopeReader{
ctx: ctx,
reader: duplexCall,
codec: c.Codec,
bufferPool: c.BufferPool,
Expand Down Expand Up @@ -912,6 +921,7 @@ func (u *connectStreamingUnmarshaler) EndStreamError() *Error {
}

type connectUnaryMarshaler struct {
ctx context.Context //nolint:containedctx
sender messageSender
codec Codec
compressMinBytes int
Expand Down Expand Up @@ -1077,6 +1087,7 @@ func (m *connectUnaryRequestMarshaler) writeWithGet(url *url.URL) *Error {
}

type connectUnaryUnmarshaler struct {
ctx context.Context //nolint:containedctx
reader io.Reader
codec Codec
compressionPool *compressionPool
Expand All @@ -1103,13 +1114,11 @@ func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]by
// ReadFrom ignores io.EOF, so any error here is real.
bytesRead, err := data.ReadFrom(reader)
if err != nil {
err = wrapIfContextError(err)
err = wrapIfMaxBytesError(err, "read first %d bytes of message", bytesRead)
err = wrapIfContextDone(u.ctx, err)
if connectErr, ok := asError(err); ok {
return connectErr
}
if readMaxBytesErr := asMaxBytesError(err, "read first %d bytes of message", bytesRead); readMaxBytesErr != nil {
return readMaxBytesErr
}
return errorf(CodeUnknown, "read message: %w", err)
}
if u.readMaxBytes > 0 && bytesRead > int64(u.readMaxBytes) {
Expand Down
Loading

0 comments on commit 064c61e

Please sign in to comment.