Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In iterators, allocate new msg for each Receive #350

Merged
merged 1 commit into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,19 @@ func (c *ClientStreamForClient[Req, Res]) CloseAndReceive() (*Response[Res], err
return response, c.conn.CloseResponse()
}

// Conn exposes the underlying StreamingClientConn. This may be useful if
// you'd prefer to wrap the connection in a different high-level API.
func (c *ClientStreamForClient[Req, Res]) Conn() (StreamingClientConn, error) {
return c.conn, c.err
}

// ServerStreamForClient is the client's view of a server streaming RPC.
//
// It's returned from [Client].CallServerStream, but doesn't currently have an
// exported constructor function.
type ServerStreamForClient[Res any] struct {
conn StreamingClientConn
msg Res
msg *Res
// Error from client construction. If non-nil, return for all calls.
constructErr error
// Error from conn.Receive().
Expand All @@ -92,15 +98,17 @@ func (s *ServerStreamForClient[Res]) Receive() bool {
if s.constructErr != nil || s.receiveErr != nil {
return false
}
s.receiveErr = s.conn.Receive(&s.msg)
s.msg = new(Res)
s.receiveErr = s.conn.Receive(s.msg)
return s.receiveErr == nil
}

// Msg returns the most recent message unmarshaled by a call to Receive. The
// returned message points to data that will be overwritten by the next call to
// Receive.
// Msg returns the most recent message unmarshaled by a call to Receive.
func (s *ServerStreamForClient[Res]) Msg() *Res {
return &s.msg
if s.msg == nil {
s.msg = new(Res)
}
return s.msg
}

// Err returns the first non-EOF error that was encountered by Receive.
Expand Down Expand Up @@ -140,6 +148,12 @@ func (s *ServerStreamForClient[Res]) Close() error {
return s.conn.CloseResponse()
}

// Conn exposes the underlying StreamingClientConn. This may be useful if
// you'd prefer to wrap the connection in a different high-level API.
func (s *ServerStreamForClient[Res]) Conn() (StreamingClientConn, error) {
return s.conn, s.constructErr
}

// BidiStreamForClient is the client's view of a bidirectional streaming RPC.
//
// It's returned from [Client].CallBidiStream, but doesn't currently have an
Expand Down Expand Up @@ -218,3 +232,9 @@ func (b *BidiStreamForClient[Req, Res]) ResponseTrailer() http.Header {
}
return b.conn.ResponseTrailer()
}

// Conn exposes the underlying StreamingClientConn. This may be useful if
// you'd prefer to wrap the connection in a different high-level API.
func (b *BidiStreamForClient[Req, Res]) Conn() (StreamingClientConn, error) {
return b.conn, b.err
}
43 changes: 39 additions & 4 deletions client_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package connect

import (
"errors"
"fmt"
"net/http"
"testing"

Expand All @@ -26,12 +27,15 @@ import (
func TestClientStreamForClient_NoPanics(t *testing.T) {
t.Parallel()
initErr := errors.New("client init failure")
cs := &ClientStreamForClient[pingv1.PingRequest, pingv1.PingResponse]{err: initErr}
assert.ErrorIs(t, cs.Send(&pingv1.PingRequest{}), initErr)
verifyHeaders(t, cs.RequestHeader())
res, err := cs.CloseAndReceive()
clientStream := &ClientStreamForClient[pingv1.PingRequest, pingv1.PingResponse]{err: initErr}
assert.ErrorIs(t, clientStream.Send(&pingv1.PingRequest{}), initErr)
verifyHeaders(t, clientStream.RequestHeader())
res, err := clientStream.CloseAndReceive()
assert.Nil(t, res)
assert.ErrorIs(t, err, initErr)
conn, err := clientStream.Conn()
assert.NotNil(t, err)
assert.Nil(t, conn)
}

func TestServerStreamForClient_NoPanics(t *testing.T) {
Expand All @@ -44,6 +48,26 @@ func TestServerStreamForClient_NoPanics(t *testing.T) {
assert.False(t, serverStream.Receive())
verifyHeaders(t, serverStream.ResponseHeader())
verifyHeaders(t, serverStream.ResponseTrailer())
conn, err := serverStream.Conn()
assert.NotNil(t, err)
assert.Nil(t, conn)
}

func TestServerStreamForClient(t *testing.T) {
t.Parallel()
stream := &ServerStreamForClient[pingv1.PingResponse]{conn: &nopStreamingClientConn{}}
// Ensure that each call to Receive allocates a new message. This helps
// vtprotobuf, which doesn't automatically zero messages before unmarshaling
// (see https://github.com/bufbuild/connect-go/issues/345), and it's also
// less error-prone for users.
assert.True(t, stream.Receive())
first := fmt.Sprintf("%p", stream.Msg())
assert.True(t, stream.Receive())
second := fmt.Sprintf("%p", stream.Msg())
assert.NotEqual(t, first, second)
conn, err := stream.Conn()
assert.Nil(t, err)
assert.NotNil(t, conn)
}

func TestBidiStreamForClient_NoPanics(t *testing.T) {
Expand All @@ -59,6 +83,9 @@ func TestBidiStreamForClient_NoPanics(t *testing.T) {
assert.ErrorIs(t, bidiStream.Send(&pingv1.CumSumRequest{}), initErr)
assert.ErrorIs(t, bidiStream.CloseRequest(), initErr)
assert.ErrorIs(t, bidiStream.CloseResponse(), initErr)
conn, err := bidiStream.Conn()
assert.NotNil(t, err)
assert.Nil(t, conn)
}

func verifyHeaders(t *testing.T, headers http.Header) {
Expand All @@ -69,3 +96,11 @@ func verifyHeaders(t *testing.T, headers http.Header) {
headers.Set("a", "b")
headers.Del("a")
}

type nopStreamingClientConn struct {
StreamingClientConn
}

func (c *nopStreamingClientConn) Receive(msg any) error {
return nil
}
32 changes: 26 additions & 6 deletions handler_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
// an exported constructor.
type ClientStream[Req any] struct {
conn StreamingHandlerConn
msg Req
msg *Req
err error
}

Expand All @@ -44,15 +44,17 @@ func (c *ClientStream[Req]) Receive() bool {
if c.err != nil {
return false
}
c.err = c.conn.Receive(&c.msg)
c.msg = new(Req)
c.err = c.conn.Receive(c.msg)
return c.err == nil
}

// Msg returns the most recent message unmarshaled by a call to Receive. The
// returned message points to data that will be overwritten by the next call to
// Receive.
// Msg returns the most recent message unmarshaled by a call to Receive.
func (c *ClientStream[Req]) Msg() *Req {
return &c.msg
if c.msg == nil {
c.msg = new(Req)
}
return c.msg
}

// Err returns the first non-EOF error that was encountered by Receive.
Expand All @@ -63,6 +65,12 @@ func (c *ClientStream[Req]) Err() error {
return c.err
}

// Conn exposes the underlying StreamingHandlerConn. This may be useful if
// you'd prefer to wrap the connection in a different high-level API.
func (c *ClientStream[Req]) Conn() StreamingHandlerConn {
return c.conn
}

// ServerStream is the handler's view of a server streaming RPC.
//
// It's constructed as part of [Handler] invocation, but doesn't currently have
Expand All @@ -89,6 +97,12 @@ func (s *ServerStream[Res]) Send(msg *Res) error {
return s.conn.Send(msg)
}

// Conn exposes the underlying StreamingHandlerConn. This may be useful if
// you'd prefer to wrap the connection in a different high-level API.
func (s *ServerStream[Res]) Conn() StreamingHandlerConn {
return s.conn
}

// BidiStream is the handler's view of a bidirectional streaming RPC.
//
// It's constructed as part of [Handler] invocation, but doesn't currently have
Expand Down Expand Up @@ -129,3 +143,9 @@ func (b *BidiStream[Req, Res]) ResponseTrailer() http.Header {
func (b *BidiStream[Req, Res]) Send(msg *Res) error {
return b.conn.Send(msg)
}

// Conn exposes the underlying StreamingHandlerConn. This may be useful if
// you'd prefer to wrap the connection in a different high-level API.
func (b *BidiStream[Req, Res]) Conn() StreamingHandlerConn {
return b.conn
}
41 changes: 41 additions & 0 deletions handler_stream_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright 2021-2022 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package connect

import (
"fmt"
"testing"

"github.com/bufbuild/connect-go/internal/assert"
pingv1 "github.com/bufbuild/connect-go/internal/gen/connect/ping/v1"
)

func TestClientStream(t *testing.T) {
t.Parallel()
stream := &ClientStream[pingv1.PingRequest]{conn: &nopStreamingHandlerConn{}}
assert.True(t, stream.Receive())
first := fmt.Sprintf("%p", stream.Msg())
assert.True(t, stream.Receive())
second := fmt.Sprintf("%p", stream.Msg())
assert.NotEqual(t, first, second)
}

type nopStreamingHandlerConn struct {
StreamingHandlerConn
}

func (nopStreamingHandlerConn) Receive(msg any) error {
return nil
}