diff --git a/client_ext_test.go b/client_ext_test.go index e0ba6381..bfe4e399 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -97,7 +97,8 @@ func TestClientPeer(t *testing.T) { _, closeErr := clientStream.CloseAndReceive() assert.Nil(t, closeErr) }) - assert.NotNil(t, clientStream.Peer().Addr) + assert.NotZero(t, clientStream.Peer().Addr) + assert.NotZero(t, clientStream.Peer().Protocol) err = clientStream.Send(&pingv1.SumRequest{}) assert.Nil(t, err) // server streaming @@ -112,7 +113,8 @@ func TestClientPeer(t *testing.T) { assert.Nil(t, bidiStream.CloseRequest()) assert.Nil(t, bidiStream.CloseResponse()) }) - assert.NotNil(t, bidiStream.Peer().Addr) + assert.NotZero(t, bidiStream.Peer().Addr) + assert.NotZero(t, bidiStream.Peer().Protocol) err = bidiStream.Send(&pingv1.CumSumRequest{}) assert.Nil(t, err) } @@ -138,6 +140,7 @@ type assertPeerInterceptor struct { func (a *assertPeerInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { assert.NotZero(a.tb, req.Peer().Addr) + assert.NotZero(a.tb, req.Peer().Protocol) return next(ctx, req) } } @@ -146,6 +149,7 @@ func (a *assertPeerInterceptor) WrapStreamingClient(next connect.StreamingClient return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { conn := next(ctx, spec) assert.NotZero(a.tb, conn.Peer().Addr) + assert.NotZero(a.tb, conn.Peer().Protocol) return conn } } @@ -153,6 +157,7 @@ func (a *assertPeerInterceptor) WrapStreamingClient(next connect.StreamingClient func (a *assertPeerInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { assert.NotZero(a.tb, conn.Peer().Addr) + assert.NotZero(a.tb, conn.Peer().Protocol) return next(ctx, conn) } } diff --git a/connect.go b/connect.go index 283ca9f7..d04408e8 100644 --- a/connect.go +++ b/connect.go @@ -253,19 +253,26 @@ type Spec struct { IsClient bool // otherwise we're in a handler } -// Peer describes the other party to an RPC. When accessed client-side, Addr -// contains the host or host:port from the server's URL. When accessed -// server-side, Addr contains the client's address in IP:port format. +// Peer describes the other party to an RPC. +// +// When accessed client-side, Addr contains the host or host:port from the +// server's URL. When accessed server-side, Addr contains the client's address +// in IP:port format. +// +// On both the client and the server, Protocol is the RPC protocol in use. +// Currently, it's either [ProtocolConnect], [ProtocolGRPC], or +// [ProtocolGRPCWeb], but additional protocols may be added in the future. type Peer struct { - Addr string + Addr string + Protocol string } -func newPeerFromURL(s string) Peer { - u, err := url.Parse(s) - if err != nil { - return Peer{} +func newPeerFromURL(urlString, protocol string) Peer { + peer := Peer{Protocol: protocol} + if u, err := url.Parse(urlString); err == nil { + peer.Addr = u.Host } - return Peer{Addr: u.Host} + return peer } // handlerConnCloser extends HandlerConn with a method for handlers to diff --git a/connect_ext_test.go b/connect_ext_test.go index ff5313b4..5e7f7dc3 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -1497,6 +1497,9 @@ func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.Pi if request.Peer().Addr == "" { return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) } + if request.Peer().Protocol == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } response := connect.NewResponse( &pingv1.PingResponse{ Number: request.Msg.Number, @@ -1515,6 +1518,9 @@ func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.Fa if request.Peer().Addr == "" { return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) } + if request.Peer().Protocol == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } err := connect.NewError(connect.Code(request.Msg.Code), errors.New(errorMessage)) err.Meta().Set(handlerHeader, headerValue) err.Meta().Set(handlerTrailer, trailerValue) @@ -1533,6 +1539,9 @@ func (p pingServer) Sum( if stream.Peer().Addr == "" { return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) } + if stream.Peer().Protocol == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } var sum int64 for stream.Receive() { sum += stream.Msg().Number @@ -1557,6 +1566,9 @@ func (p pingServer) CountUp( if request.Peer().Addr == "" { return connect.NewError(connect.CodeInternal, errors.New("no peer address")) } + if request.Peer().Protocol == "" { + return connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } if request.Msg.Number <= 0 { return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( "number must be positive: got %v", @@ -1586,6 +1598,9 @@ func (p pingServer) CumSum( if stream.Peer().Addr == "" { return connect.NewError(connect.CodeInternal, errors.New("no peer address")) } + if stream.Peer().Protocol == "" { + return connect.NewError(connect.CodeInternal, errors.New("no peer address")) + } stream.ResponseHeader().Set(handlerHeader, headerValue) stream.ResponseTrailer().Set(handlerTrailer, trailerValue) for { diff --git a/protocol.go b/protocol.go index 8f0e7060..0fe92387 100644 --- a/protocol.go +++ b/protocol.go @@ -26,6 +26,14 @@ import ( "strings" ) +// The names of the Connect, gRPC, and gRPC-Web protocols (as exposed by +// [Peer.Protocol]). Additional protocols may be added in the future. +const ( + ProtocolConnect = "connect" + ProtocolGRPC = "grpc" + ProtocolGRPCWeb = "grpcweb" +) + const ( headerContentType = "Content-Type" headerUserAgent = "User-Agent" diff --git a/protocol_connect.go b/protocol_connect.go index 069d4e35..13d32e23 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -151,7 +151,10 @@ func (h *connectHandler) NewConn( codec := h.Codecs.Get(codecName) // handler.go guarantees this is not nil var conn handlerConnCloser - peer := Peer{Addr: request.RemoteAddr} + peer := Peer{ + Addr: request.RemoteAddr, + Protocol: ProtocolConnect, + } if h.Spec.StreamType == StreamTypeUnary { conn = &connectUnaryHandlerConn{ spec: h.Spec, @@ -221,7 +224,7 @@ type connectClient struct { } func (c *connectClient) Peer() Peer { - return newPeerFromURL(c.URL) + return newPeerFromURL(c.URL, ProtocolConnect) } func (c *connectClient) WriteRequestHeader(streamType StreamType, header http.Header) { diff --git a/protocol_grpc.go b/protocol_grpc.go index bd2c14e1..36f957af 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -161,9 +161,16 @@ func (g *grpcHandler) NewConn( codecName := grpcCodecFromContentType(g.web, request.Header.Get(headerContentType)) codec := g.Codecs.Get(codecName) // handler.go guarantees this is not nil + protocolName := ProtocolGRPC + if g.web { + protocolName = ProtocolGRPCWeb + } conn := wrapHandlerConnWithCodedErrors(&grpcHandlerConn{ - spec: g.Spec, - peer: Peer{Addr: request.RemoteAddr}, + spec: g.Spec, + peer: Peer{ + Addr: request.RemoteAddr, + Protocol: protocolName, + }, web: g.web, bufferPool: g.BufferPool, protobuf: g.Codecs.Protobuf(), // for errors @@ -207,7 +214,10 @@ type grpcClient struct { } func (g *grpcClient) Peer() Peer { - return newPeerFromURL(g.URL) + if g.web { + return newPeerFromURL(g.URL, ProtocolGRPCWeb) + } + return newPeerFromURL(g.URL, ProtocolGRPC) } func (g *grpcClient) WriteRequestHeader(_ StreamType, header http.Header) {