Skip to content

Commit

Permalink
Add RPC protocol to Peer (#394)
Browse files Browse the repository at this point in the history
Currently, observability interceptors must parse the `Content-Type`
header to determine which RPC protocol is in use. This isn't awful, but
it's so easy for `connect-go` to expose directly that we might as well.

I chose to use our own strings rather than repurposing OpenTelemetry's
semantic conventions here. The OTel strings are less nice in this
context, and the packages are all unstable and subject to change.
@joshcarp, LMK if this seems like a bad tradeoff to you.
  • Loading branch information
akshayjshah authored Nov 18, 2022
1 parent 7991753 commit 40a2428
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 16 deletions.
9 changes: 7 additions & 2 deletions client_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ func TestClientPeer(t *testing.T) {
_, closeErr := clientStream.CloseAndReceive()
assert.Nil(t, closeErr)
})
assert.NotNil(t, clientStream.Peer().Addr)
assert.NotZero(t, clientStream.Peer().Addr)
assert.NotZero(t, clientStream.Peer().Protocol)
err = clientStream.Send(&pingv1.SumRequest{})
assert.Nil(t, err)
// server streaming
Expand All @@ -112,7 +113,8 @@ func TestClientPeer(t *testing.T) {
assert.Nil(t, bidiStream.CloseRequest())
assert.Nil(t, bidiStream.CloseResponse())
})
assert.NotNil(t, bidiStream.Peer().Addr)
assert.NotZero(t, bidiStream.Peer().Addr)
assert.NotZero(t, bidiStream.Peer().Protocol)
err = bidiStream.Send(&pingv1.CumSumRequest{})
assert.Nil(t, err)
}
Expand All @@ -138,6 +140,7 @@ type assertPeerInterceptor struct {
func (a *assertPeerInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
assert.NotZero(a.tb, req.Peer().Addr)
assert.NotZero(a.tb, req.Peer().Protocol)
return next(ctx, req)
}
}
Expand All @@ -146,13 +149,15 @@ func (a *assertPeerInterceptor) WrapStreamingClient(next connect.StreamingClient
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
conn := next(ctx, spec)
assert.NotZero(a.tb, conn.Peer().Addr)
assert.NotZero(a.tb, conn.Peer().Protocol)
return conn
}
}

func (a *assertPeerInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
assert.NotZero(a.tb, conn.Peer().Addr)
assert.NotZero(a.tb, conn.Peer().Protocol)
return next(ctx, conn)
}
}
25 changes: 16 additions & 9 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,19 +253,26 @@ type Spec struct {
IsClient bool // otherwise we're in a handler
}

// Peer describes the other party to an RPC. When accessed client-side, Addr
// contains the host or host:port from the server's URL. When accessed
// server-side, Addr contains the client's address in IP:port format.
// Peer describes the other party to an RPC.
//
// When accessed client-side, Addr contains the host or host:port from the
// server's URL. When accessed server-side, Addr contains the client's address
// in IP:port format.
//
// On both the client and the server, Protocol is the RPC protocol in use.
// Currently, it's either [ProtocolConnect], [ProtocolGRPC], or
// [ProtocolGRPCWeb], but additional protocols may be added in the future.
type Peer struct {
Addr string
Addr string
Protocol string
}

func newPeerFromURL(s string) Peer {
u, err := url.Parse(s)
if err != nil {
return Peer{}
func newPeerFromURL(urlString, protocol string) Peer {
peer := Peer{Protocol: protocol}
if u, err := url.Parse(urlString); err == nil {
peer.Addr = u.Host
}
return Peer{Addr: u.Host}
return peer
}

// handlerConnCloser extends HandlerConn with a method for handlers to
Expand Down
15 changes: 15 additions & 0 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1497,6 +1497,9 @@ func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.Pi
if request.Peer().Addr == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
if request.Peer().Protocol == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol"))
}
response := connect.NewResponse(
&pingv1.PingResponse{
Number: request.Msg.Number,
Expand All @@ -1515,6 +1518,9 @@ func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.Fa
if request.Peer().Addr == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
if request.Peer().Protocol == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol"))
}
err := connect.NewError(connect.Code(request.Msg.Code), errors.New(errorMessage))
err.Meta().Set(handlerHeader, headerValue)
err.Meta().Set(handlerTrailer, trailerValue)
Expand All @@ -1533,6 +1539,9 @@ func (p pingServer) Sum(
if stream.Peer().Addr == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
if stream.Peer().Protocol == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol"))
}
var sum int64
for stream.Receive() {
sum += stream.Msg().Number
Expand All @@ -1557,6 +1566,9 @@ func (p pingServer) CountUp(
if request.Peer().Addr == "" {
return connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
if request.Peer().Protocol == "" {
return connect.NewError(connect.CodeInternal, errors.New("no peer protocol"))
}
if request.Msg.Number <= 0 {
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf(
"number must be positive: got %v",
Expand Down Expand Up @@ -1586,6 +1598,9 @@ func (p pingServer) CumSum(
if stream.Peer().Addr == "" {
return connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
if stream.Peer().Protocol == "" {
return connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
stream.ResponseHeader().Set(handlerHeader, headerValue)
stream.ResponseTrailer().Set(handlerTrailer, trailerValue)
for {
Expand Down
8 changes: 8 additions & 0 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ import (
"strings"
)

// The names of the Connect, gRPC, and gRPC-Web protocols (as exposed by
// [Peer.Protocol]). Additional protocols may be added in the future.
const (
ProtocolConnect = "connect"
ProtocolGRPC = "grpc"
ProtocolGRPCWeb = "grpcweb"
)

const (
headerContentType = "Content-Type"
headerUserAgent = "User-Agent"
Expand Down
7 changes: 5 additions & 2 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ func (h *connectHandler) NewConn(
codec := h.Codecs.Get(codecName) // handler.go guarantees this is not nil

var conn handlerConnCloser
peer := Peer{Addr: request.RemoteAddr}
peer := Peer{
Addr: request.RemoteAddr,
Protocol: ProtocolConnect,
}
if h.Spec.StreamType == StreamTypeUnary {
conn = &connectUnaryHandlerConn{
spec: h.Spec,
Expand Down Expand Up @@ -221,7 +224,7 @@ type connectClient struct {
}

func (c *connectClient) Peer() Peer {
return newPeerFromURL(c.URL)
return newPeerFromURL(c.URL, ProtocolConnect)
}

func (c *connectClient) WriteRequestHeader(streamType StreamType, header http.Header) {
Expand Down
16 changes: 13 additions & 3 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,16 @@ func (g *grpcHandler) NewConn(

codecName := grpcCodecFromContentType(g.web, request.Header.Get(headerContentType))
codec := g.Codecs.Get(codecName) // handler.go guarantees this is not nil
protocolName := ProtocolGRPC
if g.web {
protocolName = ProtocolGRPCWeb
}
conn := wrapHandlerConnWithCodedErrors(&grpcHandlerConn{
spec: g.Spec,
peer: Peer{Addr: request.RemoteAddr},
spec: g.Spec,
peer: Peer{
Addr: request.RemoteAddr,
Protocol: protocolName,
},
web: g.web,
bufferPool: g.BufferPool,
protobuf: g.Codecs.Protobuf(), // for errors
Expand Down Expand Up @@ -207,7 +214,10 @@ type grpcClient struct {
}

func (g *grpcClient) Peer() Peer {
return newPeerFromURL(g.URL)
if g.web {
return newPeerFromURL(g.URL, ProtocolGRPCWeb)
}
return newPeerFromURL(g.URL, ProtocolGRPC)
}

func (g *grpcClient) WriteRequestHeader(_ StreamType, header http.Header) {
Expand Down

0 comments on commit 40a2428

Please sign in to comment.