Skip to content

Commit

Permalink
Merge pull request #1124 from MakMukhi/rst_stream_issue
Browse files Browse the repository at this point in the history
Upon observing timeout on rpc context, the client should send a RST_S…
  • Loading branch information
MakMukhi authored Mar 14, 2017
2 parents 0713829 + 5535384 commit cdee119
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 81 deletions.
79 changes: 0 additions & 79 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2447,85 +2447,6 @@ func testFailedServerStreaming(t *testing.T, e env) {
}
}

// checkTimeoutErrorServer is a gRPC server checks context timeout error in FullDuplexCall().
// It is only used in TestStreamingRPCTimeoutServerError.
type checkTimeoutErrorServer struct {
t *testing.T
done chan struct{}
testpb.TestServiceServer
}

func (s *checkTimeoutErrorServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
defer close(s.done)
for {
_, err := stream.Recv()
if err != nil {
if grpc.Code(err) != codes.DeadlineExceeded {
s.t.Errorf("stream.Recv() = _, %v, want error code %s", err, codes.DeadlineExceeded)
}
return err
}
if err := stream.Send(&testpb.StreamingOutputCallResponse{
Payload: &testpb.Payload{
Body: []byte{'0'},
},
}); err != nil {
if grpc.Code(err) != codes.DeadlineExceeded {
s.t.Errorf("stream.Send(_) = %v, want error code %s", err, codes.DeadlineExceeded)
}
return err
}
}
}

func TestStreamingRPCTimeoutServerError(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testStreamingRPCTimeoutServerError(t, e)
}
}

// testStreamingRPCTimeoutServerError tests the server side behavior.
// When context timeout happens on client side, server should get deadline exceeded error.
func testStreamingRPCTimeoutServerError(t *testing.T, e env) {
te := newTest(t, e)
serverDone := make(chan struct{})
te.startServer(&checkTimeoutErrorServer{t: t, done: serverDone})
defer te.tearDown()

cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)

req := &testpb.StreamingOutputCallRequest{}
for duration := 50 * time.Millisecond; ; duration *= 2 {
ctx, _ := context.WithTimeout(context.Background(), duration)
stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false))
if grpc.Code(err) == codes.DeadlineExceeded {
// Redo test with double timeout.
continue
}
if err != nil {
t.Errorf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
return
}
for {
err := stream.Send(req)
if err != nil {
break
}
_, err = stream.Recv()
if err != nil {
break
}
}

// Wait for context timeout on server before closing connection
// to make sure the server will get timeout error.
<-serverDone
break
}
}

// concurrentSendServer is a TestServiceServer whose
// StreamingOutputCall makes ten serial Send calls, sending payloads
// "0".."9", inclusive. TestServerStreamingConcurrent verifies they
Expand Down
8 changes: 6 additions & 2 deletions transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,17 +533,19 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
// after having acquired the writableChan to send RST_STREAM out (look at
// the controller() routine).
var rstStream bool
var rstError http2.ErrCode
defer func() {
// In case, the client doesn't have to send RST_STREAM to server
// we can safely add back to streamsQuota pool now.
if !rstStream {
t.streamsQuota.add(1)
return
}
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeCancel})
t.controlBuf.put(&resetStream{s.id, rstError})
}()
s.mu.Lock()
rstStream = s.rstStream
rstError = s.rstError
if q := s.fc.resetPendingData(); q > 0 {
if n := t.fc.onRead(q); n > 0 {
t.controlBuf.put(&windowUpdate{0, n})
Expand All @@ -559,8 +561,9 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
}
s.state = streamDone
s.mu.Unlock()
if se, ok := err.(StreamError); ok && se.Code != codes.DeadlineExceeded {
if _, ok := err.(StreamError); ok {
rstStream = true
rstError = http2.ErrCodeCancel
}
}

Expand Down Expand Up @@ -807,6 +810,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
s.statusCode = codes.Internal
s.statusDesc = err.Error()
s.rstStream = true
s.rstError = http2.ErrCodeFlowControl
close(s.done)
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
Expand Down
3 changes: 3 additions & 0 deletions transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"sync"

"golang.org/x/net/context"
"golang.org/x/net/http2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
Expand Down Expand Up @@ -217,6 +218,8 @@ type Stream struct {
// rstStream indicates whether a RST_STREAM frame needs to be sent
// to the server to signify that this stream is closing.
rstStream bool
// rstError is the error that needs to be sent along with the RST_STREAM frame.
rstError http2.ErrCode
}

// RecvCompress returns the compression algorithm applied to the inbound
Expand Down

0 comments on commit cdee119

Please sign in to comment.