From d3c2b89dbb841423a25418802ea737723d717c37 Mon Sep 17 00:00:00 2001 From: Akshay Shah Date: Fri, 23 Sep 2022 10:43:25 -0700 Subject: [PATCH] Expose peer information to handlers and clients (#364) * Add Spec to some user-facing streams Where the generated code replaces Request with a stream, expose the Spec. * Expose peer information to servers and clients On requests and streams, expose the peer's address. For clients, the address is the host or host:port, parsed from the URL. For servers, the address is an IP:port pair, provided by `net/http` as `Request.RemoteAddr`. In #357, a handful of users asked for this information for logging, per-IP rate limiting, and a variety of other server-side concerns. It's also necessary for OpenTelemetry interceptors (#344). Fixes #357. --- client.go | 8 +++-- client_ext_test.go | 88 +++++++++++++++++++++++++++++++++++++++++++++ client_stream.go | 20 +++++++++++ connect.go | 27 +++++++++++++- connect_ext_test.go | 15 ++++++++ handler.go | 2 ++ handler_stream.go | 20 +++++++++++ protocol.go | 3 ++ protocol_connect.go | 29 +++++++++++++++ protocol_grpc.go | 16 +++++++++ 10 files changed, 224 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index e79886d2..ad223040 100644 --- a/client.go +++ b/client.go @@ -70,7 +70,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien // once at client creation. unarySpec := config.newSpec(StreamTypeUnary) unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { - conn := protocolClient.NewConn(ctx, unarySpec, request.Header()) + conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header()) // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the // full error from the server-side. @@ -94,9 +94,11 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien unaryFunc = interceptor.WrapUnary(unaryFunc) } client.callUnary = func(ctx context.Context, request *Request[Req]) (*Response[Res], error) { - // To make the specification and RPC headers visible to the full interceptor - // chain (as though they were supplied by the caller), we'll add them here. + // To make the specification, peer, and RPC headers visible to the full + // interceptor chain (as though they were supplied by the caller), we'll + // add them here. request.spec = unarySpec + request.peer = client.protocolClient.Peer() protocolClient.WriteRequestHeader(StreamTypeUnary, request.Header()) response, err := unaryFunc(ctx, request) if err != nil { diff --git a/client_ext_test.go b/client_ext_test.go index bfbefc55..1497b55b 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -18,6 +18,7 @@ import ( "context" "errors" "net/http" + "net/http/httptest" "testing" "connectrpc.com/connect" @@ -68,3 +69,90 @@ func TestNewClient_InitFailure(t *testing.T) { validateExpectedError(t, err) }) } + +func TestClientPeer(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) + server := httptest.NewUnstartedServer(mux) + server.EnableHTTP2 = true + server.StartTLS() + t.Cleanup(server.Close) + + run := func(t *testing.T, opts ...connect.ClientOption) { + t.Helper() + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL, + connect.WithClientOptions(opts...), + connect.WithInterceptors(&assertPeerInterceptor{t}), + ) + ctx := context.Background() + // unary + _, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{})) + assert.Nil(t, err) + // client streaming + clientStream := client.Sum(ctx) + t.Cleanup(func() { + _, closeErr := clientStream.CloseAndReceive() + assert.Nil(t, closeErr) + }) + assert.NotNil(t, clientStream.Peer().Addr) + err = clientStream.Send(&pingv1.SumRequest{}) + assert.Nil(t, err) + // server streaming + serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{})) + t.Cleanup(func() { + assert.Nil(t, serverStream.Close()) + }) + assert.Nil(t, err) + // bidi streaming + bidiStream := client.CumSum(ctx) + t.Cleanup(func() { + assert.Nil(t, bidiStream.CloseRequest()) + assert.Nil(t, bidiStream.CloseResponse()) + }) + assert.NotNil(t, bidiStream.Peer().Addr) + err = bidiStream.Send(&pingv1.CumSumRequest{}) + assert.Nil(t, err) + } + + t.Run("connect", func(t *testing.T) { + t.Parallel() + run(t) + }) + t.Run("grpc", func(t *testing.T) { + t.Parallel() + run(t, connect.WithGRPC()) + }) + t.Run("grpcweb", func(t *testing.T) { + t.Parallel() + run(t, connect.WithGRPCWeb()) + }) +} + +type assertPeerInterceptor struct { + tb testing.TB +} + +func (a *assertPeerInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + assert.NotZero(a.tb, req.Peer().Addr) + return next(ctx, req) + } +} + +func (a *assertPeerInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { + conn := next(ctx, spec) + assert.NotZero(a.tb, conn.Peer().Addr) + return conn + } +} + +func (a *assertPeerInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return func(ctx context.Context, conn connect.StreamingHandlerConn) error { + assert.NotZero(a.tb, conn.Peer().Addr) + return next(ctx, conn) + } +} diff --git a/client_stream.go b/client_stream.go index db84121d..6c07c5f0 100644 --- a/client_stream.go +++ b/client_stream.go @@ -30,6 +30,16 @@ type ClientStreamForClient[Req, Res any] struct { err error } +// Spec returns the specification for the RPC. +func (c *ClientStreamForClient[_, _]) Spec() Spec { + return c.conn.Spec() +} + +// Peer describes the server for the RPC. +func (c *ClientStreamForClient[_, _]) Peer() Peer { + return c.conn.Peer() +} + // RequestHeader returns the request headers. Headers are sent to the server with the // first call to Send. func (c *ClientStreamForClient[Req, Res]) RequestHeader() http.Header { @@ -164,6 +174,16 @@ type BidiStreamForClient[Req, Res any] struct { err error } +// Spec returns the specification for the RPC. +func (b *BidiStreamForClient[_, _]) Spec() Spec { + return b.conn.Spec() +} + +// Peer describes the server for the RPC. +func (b *BidiStreamForClient[_, _]) Peer() Peer { + return b.conn.Peer() +} + // RequestHeader returns the request headers. Headers are sent with the first // call to Send. func (b *BidiStreamForClient[Req, Res]) RequestHeader() http.Header { diff --git a/connect.go b/connect.go index ab705ba6..3f66ba1f 100644 --- a/connect.go +++ b/connect.go @@ -28,6 +28,7 @@ import ( "errors" "io" "net/http" + "net/url" ) // Version is the semantic version of the connect module. @@ -68,6 +69,7 @@ const ( // StreamingHandlerConn implementations do not need to be safe for concurrent use. type StreamingHandlerConn interface { Spec() Spec + Peer() Peer Receive(any) error RequestHeader() http.Header @@ -97,8 +99,9 @@ type StreamingHandlerConn interface { // implementations must support limited concurrent use. See the comments on // each group of methods for details. type StreamingClientConn interface { - // Spec must be safe to call concurrently with all other methods. + // Spec and Peer must be safe to call concurrently with all other methods. Spec() Spec + Peer() Peer // Send, RequestHeader, and CloseRequest may race with each other, but must // be safe to call concurrently with all other methods. @@ -121,6 +124,7 @@ type Request[T any] struct { Msg *T spec Spec + peer Peer header http.Header } @@ -144,6 +148,11 @@ func (r *Request[_]) Spec() Spec { return r.spec } +// Peer describes the other party for this RPC. +func (r *Request[_]) Peer() Peer { + return r.peer +} + // Header returns the HTTP headers for this request. func (r *Request[_]) Header() http.Header { if r.header == nil { @@ -164,6 +173,7 @@ func (r *Request[_]) internalOnly() {} type AnyRequest interface { Any() any Spec() Spec + Peer() Peer Header() http.Header internalOnly() @@ -243,6 +253,21 @@ type Spec struct { IsClient bool // otherwise we're in a handler } +// Peer describes the other party to an RPC. When accessed client-side, Addr +// contains the host or host:port from the server's URL. When accessed +// server-side, Addr contains the client's address in IP:port format. +type Peer struct { + Addr string +} + +func newPeerFromURL(s string) Peer { + u, err := url.Parse(s) + if err != nil { + return Peer{} + } + return Peer{Addr: u.Host} +} + // handlerConnCloser extends HandlerConn with a method for handlers to // terminate the message exchange (and optionally send an error to the client). type handlerConnCloser interface { diff --git a/connect_ext_test.go b/connect_ext_test.go index faaa696d..caad1ddb 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -1431,6 +1431,9 @@ func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.Pi if err := expectClientHeader(p.checkMetadata, request); err != nil { return nil, err } + if request.Peer().Addr == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) + } response := connect.NewResponse( &pingv1.PingResponse{ Number: request.Msg.Number, @@ -1446,6 +1449,9 @@ func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.Fa if err := expectClientHeader(p.checkMetadata, request); err != nil { return nil, err } + if request.Peer().Addr == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) + } err := connect.NewError(connect.Code(request.Msg.Code), errors.New(errorMessage)) err.Meta().Set(handlerHeader, headerValue) err.Meta().Set(handlerTrailer, trailerValue) @@ -1461,6 +1467,9 @@ func (p pingServer) Sum( return nil, err } } + if stream.Peer().Addr == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) + } var sum int64 for stream.Receive() { sum += stream.Msg().Number @@ -1482,6 +1491,9 @@ func (p pingServer) CountUp( if err := expectClientHeader(p.checkMetadata, request); err != nil { return err } + if request.Peer().Addr == "" { + return connect.NewError(connect.CodeInternal, errors.New("no peer address")) + } if request.Msg.Number <= 0 { return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( "number must be positive: got %v", @@ -1508,6 +1520,9 @@ func (p pingServer) CumSum( return err } } + if stream.Peer().Addr == "" { + return connect.NewError(connect.CodeInternal, errors.New("no peer address")) + } stream.ResponseHeader().Set(handlerHeader, headerValue) stream.ResponseTrailer().Set(handlerTrailer, trailerValue) for { diff --git a/handler.go b/handler.go index 86eda1b8..28c6417c 100644 --- a/handler.go +++ b/handler.go @@ -62,6 +62,7 @@ func NewUnaryHandler[Req, Res any]( request := &Request[Req]{ Msg: &msg, spec: conn.Spec(), + peer: conn.Peer(), header: conn.RequestHeader(), } response, err := untyped(ctx, request) @@ -124,6 +125,7 @@ func NewServerStreamHandler[Req, Res any]( &Request[Req]{ Msg: &msg, spec: conn.Spec(), + peer: conn.Peer(), header: conn.RequestHeader(), }, &ServerStream[Res]{conn: conn}, diff --git a/handler_stream.go b/handler_stream.go index 7f5e034d..eec9d8f1 100644 --- a/handler_stream.go +++ b/handler_stream.go @@ -30,6 +30,16 @@ type ClientStream[Req any] struct { err error } +// Spec returns the specification for the RPC. +func (c *ClientStream[_]) Spec() Spec { + return c.conn.Spec() +} + +// Peer describes the client for this RPC. +func (c *ClientStream[_]) Peer() Peer { + return c.conn.Peer() +} + // RequestHeader returns the headers received from the client. func (c *ClientStream[Req]) RequestHeader() http.Header { return c.conn.RequestHeader() @@ -111,6 +121,16 @@ type BidiStream[Req, Res any] struct { conn StreamingHandlerConn } +// Spec returns the specification for the RPC. +func (b *BidiStream[_, _]) Spec() Spec { + return b.conn.Spec() +} + +// Peer describes the client for this RPC. +func (b *BidiStream[_, _]) Peer() Peer { + return b.conn.Peer() +} + // RequestHeader returns the headers received from the client. func (b *BidiStream[Req, Res]) RequestHeader() http.Header { return b.conn.RequestHeader() diff --git a/protocol.go b/protocol.go index e4b9f416..8f0e7060 100644 --- a/protocol.go +++ b/protocol.go @@ -111,6 +111,9 @@ type protocolClientParams struct { // Client is the client side of a protocol. HTTP clients typically use a single // protocol, codec, and compressor to send requests. type protocolClient interface { + // Peer describes the server for the RPC. + Peer() Peer + // WriteRequestHeader writes any protocol-specific request headers. WriteRequestHeader(StreamType, http.Header) diff --git a/protocol_connect.go b/protocol_connect.go index 7b845da8..c75ea35c 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -151,9 +151,11 @@ func (h *connectHandler) NewConn( codec := h.Codecs.Get(codecName) // handler.go guarantees this is not nil var conn handlerConnCloser + peer := Peer{Addr: request.RemoteAddr} if h.Spec.StreamType == StreamTypeUnary { conn = &connectUnaryHandlerConn{ spec: h.Spec, + peer: peer, request: request, responseWriter: responseWriter, marshaler: connectUnaryMarshaler{ @@ -178,6 +180,7 @@ func (h *connectHandler) NewConn( } else { conn = &connectStreamingHandlerConn{ spec: h.Spec, + peer: peer, request: request, responseWriter: responseWriter, marshaler: connectStreamingMarshaler{ @@ -217,6 +220,10 @@ type connectClient struct { protocolClientParams } +func (c *connectClient) Peer() Peer { + return newPeerFromURL(c.URL) +} + func (c *connectClient) WriteRequestHeader(streamType StreamType, header http.Header) { // We know these header keys are in canonical form, so we can bypass all the // checks in Header.Set. @@ -263,6 +270,7 @@ func (c *connectClient) NewConn( if spec.StreamType == StreamTypeUnary { unaryConn := &connectUnaryClientConn{ spec: spec, + peer: c.Peer(), duplexCall: duplexCall, compressionPools: c.CompressionPools, bufferPool: c.BufferPool, @@ -290,6 +298,7 @@ func (c *connectClient) NewConn( } else { streamingConn := &connectStreamingClientConn{ spec: spec, + peer: c.Peer(), duplexCall: duplexCall, compressionPools: c.CompressionPools, bufferPool: c.BufferPool, @@ -323,6 +332,7 @@ func (c *connectClient) NewConn( type connectUnaryClientConn struct { spec Spec + peer Peer duplexCall *duplexHTTPCall compressionPools readOnlyCompressionPools bufferPool *bufferPool @@ -336,6 +346,10 @@ func (cc *connectUnaryClientConn) Spec() Spec { return cc.spec } +func (cc *connectUnaryClientConn) Peer() Peer { + return cc.peer +} + func (cc *connectUnaryClientConn) Send(msg any) error { if err := cc.marshaler.Marshal(msg); err != nil { return err @@ -416,6 +430,7 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err type connectStreamingClientConn struct { spec Spec + peer Peer duplexCall *duplexHTTPCall compressionPools readOnlyCompressionPools bufferPool *bufferPool @@ -430,6 +445,10 @@ func (cc *connectStreamingClientConn) Spec() Spec { return cc.spec } +func (cc *connectStreamingClientConn) Peer() Peer { + return cc.peer +} + func (cc *connectStreamingClientConn) Send(msg any) error { if err := cc.marshaler.Marshal(msg); err != nil { return err @@ -507,6 +526,7 @@ func (cc *connectStreamingClientConn) validateResponse(response *http.Response) type connectUnaryHandlerConn struct { spec Spec + peer Peer request *http.Request responseWriter http.ResponseWriter marshaler connectUnaryMarshaler @@ -519,6 +539,10 @@ func (hc *connectUnaryHandlerConn) Spec() Spec { return hc.spec } +func (hc *connectUnaryHandlerConn) Peer() Peer { + return hc.peer +} + func (hc *connectUnaryHandlerConn) Receive(msg any) error { if err := hc.unmarshaler.Unmarshal(msg); err != nil { return err @@ -583,6 +607,7 @@ func (hc *connectUnaryHandlerConn) writeResponseHeader(err error) { type connectStreamingHandlerConn struct { spec Spec + peer Peer request *http.Request responseWriter http.ResponseWriter marshaler connectStreamingMarshaler @@ -594,6 +619,10 @@ func (hc *connectStreamingHandlerConn) Spec() Spec { return hc.spec } +func (hc *connectStreamingHandlerConn) Peer() Peer { + return hc.peer +} + func (hc *connectStreamingHandlerConn) Receive(msg any) error { if err := hc.unmarshaler.Unmarshal(msg); err != nil { // Clients may not send end-of-stream metadata, so we don't need to handle diff --git a/protocol_grpc.go b/protocol_grpc.go index 79a0f84c..8cbef8b9 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -163,6 +163,7 @@ func (g *grpcHandler) NewConn( codec := g.Codecs.Get(codecName) // handler.go guarantees this is not nil conn := wrapHandlerConnWithCodedErrors(&grpcHandlerConn{ spec: g.Spec, + peer: Peer{Addr: request.RemoteAddr}, web: g.web, bufferPool: g.BufferPool, protobuf: g.Codecs.Protobuf(), // for errors @@ -205,6 +206,10 @@ type grpcClient struct { web bool } +func (g *grpcClient) Peer() Peer { + return newPeerFromURL(g.URL) +} + func (g *grpcClient) WriteRequestHeader(_ StreamType, header http.Header) { // We know these header keys are in canonical form, so we can bypass all the // checks in Header.Set. @@ -248,6 +253,7 @@ func (g *grpcClient) NewConn( ) conn := &grpcClientConn{ spec: spec, + peer: g.Peer(), duplexCall: duplexCall, compressionPools: g.CompressionPools, bufferPool: g.BufferPool, @@ -292,6 +298,7 @@ func (g *grpcClient) NewConn( // grpcClientConn works for both gRPC and gRPC-Web. type grpcClientConn struct { spec Spec + peer Peer duplexCall *duplexHTTPCall compressionPools readOnlyCompressionPools bufferPool *bufferPool @@ -307,6 +314,10 @@ func (cc *grpcClientConn) Spec() Spec { return cc.spec } +func (cc *grpcClientConn) Peer() Peer { + return cc.peer +} + func (cc *grpcClientConn) Send(msg any) error { if err := cc.marshaler.Marshal(msg); err != nil { return err @@ -393,6 +404,7 @@ func (cc *grpcClientConn) validateResponse(response *http.Response) *Error { type grpcHandlerConn struct { spec Spec + peer Peer web bool bufferPool *bufferPool protobuf Codec // for errors @@ -409,6 +421,10 @@ func (hc *grpcHandlerConn) Spec() Spec { return hc.spec } +func (hc *grpcHandlerConn) Peer() Peer { + return hc.peer +} + func (hc *grpcHandlerConn) Receive(msg any) error { if err := hc.unmarshaler.Unmarshal(msg); err != nil { return err // already coded