From 52877cf17b39b826dd927cd1f60a58a34566adc0 Mon Sep 17 00:00:00 2001 From: Linniem Date: Thu, 3 Nov 2022 22:47:30 +0800 Subject: [PATCH] Add SendHeaders to BidiStreamForClient --- client_stream.go | 5 ++++ connect.go | 5 ++-- connect_ext_test.go | 60 ++++++++++++++++++++++++++++++++++++++++++++- protocol_connect.go | 8 ++++++ protocol_grpc.go | 4 +++ 5 files changed, 79 insertions(+), 3 deletions(-) diff --git a/client_stream.go b/client_stream.go index 6c07c5f0..1716d08d 100644 --- a/client_stream.go +++ b/client_stream.go @@ -184,6 +184,11 @@ func (b *BidiStreamForClient[_, _]) Peer() Peer { return b.conn.Peer() } +// SendHeaders sends the request headers. +func (b *BidiStreamForClient[_, _]) SendHeaders() { + b.conn.SendHeaders() +} + // RequestHeader returns the request headers. Headers are sent with the first // call to Send. func (b *BidiStreamForClient[Req, Res]) RequestHeader() http.Header { diff --git a/connect.go b/connect.go index 283ca9f7..a5b4fc3e 100644 --- a/connect.go +++ b/connect.go @@ -103,9 +103,10 @@ type StreamingClientConn interface { Spec() Spec Peer() Peer - // Send, RequestHeader, and CloseRequest may race with each other, but must - // be safe to call concurrently with all other methods. + // Send, SendHeaders, RequestHeader, and CloseRequest may race with each other, + // but must be safe to call concurrently with all other methods. Send(any) error + SendHeaders() RequestHeader() http.Header CloseRequest() error diff --git a/connect_ext_test.go b/connect_ext_test.go index 894e45dc..4c603548 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -1343,6 +1343,56 @@ func TestClientWithSendMaxBytes(t *testing.T) { }) } +func TestBidiStreamForClientSendHeaders(t *testing.T) { + t.Parallel() + run := func(t *testing.T, opts ...connect.ClientOption) { + t.Helper() + headersSent := make(chan struct{}) + pingServer := &pluggablePingServer{ + cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error { + close(headersSent) + return nil + }, + } + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + server := httptest.NewUnstartedServer(mux) + server.EnableHTTP2 = true + server.StartTLS() + t.Cleanup(server.Close) + + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL, + connect.WithClientOptions(opts...), + connect.WithInterceptors(&assertPeerInterceptor{t}), + ) + stream := client.CumSum(context.Background()) + t.Cleanup(func() { + assert.Nil(t, stream.CloseRequest()) + assert.Nil(t, stream.CloseResponse()) + }) + stream.SendHeaders() + select { + case <-time.After(time.Second): + t.Error("timed out to get request headers") + case <-headersSent: + } + } + t.Run("connect", func(t *testing.T) { + t.Parallel() + run(t) + }) + t.Run("grpc", func(t *testing.T) { + t.Parallel() + run(t, connect.WithGRPC()) + }) + t.Run("grpcweb", func(t *testing.T) { + t.Parallel() + run(t, connect.WithGRPCWeb()) + }) +} + func gzipCompressedSize(tb testing.TB, message proto.Message) int { tb.Helper() uncompressed, err := proto.Marshal(message) @@ -1376,7 +1426,8 @@ func (c failCodec) Unmarshal(data []byte, message any) error { type pluggablePingServer struct { pingv1connect.UnimplementedPingServiceHandler - ping func(context.Context, *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) + ping func(context.Context, *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) + cumSum func(context.Context, *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error } func (p *pluggablePingServer) Ping( @@ -1386,6 +1437,13 @@ func (p *pluggablePingServer) Ping( return p.ping(ctx, request) } +func (p *pluggablePingServer) CumSum( + ctx context.Context, + stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse], +) error { + return p.cumSum(ctx, stream) +} + func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) { tb.Helper() if err := stream.Send(&pingv1.CumSumRequest{}); err != nil { diff --git a/protocol_connect.go b/protocol_connect.go index ec0f0a10..d9fe69e1 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -357,6 +357,10 @@ func (cc *connectUnaryClientConn) Send(msg any) error { return nil // must be a literal nil: nil *Error is a non-nil error } +func (cc *connectUnaryClientConn) SendHeaders() { + cc.duplexCall.ensureRequestMade() +} + func (cc *connectUnaryClientConn) RequestHeader() http.Header { return cc.duplexCall.Header() } @@ -456,6 +460,10 @@ func (cc *connectStreamingClientConn) Send(msg any) error { return nil // must be a literal nil: nil *Error is a non-nil error } +func (cc *connectStreamingClientConn) SendHeaders() { + cc.duplexCall.ensureRequestMade() +} + func (cc *connectStreamingClientConn) RequestHeader() http.Header { return cc.duplexCall.Header() } diff --git a/protocol_grpc.go b/protocol_grpc.go index bd2c14e1..0fa069bb 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -325,6 +325,10 @@ func (cc *grpcClientConn) Send(msg any) error { return nil // must be a literal nil: nil *Error is a non-nil error } +func (cc *grpcClientConn) SendHeaders() { + cc.duplexCall.ensureRequestMade() +} + func (cc *grpcClientConn) RequestHeader() http.Header { return cc.duplexCall.Header() }