diff --git a/test/end2end_test.go b/test/end2end_test.go index 37964919945e..98d590e9ea7f 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -2447,85 +2447,6 @@ func testFailedServerStreaming(t *testing.T, e env) { } } -// checkTimeoutErrorServer is a gRPC server checks context timeout error in FullDuplexCall(). -// It is only used in TestStreamingRPCTimeoutServerError. -type checkTimeoutErrorServer struct { - t *testing.T - done chan struct{} - testpb.TestServiceServer -} - -func (s *checkTimeoutErrorServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { - defer close(s.done) - for { - _, err := stream.Recv() - if err != nil { - if grpc.Code(err) != codes.DeadlineExceeded { - s.t.Errorf("stream.Recv() = _, %v, want error code %s", err, codes.DeadlineExceeded) - } - return err - } - if err := stream.Send(&testpb.StreamingOutputCallResponse{ - Payload: &testpb.Payload{ - Body: []byte{'0'}, - }, - }); err != nil { - if grpc.Code(err) != codes.DeadlineExceeded { - s.t.Errorf("stream.Send(_) = %v, want error code %s", err, codes.DeadlineExceeded) - } - return err - } - } -} - -func TestStreamingRPCTimeoutServerError(t *testing.T) { - defer leakCheck(t)() - for _, e := range listTestEnv() { - testStreamingRPCTimeoutServerError(t, e) - } -} - -// testStreamingRPCTimeoutServerError tests the server side behavior. -// When context timeout happens on client side, server should get deadline exceeded error. -func testStreamingRPCTimeoutServerError(t *testing.T, e env) { - te := newTest(t, e) - serverDone := make(chan struct{}) - te.startServer(&checkTimeoutErrorServer{t: t, done: serverDone}) - defer te.tearDown() - - cc := te.clientConn() - tc := testpb.NewTestServiceClient(cc) - - req := &testpb.StreamingOutputCallRequest{} - for duration := 50 * time.Millisecond; ; duration *= 2 { - ctx, _ := context.WithTimeout(context.Background(), duration) - stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false)) - if grpc.Code(err) == codes.DeadlineExceeded { - // Redo test with double timeout. - continue - } - if err != nil { - t.Errorf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) - return - } - for { - err := stream.Send(req) - if err != nil { - break - } - _, err = stream.Recv() - if err != nil { - break - } - } - - // Wait for context timeout on server before closing connection - // to make sure the server will get timeout error. - <-serverDone - break - } -} - // concurrentSendServer is a TestServiceServer whose // StreamingOutputCall makes ten serial Send calls, sending payloads // "0".."9", inclusive. TestServerStreamingConcurrent verifies they diff --git a/transport/http2_client.go b/transport/http2_client.go index c02ee1601e2f..627a590a0dd9 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -533,6 +533,7 @@ func (t *http2Client) CloseStream(s *Stream, err error) { // after having acquired the writableChan to send RST_STREAM out (look at // the controller() routine). var rstStream bool + var rstError http2.ErrCode defer func() { // In case, the client doesn't have to send RST_STREAM to server // we can safely add back to streamsQuota pool now. @@ -540,10 +541,11 @@ func (t *http2Client) CloseStream(s *Stream, err error) { t.streamsQuota.add(1) return } - t.controlBuf.put(&resetStream{s.id, http2.ErrCodeCancel}) + t.controlBuf.put(&resetStream{s.id, rstError}) }() s.mu.Lock() rstStream = s.rstStream + rstError = s.rstError if q := s.fc.resetPendingData(); q > 0 { if n := t.fc.onRead(q); n > 0 { t.controlBuf.put(&windowUpdate{0, n}) @@ -559,8 +561,9 @@ func (t *http2Client) CloseStream(s *Stream, err error) { } s.state = streamDone s.mu.Unlock() - if se, ok := err.(StreamError); ok && se.Code != codes.DeadlineExceeded { + if _, ok := err.(StreamError); ok { rstStream = true + rstError = http2.ErrCodeCancel } } @@ -807,6 +810,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) { s.statusCode = codes.Internal s.statusDesc = err.Error() s.rstStream = true + s.rstError = http2.ErrCodeFlowControl close(s.done) s.mu.Unlock() s.write(recvMsg{err: io.EOF}) diff --git a/transport/transport.go b/transport/transport.go index fed69089b038..beb0a520a573 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -45,6 +45,7 @@ import ( "sync" "golang.org/x/net/context" + "golang.org/x/net/http2" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" @@ -217,6 +218,8 @@ type Stream struct { // rstStream indicates whether a RST_STREAM frame needs to be sent // to the server to signify that this stream is closing. rstStream bool + // rstError is the error that needs to be sent along with the RST_STREAM frame. + rstError http2.ErrCode } // RecvCompress returns the compression algorithm applied to the inbound