From bc83d1799e36fcb74a30dc46ee15ffe0677f78a3 Mon Sep 17 00:00:00 2001 From: Akshay Shah Date: Thu, 17 Nov 2022 21:13:21 -0800 Subject: [PATCH] Add RPC protocol to Peer (#394) Currently, observability interceptors must parse the `Content-Type` header to determine which RPC protocol is in use. This isn't awful, but it's so easy for `connect-go` to expose directly that we might as well. I chose to use our own strings rather than repurposing OpenTelemetry's semantic conventions here. The OTel strings are less nice in this context, and the packages are all unstable and subject to change. @joshcarp, LMK if this seems like a bad tradeoff to you. --- client_ext_test.go | 9 +++++++-- connect.go | 25 ++++++++++++++++--------- connect_ext_test.go | 15 +++++++++++++++ protocol.go | 8 ++++++++ protocol_connect.go | 7 +++++-- protocol_grpc.go | 16 +++++++++++++--- 6 files changed, 64 insertions(+), 16 deletions(-) diff --git a/client_ext_test.go b/client_ext_test.go index 1497b55b..6493ffda 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -97,7 +97,8 @@ func TestClientPeer(t *testing.T) { _, closeErr := clientStream.CloseAndReceive() assert.Nil(t, closeErr) }) - assert.NotNil(t, clientStream.Peer().Addr) + assert.NotZero(t, clientStream.Peer().Addr) + assert.NotZero(t, clientStream.Peer().Protocol) err = clientStream.Send(&pingv1.SumRequest{}) assert.Nil(t, err) // server streaming @@ -112,7 +113,8 @@ func TestClientPeer(t *testing.T) { assert.Nil(t, bidiStream.CloseRequest()) assert.Nil(t, bidiStream.CloseResponse()) }) - assert.NotNil(t, bidiStream.Peer().Addr) + assert.NotZero(t, bidiStream.Peer().Addr) + assert.NotZero(t, bidiStream.Peer().Protocol) err = bidiStream.Send(&pingv1.CumSumRequest{}) assert.Nil(t, err) } @@ -138,6 +140,7 @@ type assertPeerInterceptor struct { func (a *assertPeerInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { assert.NotZero(a.tb, req.Peer().Addr) + assert.NotZero(a.tb, req.Peer().Protocol) return next(ctx, req) } } @@ -146,6 +149,7 @@ func (a *assertPeerInterceptor) WrapStreamingClient(next connect.StreamingClient return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { conn := next(ctx, spec) assert.NotZero(a.tb, conn.Peer().Addr) + assert.NotZero(a.tb, conn.Peer().Protocol) return conn } } @@ -153,6 +157,7 @@ func (a *assertPeerInterceptor) WrapStreamingClient(next connect.StreamingClient func (a *assertPeerInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { assert.NotZero(a.tb, conn.Peer().Addr) + assert.NotZero(a.tb, conn.Peer().Protocol) return next(ctx, conn) } } diff --git a/connect.go b/connect.go index 7562766e..7569d8eb 100644 --- a/connect.go +++ b/connect.go @@ -253,19 +253,26 @@ type Spec struct { IsClient bool // otherwise we're in a handler } -// Peer describes the other party to an RPC. When accessed client-side, Addr -// contains the host or host:port from the server's URL. When accessed -// server-side, Addr contains the client's address in IP:port format. +// Peer describes the other party to an RPC. +// +// When accessed client-side, Addr contains the host or host:port from the +// server's URL. When accessed server-side, Addr contains the client's address +// in IP:port format. +// +// On both the client and the server, Protocol is the RPC protocol in use. +// Currently, it's either [ProtocolConnect], [ProtocolGRPC], or +// [ProtocolGRPCWeb], but additional protocols may be added in the future. type Peer struct { - Addr string + Addr string + Protocol string } -func newPeerFromURL(s string) Peer { - u, err := url.Parse(s) - if err != nil { - return Peer{} +func newPeerFromURL(urlString, protocol string) Peer { + peer := Peer{Protocol: protocol} + if u, err := url.Parse(urlString); err == nil { + peer.Addr = u.Host } - return Peer{Addr: u.Host} + return peer } // handlerConnCloser extends HandlerConn with a method for handlers to diff --git a/connect_ext_test.go b/connect_ext_test.go index 1e160191..fc49b519 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -1497,6 +1497,9 @@ func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.Pi if request.Peer().Addr == "" { return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) } + if request.Peer().Protocol == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } response := connect.NewResponse( &pingv1.PingResponse{ Number: request.Msg.Number, @@ -1515,6 +1518,9 @@ func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.Fa if request.Peer().Addr == "" { return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) } + if request.Peer().Protocol == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } err := connect.NewError(connect.Code(request.Msg.Code), errors.New(errorMessage)) err.Meta().Set(handlerHeader, headerValue) err.Meta().Set(handlerTrailer, trailerValue) @@ -1533,6 +1539,9 @@ func (p pingServer) Sum( if stream.Peer().Addr == "" { return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) } + if stream.Peer().Protocol == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } var sum int64 for stream.Receive() { sum += stream.Msg().Number @@ -1557,6 +1566,9 @@ func (p pingServer) CountUp( if request.Peer().Addr == "" { return connect.NewError(connect.CodeInternal, errors.New("no peer address")) } + if request.Peer().Protocol == "" { + return connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } if request.Msg.Number <= 0 { return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( "number must be positive: got %v", @@ -1586,6 +1598,9 @@ func (p pingServer) CumSum( if stream.Peer().Addr == "" { return connect.NewError(connect.CodeInternal, errors.New("no peer address")) } + if stream.Peer().Protocol == "" { + return connect.NewError(connect.CodeInternal, errors.New("no peer address")) + } stream.ResponseHeader().Set(handlerHeader, headerValue) stream.ResponseTrailer().Set(handlerTrailer, trailerValue) for { diff --git a/protocol.go b/protocol.go index 8f0e7060..0fe92387 100644 --- a/protocol.go +++ b/protocol.go @@ -26,6 +26,14 @@ import ( "strings" ) +// The names of the Connect, gRPC, and gRPC-Web protocols (as exposed by +// [Peer.Protocol]). Additional protocols may be added in the future. +const ( + ProtocolConnect = "connect" + ProtocolGRPC = "grpc" + ProtocolGRPCWeb = "grpcweb" +) + const ( headerContentType = "Content-Type" headerUserAgent = "User-Agent" diff --git a/protocol_connect.go b/protocol_connect.go index 069d4e35..13d32e23 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -151,7 +151,10 @@ func (h *connectHandler) NewConn( codec := h.Codecs.Get(codecName) // handler.go guarantees this is not nil var conn handlerConnCloser - peer := Peer{Addr: request.RemoteAddr} + peer := Peer{ + Addr: request.RemoteAddr, + Protocol: ProtocolConnect, + } if h.Spec.StreamType == StreamTypeUnary { conn = &connectUnaryHandlerConn{ spec: h.Spec, @@ -221,7 +224,7 @@ type connectClient struct { } func (c *connectClient) Peer() Peer { - return newPeerFromURL(c.URL) + return newPeerFromURL(c.URL, ProtocolConnect) } func (c *connectClient) WriteRequestHeader(streamType StreamType, header http.Header) { diff --git a/protocol_grpc.go b/protocol_grpc.go index 7f67eb74..55832b5e 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -161,9 +161,16 @@ func (g *grpcHandler) NewConn( codecName := grpcCodecFromContentType(g.web, request.Header.Get(headerContentType)) codec := g.Codecs.Get(codecName) // handler.go guarantees this is not nil + protocolName := ProtocolGRPC + if g.web { + protocolName = ProtocolGRPCWeb + } conn := wrapHandlerConnWithCodedErrors(&grpcHandlerConn{ - spec: g.Spec, - peer: Peer{Addr: request.RemoteAddr}, + spec: g.Spec, + peer: Peer{ + Addr: request.RemoteAddr, + Protocol: protocolName, + }, web: g.web, bufferPool: g.BufferPool, protobuf: g.Codecs.Protobuf(), // for errors @@ -207,7 +214,10 @@ type grpcClient struct { } func (g *grpcClient) Peer() Peer { - return newPeerFromURL(g.URL) + if g.web { + return newPeerFromURL(g.URL, ProtocolGRPCWeb) + } + return newPeerFromURL(g.URL, ProtocolGRPC) } func (g *grpcClient) WriteRequestHeader(_ StreamType, header http.Header) {