From 79917534e74c087d3aedbd002bfcd8033d20de7a Mon Sep 17 00:00:00 2001 From: Akshay Shah Date: Sat, 5 Nov 2022 19:02:21 -0400 Subject: [PATCH] Make client.Send(nil) send request headers Adding to the StreamingHandlerConn interface isn't backward-compatible, but we need a way for clients to send request headers without a body. This commit removes the previous commit's `SendHeaders` method, but makes `Send(nil)` accomplish the same thing. --- client_stream.go | 14 ++++++++------ connect.go | 5 ++--- connect_ext_test.go | 4 ++-- envelope.go | 9 +++++++++ handler_stream.go | 6 ++++++ protocol_connect.go | 11 +++-------- protocol_grpc.go | 4 ---- 7 files changed, 30 insertions(+), 23 deletions(-) diff --git a/client_stream.go b/client_stream.go index 1716d08d..c7557afa 100644 --- a/client_stream.go +++ b/client_stream.go @@ -59,6 +59,9 @@ func (c *ClientStreamForClient[Req, Res]) Send(request *Req) error { if c.err != nil { return c.err } + if request == nil { + return c.conn.Send(nil) + } return c.conn.Send(request) } @@ -184,11 +187,6 @@ func (b *BidiStreamForClient[_, _]) Peer() Peer { return b.conn.Peer() } -// SendHeaders sends the request headers. -func (b *BidiStreamForClient[_, _]) SendHeaders() { - b.conn.SendHeaders() -} - // RequestHeader returns the request headers. Headers are sent with the first // call to Send. func (b *BidiStreamForClient[Req, Res]) RequestHeader() http.Header { @@ -199,7 +197,8 @@ func (b *BidiStreamForClient[Req, Res]) RequestHeader() http.Header { } // Send a message to the server. The first call to Send also sends the request -// headers. +// headers. To send just the request headers, without a body, call Send with a +// nil pointer. // // If the server returns an error, Send returns an error that wraps [io.EOF]. // Clients should check for EOF using the standard library's [errors.Is] and @@ -208,6 +207,9 @@ func (b *BidiStreamForClient[Req, Res]) Send(msg *Req) error { if b.err != nil { return b.err } + if msg == nil { + return b.conn.Send(nil) + } return b.conn.Send(msg) } diff --git a/connect.go b/connect.go index a5b4fc3e..283ca9f7 100644 --- a/connect.go +++ b/connect.go @@ -103,10 +103,9 @@ type StreamingClientConn interface { Spec() Spec Peer() Peer - // Send, SendHeaders, RequestHeader, and CloseRequest may race with each other, - // but must be safe to call concurrently with all other methods. + // Send, RequestHeader, and CloseRequest may race with each other, but must + // be safe to call concurrently with all other methods. Send(any) error - SendHeaders() RequestHeader() http.Header CloseRequest() error diff --git a/connect_ext_test.go b/connect_ext_test.go index 4c603548..ff5313b4 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -1343,7 +1343,7 @@ func TestClientWithSendMaxBytes(t *testing.T) { }) } -func TestBidiStreamForClientSendHeaders(t *testing.T) { +func TestBidiStreamServerSendsFirstMessage(t *testing.T) { t.Parallel() run := func(t *testing.T, opts ...connect.ClientOption) { t.Helper() @@ -1372,7 +1372,7 @@ func TestBidiStreamForClientSendHeaders(t *testing.T) { assert.Nil(t, stream.CloseRequest()) assert.Nil(t, stream.CloseResponse()) }) - stream.SendHeaders() + assert.Nil(t, stream.Send(nil)) select { case <-time.After(time.Second): t.Error("timed out to get request headers") diff --git a/envelope.go b/envelope.go index 0f259817..ecde9570 100644 --- a/envelope.go +++ b/envelope.go @@ -58,6 +58,15 @@ type envelopeWriter struct { } func (w *envelopeWriter) Marshal(message any) *Error { + if message == nil { + if _, err := w.writer.Write(nil); err != nil { + if connectErr, ok := asError(err); ok { + return connectErr + } + return NewError(CodeUnknown, err) + } + return nil + } raw, err := w.codec.Marshal(message) if err != nil { return errorf(CodeInternal, "marshal message: %w", err) diff --git a/handler_stream.go b/handler_stream.go index eec9d8f1..5eb7a953 100644 --- a/handler_stream.go +++ b/handler_stream.go @@ -104,6 +104,9 @@ func (s *ServerStream[Res]) ResponseTrailer() http.Header { // Send a message to the client. The first call to Send also sends the response // headers. func (s *ServerStream[Res]) Send(msg *Res) error { + if msg == nil { + return s.conn.Send(nil) + } return s.conn.Send(msg) } @@ -161,6 +164,9 @@ func (b *BidiStream[Req, Res]) ResponseTrailer() http.Header { // Send a message to the client. The first call to Send also sends the response // headers. func (b *BidiStream[Req, Res]) Send(msg *Res) error { + if msg == nil { + return b.conn.Send(nil) + } return b.conn.Send(msg) } diff --git a/protocol_connect.go b/protocol_connect.go index d9fe69e1..069d4e35 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -357,10 +357,6 @@ func (cc *connectUnaryClientConn) Send(msg any) error { return nil // must be a literal nil: nil *Error is a non-nil error } -func (cc *connectUnaryClientConn) SendHeaders() { - cc.duplexCall.ensureRequestMade() -} - func (cc *connectUnaryClientConn) RequestHeader() http.Header { return cc.duplexCall.Header() } @@ -460,10 +456,6 @@ func (cc *connectStreamingClientConn) Send(msg any) error { return nil // must be a literal nil: nil *Error is a non-nil error } -func (cc *connectStreamingClientConn) SendHeaders() { - cc.duplexCall.ensureRequestMade() -} - func (cc *connectStreamingClientConn) RequestHeader() http.Header { return cc.duplexCall.Header() } @@ -753,6 +745,9 @@ type connectUnaryMarshaler struct { } func (m *connectUnaryMarshaler) Marshal(message any) *Error { + if message == nil { + return m.write(nil) + } data, err := m.codec.Marshal(message) if err != nil { return errorf(CodeInternal, "marshal message: %w", err) diff --git a/protocol_grpc.go b/protocol_grpc.go index 0fa069bb..bd2c14e1 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -325,10 +325,6 @@ func (cc *grpcClientConn) Send(msg any) error { return nil // must be a literal nil: nil *Error is a non-nil error } -func (cc *grpcClientConn) SendHeaders() { - cc.duplexCall.ensureRequestMade() -} - func (cc *grpcClientConn) RequestHeader() http.Header { return cc.duplexCall.Header() }