Skip to content

Commit

Permalink
Add SendHeaders to BidiStreamForClient
Browse files Browse the repository at this point in the history
  • Loading branch information
Linniem authored and akshayjshah committed Nov 7, 2022
1 parent a366d93 commit 52877cf
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 3 deletions.
5 changes: 5 additions & 0 deletions client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ func (b *BidiStreamForClient[_, _]) Peer() Peer {
return b.conn.Peer()
}

// SendHeaders sends the request headers.
func (b *BidiStreamForClient[_, _]) SendHeaders() {
b.conn.SendHeaders()
}

// RequestHeader returns the request headers. Headers are sent with the first
// call to Send.
func (b *BidiStreamForClient[Req, Res]) RequestHeader() http.Header {
Expand Down
5 changes: 3 additions & 2 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,10 @@ type StreamingClientConn interface {
Spec() Spec
Peer() Peer

// Send, RequestHeader, and CloseRequest may race with each other, but must
// be safe to call concurrently with all other methods.
// Send, SendHeaders, RequestHeader, and CloseRequest may race with each other,
// but must be safe to call concurrently with all other methods.
Send(any) error
SendHeaders()
RequestHeader() http.Header
CloseRequest() error

Expand Down
60 changes: 59 additions & 1 deletion connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,56 @@ func TestClientWithSendMaxBytes(t *testing.T) {
})
}

func TestBidiStreamForClientSendHeaders(t *testing.T) {
t.Parallel()
run := func(t *testing.T, opts ...connect.ClientOption) {
t.Helper()
headersSent := make(chan struct{})
pingServer := &pluggablePingServer{
cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error {
close(headersSent)
return nil
},
}
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer))
server := httptest.NewUnstartedServer(mux)
server.EnableHTTP2 = true
server.StartTLS()
t.Cleanup(server.Close)

client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL,
connect.WithClientOptions(opts...),
connect.WithInterceptors(&assertPeerInterceptor{t}),
)
stream := client.CumSum(context.Background())
t.Cleanup(func() {
assert.Nil(t, stream.CloseRequest())
assert.Nil(t, stream.CloseResponse())
})
stream.SendHeaders()
select {
case <-time.After(time.Second):
t.Error("timed out to get request headers")
case <-headersSent:
}
}
t.Run("connect", func(t *testing.T) {
t.Parallel()
run(t)
})
t.Run("grpc", func(t *testing.T) {
t.Parallel()
run(t, connect.WithGRPC())
})
t.Run("grpcweb", func(t *testing.T) {
t.Parallel()
run(t, connect.WithGRPCWeb())
})
}

func gzipCompressedSize(tb testing.TB, message proto.Message) int {
tb.Helper()
uncompressed, err := proto.Marshal(message)
Expand Down Expand Up @@ -1376,7 +1426,8 @@ func (c failCodec) Unmarshal(data []byte, message any) error {
type pluggablePingServer struct {
pingv1connect.UnimplementedPingServiceHandler

ping func(context.Context, *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error)
ping func(context.Context, *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error)
cumSum func(context.Context, *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error
}

func (p *pluggablePingServer) Ping(
Expand All @@ -1386,6 +1437,13 @@ func (p *pluggablePingServer) Ping(
return p.ping(ctx, request)
}

func (p *pluggablePingServer) CumSum(
ctx context.Context,
stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse],
) error {
return p.cumSum(ctx, stream)
}

func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) {
tb.Helper()
if err := stream.Send(&pingv1.CumSumRequest{}); err != nil {
Expand Down
8 changes: 8 additions & 0 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,10 @@ func (cc *connectUnaryClientConn) Send(msg any) error {
return nil // must be a literal nil: nil *Error is a non-nil error
}

func (cc *connectUnaryClientConn) SendHeaders() {
cc.duplexCall.ensureRequestMade()
}

func (cc *connectUnaryClientConn) RequestHeader() http.Header {
return cc.duplexCall.Header()
}
Expand Down Expand Up @@ -456,6 +460,10 @@ func (cc *connectStreamingClientConn) Send(msg any) error {
return nil // must be a literal nil: nil *Error is a non-nil error
}

func (cc *connectStreamingClientConn) SendHeaders() {
cc.duplexCall.ensureRequestMade()
}

func (cc *connectStreamingClientConn) RequestHeader() http.Header {
return cc.duplexCall.Header()
}
Expand Down
4 changes: 4 additions & 0 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ func (cc *grpcClientConn) Send(msg any) error {
return nil // must be a literal nil: nil *Error is a non-nil error
}

func (cc *grpcClientConn) SendHeaders() {
cc.duplexCall.ensureRequestMade()
}

func (cc *grpcClientConn) RequestHeader() http.Header {
return cc.duplexCall.Header()
}
Expand Down

0 comments on commit 52877cf

Please sign in to comment.