Skip to content

Commit

Permalink
Expose peer information to handlers and clients (#364)
Browse files Browse the repository at this point in the history
* Add Spec to some user-facing streams

Where the generated code replaces Request with a stream, expose the
Spec.

* Expose peer information to servers and clients

On requests and streams, expose the peer's address. For clients, the
address is the host or host:port, parsed from the URL. For servers, the
address is an IP:port pair, provided by `net/http` as
`Request.RemoteAddr`.

In #357, a handful of users asked for this information for logging,
per-IP rate limiting, and a variety of other server-side concerns. It's
also necessary for OpenTelemetry interceptors (#344).

Fixes #357.
  • Loading branch information
akshayjshah authored Sep 23, 2022
1 parent 74e21a3 commit d3c2b89
Show file tree
Hide file tree
Showing 10 changed files with 224 additions and 4 deletions.
8 changes: 5 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
// once at client creation.
unarySpec := config.newSpec(StreamTypeUnary)
unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
conn := protocolClient.NewConn(ctx, unarySpec, request.Header())
conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header())
// Send always returns an io.EOF unless the error is from the client-side.
// We want the user to continue to call Receive in those cases to get the
// full error from the server-side.
Expand All @@ -94,9 +94,11 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
unaryFunc = interceptor.WrapUnary(unaryFunc)
}
client.callUnary = func(ctx context.Context, request *Request[Req]) (*Response[Res], error) {
// To make the specification and RPC headers visible to the full interceptor
// chain (as though they were supplied by the caller), we'll add them here.
// To make the specification, peer, and RPC headers visible to the full
// interceptor chain (as though they were supplied by the caller), we'll
// add them here.
request.spec = unarySpec
request.peer = client.protocolClient.Peer()
protocolClient.WriteRequestHeader(StreamTypeUnary, request.Header())
response, err := unaryFunc(ctx, request)
if err != nil {
Expand Down
88 changes: 88 additions & 0 deletions client_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"

"connectrpc.com/connect"
Expand Down Expand Up @@ -68,3 +69,90 @@ func TestNewClient_InitFailure(t *testing.T) {
validateExpectedError(t, err)
})
}

func TestClientPeer(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}))
server := httptest.NewUnstartedServer(mux)
server.EnableHTTP2 = true
server.StartTLS()
t.Cleanup(server.Close)

run := func(t *testing.T, opts ...connect.ClientOption) {
t.Helper()
client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL,
connect.WithClientOptions(opts...),
connect.WithInterceptors(&assertPeerInterceptor{t}),
)
ctx := context.Background()
// unary
_, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{}))
assert.Nil(t, err)
// client streaming
clientStream := client.Sum(ctx)
t.Cleanup(func() {
_, closeErr := clientStream.CloseAndReceive()
assert.Nil(t, closeErr)
})
assert.NotNil(t, clientStream.Peer().Addr)
err = clientStream.Send(&pingv1.SumRequest{})
assert.Nil(t, err)
// server streaming
serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{}))
t.Cleanup(func() {
assert.Nil(t, serverStream.Close())
})
assert.Nil(t, err)
// bidi streaming
bidiStream := client.CumSum(ctx)
t.Cleanup(func() {
assert.Nil(t, bidiStream.CloseRequest())
assert.Nil(t, bidiStream.CloseResponse())
})
assert.NotNil(t, bidiStream.Peer().Addr)
err = bidiStream.Send(&pingv1.CumSumRequest{})
assert.Nil(t, err)
}

t.Run("connect", func(t *testing.T) {
t.Parallel()
run(t)
})
t.Run("grpc", func(t *testing.T) {
t.Parallel()
run(t, connect.WithGRPC())
})
t.Run("grpcweb", func(t *testing.T) {
t.Parallel()
run(t, connect.WithGRPCWeb())
})
}

type assertPeerInterceptor struct {
tb testing.TB
}

func (a *assertPeerInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
assert.NotZero(a.tb, req.Peer().Addr)
return next(ctx, req)
}
}

func (a *assertPeerInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
conn := next(ctx, spec)
assert.NotZero(a.tb, conn.Peer().Addr)
return conn
}
}

func (a *assertPeerInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
assert.NotZero(a.tb, conn.Peer().Addr)
return next(ctx, conn)
}
}
20 changes: 20 additions & 0 deletions client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ type ClientStreamForClient[Req, Res any] struct {
err error
}

// Spec returns the specification for the RPC.
func (c *ClientStreamForClient[_, _]) Spec() Spec {
return c.conn.Spec()
}

// Peer describes the server for the RPC.
func (c *ClientStreamForClient[_, _]) Peer() Peer {
return c.conn.Peer()
}

// RequestHeader returns the request headers. Headers are sent to the server with the
// first call to Send.
func (c *ClientStreamForClient[Req, Res]) RequestHeader() http.Header {
Expand Down Expand Up @@ -164,6 +174,16 @@ type BidiStreamForClient[Req, Res any] struct {
err error
}

// Spec returns the specification for the RPC.
func (b *BidiStreamForClient[_, _]) Spec() Spec {
return b.conn.Spec()
}

// Peer describes the server for the RPC.
func (b *BidiStreamForClient[_, _]) Peer() Peer {
return b.conn.Peer()
}

// RequestHeader returns the request headers. Headers are sent with the first
// call to Send.
func (b *BidiStreamForClient[Req, Res]) RequestHeader() http.Header {
Expand Down
27 changes: 26 additions & 1 deletion connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"errors"
"io"
"net/http"
"net/url"
)

// Version is the semantic version of the connect module.
Expand Down Expand Up @@ -68,6 +69,7 @@ const (
// StreamingHandlerConn implementations do not need to be safe for concurrent use.
type StreamingHandlerConn interface {
Spec() Spec
Peer() Peer

Receive(any) error
RequestHeader() http.Header
Expand Down Expand Up @@ -97,8 +99,9 @@ type StreamingHandlerConn interface {
// implementations must support limited concurrent use. See the comments on
// each group of methods for details.
type StreamingClientConn interface {
// Spec must be safe to call concurrently with all other methods.
// Spec and Peer must be safe to call concurrently with all other methods.
Spec() Spec
Peer() Peer

// Send, RequestHeader, and CloseRequest may race with each other, but must
// be safe to call concurrently with all other methods.
Expand All @@ -121,6 +124,7 @@ type Request[T any] struct {
Msg *T

spec Spec
peer Peer
header http.Header
}

Expand All @@ -144,6 +148,11 @@ func (r *Request[_]) Spec() Spec {
return r.spec
}

// Peer describes the other party for this RPC.
func (r *Request[_]) Peer() Peer {
return r.peer
}

// Header returns the HTTP headers for this request.
func (r *Request[_]) Header() http.Header {
if r.header == nil {
Expand All @@ -164,6 +173,7 @@ func (r *Request[_]) internalOnly() {}
type AnyRequest interface {
Any() any
Spec() Spec
Peer() Peer
Header() http.Header

internalOnly()
Expand Down Expand Up @@ -243,6 +253,21 @@ type Spec struct {
IsClient bool // otherwise we're in a handler
}

// Peer describes the other party to an RPC. When accessed client-side, Addr
// contains the host or host:port from the server's URL. When accessed
// server-side, Addr contains the client's address in IP:port format.
type Peer struct {
Addr string
}

func newPeerFromURL(s string) Peer {
u, err := url.Parse(s)
if err != nil {
return Peer{}
}
return Peer{Addr: u.Host}
}

// handlerConnCloser extends HandlerConn with a method for handlers to
// terminate the message exchange (and optionally send an error to the client).
type handlerConnCloser interface {
Expand Down
15 changes: 15 additions & 0 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,9 @@ func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.Pi
if err := expectClientHeader(p.checkMetadata, request); err != nil {
return nil, err
}
if request.Peer().Addr == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
response := connect.NewResponse(
&pingv1.PingResponse{
Number: request.Msg.Number,
Expand All @@ -1446,6 +1449,9 @@ func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.Fa
if err := expectClientHeader(p.checkMetadata, request); err != nil {
return nil, err
}
if request.Peer().Addr == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
err := connect.NewError(connect.Code(request.Msg.Code), errors.New(errorMessage))
err.Meta().Set(handlerHeader, headerValue)
err.Meta().Set(handlerTrailer, trailerValue)
Expand All @@ -1461,6 +1467,9 @@ func (p pingServer) Sum(
return nil, err
}
}
if stream.Peer().Addr == "" {
return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
var sum int64
for stream.Receive() {
sum += stream.Msg().Number
Expand All @@ -1482,6 +1491,9 @@ func (p pingServer) CountUp(
if err := expectClientHeader(p.checkMetadata, request); err != nil {
return err
}
if request.Peer().Addr == "" {
return connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
if request.Msg.Number <= 0 {
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf(
"number must be positive: got %v",
Expand All @@ -1508,6 +1520,9 @@ func (p pingServer) CumSum(
return err
}
}
if stream.Peer().Addr == "" {
return connect.NewError(connect.CodeInternal, errors.New("no peer address"))
}
stream.ResponseHeader().Set(handlerHeader, headerValue)
stream.ResponseTrailer().Set(handlerTrailer, trailerValue)
for {
Expand Down
2 changes: 2 additions & 0 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func NewUnaryHandler[Req, Res any](
request := &Request[Req]{
Msg: &msg,
spec: conn.Spec(),
peer: conn.Peer(),
header: conn.RequestHeader(),
}
response, err := untyped(ctx, request)
Expand Down Expand Up @@ -124,6 +125,7 @@ func NewServerStreamHandler[Req, Res any](
&Request[Req]{
Msg: &msg,
spec: conn.Spec(),
peer: conn.Peer(),
header: conn.RequestHeader(),
},
&ServerStream[Res]{conn: conn},
Expand Down
20 changes: 20 additions & 0 deletions handler_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ type ClientStream[Req any] struct {
err error
}

// Spec returns the specification for the RPC.
func (c *ClientStream[_]) Spec() Spec {
return c.conn.Spec()
}

// Peer describes the client for this RPC.
func (c *ClientStream[_]) Peer() Peer {
return c.conn.Peer()
}

// RequestHeader returns the headers received from the client.
func (c *ClientStream[Req]) RequestHeader() http.Header {
return c.conn.RequestHeader()
Expand Down Expand Up @@ -111,6 +121,16 @@ type BidiStream[Req, Res any] struct {
conn StreamingHandlerConn
}

// Spec returns the specification for the RPC.
func (b *BidiStream[_, _]) Spec() Spec {
return b.conn.Spec()
}

// Peer describes the client for this RPC.
func (b *BidiStream[_, _]) Peer() Peer {
return b.conn.Peer()
}

// RequestHeader returns the headers received from the client.
func (b *BidiStream[Req, Res]) RequestHeader() http.Header {
return b.conn.RequestHeader()
Expand Down
3 changes: 3 additions & 0 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ type protocolClientParams struct {
// Client is the client side of a protocol. HTTP clients typically use a single
// protocol, codec, and compressor to send requests.
type protocolClient interface {
// Peer describes the server for the RPC.
Peer() Peer

// WriteRequestHeader writes any protocol-specific request headers.
WriteRequestHeader(StreamType, http.Header)

Expand Down
Loading

0 comments on commit d3c2b89

Please sign in to comment.