diff --git a/connect.go b/connect.go index 3149aada..a973e505 100644 --- a/connect.go +++ b/connect.go @@ -354,34 +354,87 @@ type handlerConnCloser interface { Close(error) error } +// receiveConn represents the shared methods of both StreamingClientConn and StreamingHandlerConn +// that the below helper functions use for implementing the rules around a "unary" stream, that +// is expected to have exactly one message (or zero messages followed by a non-EOF error). +type receiveConn interface { + Spec() Spec + Receive(any) error +} + +// hasHTTPMethod is implemented by streaming connections that support HTTP methods other than +// POST. +type hasHTTPMethod interface { + getHTTPMethod() string +} + // receiveUnaryResponse unmarshals a message from a StreamingClientConn, then // envelopes the message and attaches headers and trailers. It attempts to // consume the response stream and isn't appropriate when receiving multiple // messages. func receiveUnaryResponse[T any](conn StreamingClientConn, initializer maybeInitializer) (*Response[T], error) { + msg, err := receiveUnaryMessage[T](conn, initializer, "response") + if err != nil { + return nil, err + } + return &Response[T]{ + Msg: msg, + header: conn.ResponseHeader(), + trailer: conn.ResponseTrailer(), + }, nil +} + +// receiveUnaryRequest unmarshals a message from a StreamingClientConn, then +// envelopes the message and attaches headers and other request properties. It +// attempts to consume the request stream and isn't appropriate when receiving +// multiple messages. +func receiveUnaryRequest[T any](conn StreamingHandlerConn, initializer maybeInitializer) (*Request[T], error) { + msg, err := receiveUnaryMessage[T](conn, initializer, "request") + if err != nil { + return nil, err + } + method := http.MethodPost + if hasRequestMethod, ok := conn.(hasHTTPMethod); ok { + method = hasRequestMethod.getHTTPMethod() + } + return &Request[T]{ + Msg: msg, + spec: conn.Spec(), + peer: conn.Peer(), + header: conn.RequestHeader(), + method: method, + }, nil +} + +func receiveUnaryMessage[T any](conn receiveConn, initializer maybeInitializer, what string) (*T, error) { var msg T if err := initializer.maybe(conn.Spec(), &msg); err != nil { return nil, err } + // Possibly counter-intuitive, but the gRPC specs about error codes state that both clients + // and servers should return "unimplemented" when they encounter a cardinality violation: where + // the number of messages in the stream is wrong. Search for "cardinality violation" in the + // following docs: + // https://grpc.github.io/grpc/core/md_doc_statuscodes.html if err := conn.Receive(&msg); err != nil { + if errors.Is(err, io.EOF) { + err = NewError(CodeUnimplemented, fmt.Errorf("unary %s has zero messages", what)) + } return nil, err } - // In a well-formed stream, the response message may be followed by a block - // of in-stream trailers or HTTP trailers. To ensure that we receive the - // trailers, try to read another message from the stream. - // TODO: optimise unary calls to avoid this extra receive. + // In a well-formed stream, the one message must be the only content in the body. + // To verify that it is well-formed, try to read another message from the stream. + // TODO: optimise this second receive: ideally do it w/out allocation, w/out + // fully reading next message (if one is present), and w/out trying to + // actually unmarshal the bytes) var msg2 T if err := initializer.maybe(conn.Spec(), &msg2); err != nil { return nil, err } if err := conn.Receive(&msg2); err == nil { - return nil, NewError(CodeUnknown, errors.New("unary stream has multiple messages")) + return nil, NewError(CodeUnimplemented, fmt.Errorf("unary %s has multiple messages", what)) } else if err != nil && !errors.Is(err, io.EOF) { return nil, err } - return &Response[T]{ - Msg: &msg, - header: conn.ResponseHeader(), - trailer: conn.ResponseTrailer(), - }, nil + return &msg, nil } diff --git a/connect_ext_test.go b/connect_ext_test.go index 86bdeeff..5f2076d2 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -1689,6 +1689,7 @@ func TestStreamForServer(t *testing.T) { client := newPingClient(t, &pluggablePingServer{ sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { assert.True(t, stream.Receive()) + // We end up sending two response messages, but only one is expected. assert.Nil(t, stream.Conn().Send(&pingv1.SumResponse{Sum: 2})) return connect.NewResponse(&pingv1.SumResponse{}), nil }, @@ -1697,7 +1698,7 @@ func TestStreamForServer(t *testing.T) { assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 1})) res, err := stream.CloseAndReceive() assert.NotNil(t, err) - assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown) + assert.Equal(t, connect.CodeOf(err), connect.CodeUnimplemented) assert.Nil(t, res) }) } diff --git a/handler.go b/handler.go index 77724bdf..1d573291 100644 --- a/handler.go +++ b/handler.go @@ -63,24 +63,10 @@ func NewUnaryHandler[Req, Res any]( } // Given a stream, how should we call the unary function? implementation := func(ctx context.Context, conn StreamingHandlerConn) error { - var msg Req - if err := config.Initializer.maybe(conn.Spec(), &msg); err != nil { - return err - } - if err := conn.Receive(&msg); err != nil { + request, err := receiveUnaryRequest[Req](conn, config.Initializer) + if err != nil { return err } - method := http.MethodPost - if hasRequestMethod, ok := conn.(interface{ getHTTPMethod() string }); ok { - method = hasRequestMethod.getHTTPMethod() - } - request := &Request[Req]{ - Msg: &msg, - spec: conn.Spec(), - peer: conn.Peer(), - header: conn.RequestHeader(), - method: method, - } response, err := untyped(ctx, request) if err != nil { return err @@ -140,24 +126,11 @@ func NewServerStreamHandler[Req, Res any]( return newStreamHandler( config, func(ctx context.Context, conn StreamingHandlerConn) error { - var msg Req - if err := config.Initializer.maybe(conn.Spec(), &msg); err != nil { - return err - } - if err := conn.Receive(&msg); err != nil { + req, err := receiveUnaryRequest[Req](conn, config.Initializer) + if err != nil { return err } - return implementation( - ctx, - &Request[Req]{ - Msg: &msg, - spec: conn.Spec(), - peer: conn.Peer(), - header: conn.RequestHeader(), - method: http.MethodPost, - }, - &ServerStream[Res]{conn: conn}, - ) + return implementation(ctx, req, &ServerStream[Res]{conn: conn}) }, ) } diff --git a/internal/conformance/known-failing.txt b/internal/conformance/known-failing.txt index 83bcb82f..c92535b2 100644 --- a/internal/conformance/known-failing.txt +++ b/internal/conformance/known-failing.txt @@ -25,3 +25,15 @@ HTTP to Connect Code Mapping/**/payload-too-large HTTP to Connect Code Mapping/**/precondition-failed HTTP to Connect Code Mapping/**/request-header-fields-too-large HTTP to Connect Code Mapping/**/request-timeout + +# The current v1.0.0-rc3 of conformance suite has expectations for these +# conditions that were based on the behavior of grpc-go (which returns an +# "unknown" error), with the incorrect idea that was authoritative (and, +# honestly, that code makes sense). However, the actual correct behavior, +# per the specification for gRPC error codes, is for these cardinality +# violations to instead return "unimplemented": +# https://grpc.github.io/grpc/core/md_doc_statuscodes.html +# This library returns the correct code, which (for now) is interpreted +# as a failure by the conformance suite. +**/unary-ok-but-no-response +**/unary-multiple-responses \ No newline at end of file