Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RPC protocol to Peer #394

Merged
merged 3 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions client_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ func TestClientPeer(t *testing.T) {
_, closeErr := clientStream.CloseAndReceive()
assert.Nil(t, closeErr)
})
assert.NotNil(t, clientStream.Peer().Addr)
assert.NotZero(t, clientStream.Peer().Addr)
assert.NotZero(t, clientStream.Peer().Protocol)
err = clientStream.Send(&pingv1.SumRequest{})
assert.Nil(t, err)
// server streaming
Expand All @@ -112,7 +113,8 @@ func TestClientPeer(t *testing.T) {
assert.Nil(t, bidiStream.CloseRequest())
assert.Nil(t, bidiStream.CloseResponse())
})
assert.NotNil(t, bidiStream.Peer().Addr)
assert.NotZero(t, bidiStream.Peer().Addr)
assert.NotZero(t, bidiStream.Peer().Protocol)
err = bidiStream.Send(&pingv1.CumSumRequest{})
assert.Nil(t, err)
}
Expand All @@ -138,6 +140,7 @@ type assertPeerInterceptor struct {
func (a *assertPeerInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
assert.NotZero(a.tb, req.Peer().Addr)
assert.NotZero(a.tb, req.Peer().Protocol)
return next(ctx, req)
}
}
Expand All @@ -146,13 +149,15 @@ func (a *assertPeerInterceptor) WrapStreamingClient(next connect.StreamingClient
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
conn := next(ctx, spec)
assert.NotZero(a.tb, conn.Peer().Addr)
assert.NotZero(a.tb, conn.Peer().Protocol)
return conn
}
}

func (a *assertPeerInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
assert.NotZero(a.tb, conn.Peer().Addr)
assert.NotZero(a.tb, conn.Peer().Protocol)
return next(ctx, conn)
}
}
25 changes: 16 additions & 9 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,19 +253,26 @@ type Spec struct {
IsClient bool // otherwise we're in a handler
}

// Peer describes the other party to an RPC. When accessed client-side, Addr
// contains the host or host:port from the server's URL. When accessed
// server-side, Addr contains the client's address in IP:port format.
// Peer describes the other party to an RPC.
//
// When accessed client-side, Addr contains the host or host:port from the
// server's URL. When accessed server-side, Addr contains the client's address
// in IP:port format.
//
// On both the client and the server, Protocol is the RPC protocol in use.
// Currently, it's either [ProtocolConnect], [ProtocolGRPC], or
// [ProtocolGRPCWeb], but additional protocols may be added in the future.
type Peer struct {
Addr string
Addr string
Protocol string
}

func newPeerFromURL(s string) Peer {
u, err := url.Parse(s)
if err != nil {
return Peer{}
func newPeerFromURL(urlString, protocol string) Peer {
peer := Peer{Protocol: protocol}
if u, err := url.Parse(urlString); err == nil {
peer.Addr = u.Host
}
return Peer{Addr: u.Host}
return peer
}

// handlerConnCloser extends HandlerConn with a method for handlers to
Expand Down
15 changes: 15 additions & 0 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1497,6 +1497,9 @@ func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.Pi
if request.Peer().Addr == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
if request.Peer().Protocol == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol"))
}
response := connect.NewResponse(
&pingv1.PingResponse{
Number: request.Msg.Number,
Expand All @@ -1515,6 +1518,9 @@ func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.Fa
if request.Peer().Addr == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
if request.Peer().Protocol == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol"))
}
err := connect.NewError(connect.Code(request.Msg.Code), errors.New(errorMessage))
err.Meta().Set(handlerHeader, headerValue)
err.Meta().Set(handlerTrailer, trailerValue)
Expand All @@ -1533,6 +1539,9 @@ func (p pingServer) Sum(
if stream.Peer().Addr == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
if stream.Peer().Protocol == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol"))
}
var sum int64
for stream.Receive() {
sum += stream.Msg().Number
Expand All @@ -1557,6 +1566,9 @@ func (p pingServer) CountUp(
if request.Peer().Addr == "" {
return connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
if request.Peer().Protocol == "" {
return connect.NewError(connect.CodeInternal, errors.New("no peer protocol"))
}
if request.Msg.Number <= 0 {
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf(
"number must be positive: got %v",
Expand Down Expand Up @@ -1586,6 +1598,9 @@ func (p pingServer) CumSum(
if stream.Peer().Addr == "" {
return connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
if stream.Peer().Protocol == "" {
return connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
stream.ResponseHeader().Set(handlerHeader, headerValue)
stream.ResponseTrailer().Set(handlerTrailer, trailerValue)
for {
Expand Down
8 changes: 8 additions & 0 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ import (
"strings"
)

// The names of the Connect, gRPC, and gRPC-Web protocols (as exposed by
// [Peer.Protocol]). Additional protocols may be added in the future.
const (
ProtocolConnect = "connect"
ProtocolGRPC = "grpc"
ProtocolGRPCWeb = "grpcweb"
)

const (
headerContentType = "Content-Type"
headerUserAgent = "User-Agent"
Expand Down
7 changes: 5 additions & 2 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ func (h *connectHandler) NewConn(
codec := h.Codecs.Get(codecName) // handler.go guarantees this is not nil

var conn handlerConnCloser
peer := Peer{Addr: request.RemoteAddr}
peer := Peer{
Addr: request.RemoteAddr,
Protocol: ProtocolConnect,
}
if h.Spec.StreamType == StreamTypeUnary {
conn = &connectUnaryHandlerConn{
spec: h.Spec,
Expand Down Expand Up @@ -221,7 +224,7 @@ type connectClient struct {
}

func (c *connectClient) Peer() Peer {
return newPeerFromURL(c.URL)
return newPeerFromURL(c.URL, ProtocolConnect)
}

func (c *connectClient) WriteRequestHeader(streamType StreamType, header http.Header) {
Expand Down
16 changes: 13 additions & 3 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,16 @@ func (g *grpcHandler) NewConn(

codecName := grpcCodecFromContentType(g.web, request.Header.Get(headerContentType))
codec := g.Codecs.Get(codecName) // handler.go guarantees this is not nil
protocolName := ProtocolGRPC
if g.web {
protocolName = ProtocolGRPCWeb
}
conn := wrapHandlerConnWithCodedErrors(&grpcHandlerConn{
spec: g.Spec,
peer: Peer{Addr: request.RemoteAddr},
spec: g.Spec,
peer: Peer{
Addr: request.RemoteAddr,
Protocol: protocolName,
},
web: g.web,
bufferPool: g.BufferPool,
protobuf: g.Codecs.Protobuf(), // for errors
Expand Down Expand Up @@ -207,7 +214,10 @@ type grpcClient struct {
}

func (g *grpcClient) Peer() Peer {
return newPeerFromURL(g.URL)
if g.web {
return newPeerFromURL(g.URL, ProtocolGRPCWeb)
}
return newPeerFromURL(g.URL, ProtocolGRPC)
}

func (g *grpcClient) WriteRequestHeader(_ StreamType, header http.Header) {
Expand Down