diff --git a/constants.go b/constants.go index 7c3532e29d7c0..584e8578f6445 100644 --- a/constants.go +++ b/constants.go @@ -116,6 +116,9 @@ const ( // ComponentProxy is SSH proxy (SSH server forwarding connections) ComponentProxy = "proxy" + // ComponentProxyPeer is the proxy peering component of the proxy service + ComponentProxyPeer = "proxy:peer" + // ComponentApp is the application proxy service. ComponentApp = "app:service" diff --git a/go.mod b/go.mod index 1b5d19a129014..e9082dc665ec5 100644 --- a/go.mod +++ b/go.mod @@ -51,7 +51,7 @@ require ( github.com/gravitational/reporting v0.0.0-20210923183620-237377721140 github.com/gravitational/roundtrip v1.0.1 github.com/gravitational/teleport/api v0.0.0 - github.com/gravitational/trace v1.1.17 + github.com/gravitational/trace v1.1.18 github.com/gravitational/ttlmap v0.0.0-20171116003245-91fd36b9004c github.com/grpc-ecosystem/go-grpc-middleware/providers/openmetrics/v2 v2.0.0-20220308023801-e4a6915ea237 github.com/hashicorp/golang-lru v0.5.4 @@ -103,6 +103,7 @@ require ( google.golang.org/api v0.65.0 google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 google.golang.org/grpc v1.43.0 + google.golang.org/grpc/examples v0.0.0-20220317213542-f95b001a48df google.golang.org/protobuf v1.27.1 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c gopkg.in/ini.v1 v1.62.0 diff --git a/go.sum b/go.sum index 32498cff94a05..e7f1eaff6b26e 100644 --- a/go.sum +++ b/go.sum @@ -521,8 +521,9 @@ github.com/gravitational/reporting v0.0.0-20210923183620-237377721140/go.mod h1: github.com/gravitational/roundtrip v1.0.1 h1:eD/y0av12Gu9VIwNgPY/ltmpeVk0Azek/yIJvOPuTuY= github.com/gravitational/roundtrip v1.0.1/go.mod h1:qccpLd30tAJVSpx7aOEEnws4ZT3njPwdbtT8lNQxbAs= github.com/gravitational/trace v1.1.16-0.20220114165159-14a9a7dd6aaf/go.mod h1:zXqxTI6jXDdKnlf8s+nT+3c8LrwUEy3yNpO4XJL90lA= -github.com/gravitational/trace v1.1.17 h1:BkF30oLm1aKMZ5SPVbnlVbYtYEsG26zHxA4dJ+Z46dM= github.com/gravitational/trace v1.1.17/go.mod h1:n0ijrq6psJY0sOI/NzLp+xdd8xl79jjwzVOFHDY6+kQ= +github.com/gravitational/trace v1.1.18 h1:Ulobib6xd5g1ct+ZC01HPAEvODws7QerjuTY9L4U8pY= +github.com/gravitational/trace v1.1.18/go.mod h1:n0ijrq6psJY0sOI/NzLp+xdd8xl79jjwzVOFHDY6+kQ= github.com/gravitational/ttlmap v0.0.0-20171116003245-91fd36b9004c h1:C2iWDiod8vQ3YnOiCdMP9qYeg2UifQ8KSk36r0NswSE= github.com/gravitational/ttlmap v0.0.0-20171116003245-91fd36b9004c/go.mod h1:erKVikttPjeHKDCQZcqowEqiccy23cJAqPadZgfjNm8= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7 h1:pdN6V1QBWetyv/0+wjACpqVH+eVULgEjkurDLq3goeM= @@ -1565,6 +1566,8 @@ google.golang.org/grpc v1.43.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ5 google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/grpc/examples v0.0.0-20200723182653-9106c3fff523/go.mod h1:5j1uub0jRGhRiSghIlrThmBUgcgLXOVJQ/l1getT4uo= google.golang.org/grpc/examples v0.0.0-20210424002626-9572fd6faeae/go.mod h1:Ly7ZA/ARzg8fnPU9TyZIxoz33sEUuWX7txiqs8lPTgE= +google.golang.org/grpc/examples v0.0.0-20220317213542-f95b001a48df h1:7Gq+gDOOhAZ1zuhvFhzTbC7jlpSfRGyxaJC4zqSzo6s= +google.golang.org/grpc/examples v0.0.0-20220317213542-f95b001a48df/go.mod h1:wKDg0brwMZpaizQ1i7IzYcJjH1TmbJudYdnQC9+J+LE= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/lib/auth/middleware.go b/lib/auth/middleware.go index 9923bc8b0664a..71853da8e4d5f 100644 --- a/lib/auth/middleware.go +++ b/lib/auth/middleware.go @@ -412,12 +412,12 @@ func (a *Middleware) UnaryInterceptor() grpc.UnaryServerInterceptor { if a.GRPCMetrics != nil { return utils.ChainUnaryServerInterceptors( om.UnaryServerInterceptor(a.GRPCMetrics), - utils.ErrorConvertUnaryInterceptor, + utils.GRPCServerUnaryErrorInterceptor, a.Limiter.UnaryServerInterceptorWithCustomRate(getCustomRate), a.withAuthenticatedUserUnaryInterceptor) } return utils.ChainUnaryServerInterceptors( - utils.ErrorConvertUnaryInterceptor, + utils.GRPCServerUnaryErrorInterceptor, a.Limiter.UnaryServerInterceptorWithCustomRate(getCustomRate), a.withAuthenticatedUserUnaryInterceptor) } @@ -429,12 +429,12 @@ func (a *Middleware) StreamInterceptor() grpc.StreamServerInterceptor { if a.GRPCMetrics != nil { return utils.ChainStreamServerInterceptors( om.StreamServerInterceptor(a.GRPCMetrics), - utils.ErrorConvertStreamInterceptor, + utils.GRPCServerStreamErrorInterceptor, a.Limiter.StreamServerInterceptor, a.withAuthenticatedUserStreamInterceptor) } return utils.ChainStreamServerInterceptors( - utils.ErrorConvertStreamInterceptor, + utils.GRPCServerStreamErrorInterceptor, a.Limiter.StreamServerInterceptor, a.withAuthenticatedUserStreamInterceptor) } diff --git a/lib/proxy/auth.go b/lib/proxy/auth.go index 3ae2408538566..996a2cb275c08 100644 --- a/lib/proxy/auth.go +++ b/lib/proxy/auth.go @@ -17,6 +17,7 @@ package proxy import ( "context" "crypto/tls" + "crypto/x509" "net" "github.com/gravitational/teleport/api/types" @@ -98,22 +99,52 @@ func checkProxyRole(authInfo credentials.AuthInfo) error { return trace.AccessDenied("proxy system role required") } +// getConfigForClient clones and updates the server's tls config with the +// appropriate client certificate authorities. func getConfigForClient(tlsConfig *tls.Config, ap auth.AccessCache, log logrus.FieldLogger) func(*tls.ClientHelloInfo) (*tls.Config, error) { return func(info *tls.ClientHelloInfo) (*tls.Config, error) { - clusterName, err := ap.GetClusterName() - if err != nil { - log.WithError(err).Error("Failed to retrieve cluster name.") - return nil, nil - } + tlsCopy := tlsConfig.Clone() - pool, _, err := auth.ClientCertPool(ap, clusterName.GetClusterName()) + pool, err := getCertPool(ap) if err != nil { log.WithError(err).Error("Failed to retrieve client CA pool.") - return nil, nil + return tlsCopy, nil } - tlsCopy := tlsConfig.Clone() + tlsCopy.ClientAuth = tls.RequireAndVerifyClientCert tlsCopy.ClientCAs = pool return tlsCopy, nil } } + +// getConfigForServer clones and updates the client's tls config with the +// appropriate server certificate authorities. +func getConfigForServer(tlsConfig *tls.Config, ap auth.AccessCache, log logrus.FieldLogger) func() (*tls.Config, error) { + return func() (*tls.Config, error) { + tlsCopy := tlsConfig.Clone() + + pool, err := getCertPool(ap) + if err != nil { + log.WithError(err).Error("Failed to retrieve server CA pool.") + return tlsCopy, nil + } + + tlsCopy.RootCAs = pool + return tlsCopy, nil + } +} + +// getCertPool returns a new cert pool from cache if any. +func getCertPool(ap auth.AccessCache) (*x509.CertPool, error) { + clusterName, err := ap.GetClusterName() + if err != nil { + return nil, trace.Wrap(err) + } + + pool, _, err := auth.ClientCertPool(ap, clusterName.GetClusterName()) + if err != nil { + return nil, trace.Wrap(err) + } + + return pool, nil +} diff --git a/lib/proxy/client.go b/lib/proxy/client.go new file mode 100644 index 0000000000000..f45585fcca58a --- /dev/null +++ b/lib/proxy/client.go @@ -0,0 +1,563 @@ +// Copyright 2022 Gravitational, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "crypto/tls" + "net" + "sync" + "time" + + clientapi "github.com/gravitational/teleport/api/client/proto" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/metadata" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/keepalive" +) + +// ClientConfig configures a Client instance. +type ClientConfig struct { + // Context is a signalling context + Context context.Context + // ID is the ID of this server proxy + ID string + // AuthClient is an auth client + AuthClient auth.ClientI + // AccessPoint is a caching auth client + AccessPoint auth.ProxyAccessPoint + // TLSConfig is the proxy client TLS configuration. + TLSConfig *tls.Config + // Log is the proxy client logger. + Log logrus.FieldLogger + // Clock is used to control connection monitoring ticker. + Clock clockwork.Clock + // GracefulShutdownTimout is used set the graceful shutdown + // duration limit. + GracefulShutdownTimeout time.Duration + + // getConfigForServer updates the client tls config. + // configurable for testing purposes. + getConfigForServer func() (*tls.Config, error) + + // sync runs proxy and connection syncing operations + // configurable for testing purposes + sync func() +} + +// checkAndSetDefaults checks and sets default values +func (c *ClientConfig) checkAndSetDefaults() error { + if c.Log == nil { + c.Log = logrus.New() + } + + c.Log = c.Log.WithField( + trace.Component, + teleport.Component(teleport.ComponentProxyPeer), + ) + + if c.Clock == nil { + c.Clock = clockwork.NewRealClock() + } + + if c.Context == nil { + c.Context = context.Background() + } + + if c.GracefulShutdownTimeout == 0 { + c.GracefulShutdownTimeout = defaults.DefaultGracefulShutdownTimeout + } + + if c.ID == "" { + return trace.BadParameter("missing parameter ID") + } + + if c.AuthClient == nil { + return trace.BadParameter("missing auth client") + } + + if c.AccessPoint == nil { + return trace.BadParameter("missing access cache") + } + + if c.TLSConfig == nil { + return trace.BadParameter("missing tls config") + } + + if len(c.TLSConfig.Certificates) == 0 { + return trace.BadParameter("missing tls certificate") + } + + if c.getConfigForServer == nil { + c.getConfigForServer = getConfigForServer(c.TLSConfig, c.AccessPoint, c.Log) + } + + return nil +} + +// clientConn hold info about a dialed grpc connection +type clientConn struct { + *grpc.ClientConn + ctx context.Context + cancel context.CancelFunc + wg *sync.WaitGroup + + id string + addr string +} + +// Client is a peer proxy service client using grpc and tls. +type Client struct { + sync.RWMutex + ctx context.Context + cancel context.CancelFunc + + config ClientConfig + conns map[string]*clientConn + metrics *clientMetrics + reporter *reporter +} + +// NewClient creats a new peer proxy client. +func NewClient(config ClientConfig) (*Client, error) { + err := config.checkAndSetDefaults() + if err != nil { + return nil, trace.Wrap(err) + } + + metrics, err := newClientMetrics() + if err != nil { + return nil, trace.Wrap(err) + } + + reporter := newReporter(metrics) + + closeContext, cancel := context.WithCancel(config.Context) + + c := &Client{ + config: config, + ctx: closeContext, + cancel: cancel, + conns: make(map[string]*clientConn), + metrics: metrics, + reporter: reporter, + } + + go c.monitor() + + if c.config.sync != nil { + go c.config.sync() + } else { + go c.sync() + } + + return c, nil +} + +// monitor monitors the status of peer proxy grpc connections. +func (c *Client) monitor() { + ticker := c.config.Clock.NewTicker(defaults.ResyncInterval) + defer ticker.Stop() + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.Chan(): + c.RLock() + c.reporter.resetConnections() + for _, conn := range c.conns { + switch conn.GetState() { + case connectivity.Idle: + c.reporter.incConnection(c.config.ID, conn.id, connectivity.Idle.String()) + case connectivity.Connecting: + c.reporter.incConnection(c.config.ID, conn.id, connectivity.Connecting.String()) + case connectivity.Ready: + c.reporter.incConnection(c.config.ID, conn.id, connectivity.Ready.String()) + case connectivity.TransientFailure: + c.reporter.incConnection(c.config.ID, conn.id, connectivity.TransientFailure.String()) + case connectivity.Shutdown: + c.reporter.incConnection(c.config.ID, conn.id, connectivity.Shutdown.String()) + } + } + c.RUnlock() + } + } +} + +// sync runs the peer proxy watcher functionality. +func (c *Client) sync() { + proxyWatcher, err := services.NewProxyWatcher(c.ctx, services.ProxyWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.Component(teleport.ComponentProxyPeer), + Client: c.config.AccessPoint, + Log: c.config.Log, + }, + }) + if err != nil { + c.config.Log.Errorf("Error initializing proxy peer watcher: %+v.", err) + return + } + defer proxyWatcher.Close() + + for { + select { + case <-c.ctx.Done(): + c.config.Log.Debug("Stopping peer proxy sync.") + return + case proxies := <-proxyWatcher.ProxiesC: + if err := c.updateConnections(proxies); err != nil { + c.config.Log.Errorf("Error syncing peer proxies: %+v.", err) + } + } + } +} + +func (c *Client) updateConnections(proxies []types.Server) error { + c.RLock() + + toDial := make(map[string]types.Server) + for _, proxy := range proxies { + toDial[proxy.GetName()] = proxy + } + + var toDelete []string + toKeep := make(map[string]*clientConn) + for id, conn := range c.conns { + proxy, ok := toDial[id] + + // delete nonexistent connections + if !ok { + toDelete = append(toDelete, id) + continue + } + + // peer address changed + if conn.addr != proxy.GetPeerAddr() { + toDelete = append(toDelete, id) + continue + } + + toKeep[id] = conn + } + + var errs []error + for id, proxy := range toDial { + // skips itself + if id == c.config.ID { + continue + } + + // skip existing connections + if _, ok := toKeep[id]; ok { + continue + } + + // establish new connections + conn, err := c.connect(id, proxy.GetPeerAddr()) + if err != nil { + c.metrics.reportTunnelError(errorProxyPeerTunnelDial) + c.config.Log.Debugf("Error dialing peer proxy %+v at %+v", id, proxy.GetPeerAddr()) + errs = append(errs, err) + continue + } + + toKeep[id] = conn + } + c.RUnlock() + + c.Lock() + defer c.Unlock() + + for _, id := range toDelete { + if conn, ok := c.conns[id]; ok { + go c.shutdownConn(conn) + } + } + c.conns = toKeep + + return trace.NewAggregate(errs...) +} + +// DialNode dials a node through a peer proxy. +func (c *Client) DialNode( + proxyIDs []string, + nodeID string, + src net.Addr, + dst net.Addr, + tunnelType types.TunnelType, +) (net.Conn, error) { + stream, _, err := c.dial(proxyIDs) + if err != nil { + return nil, trace.ConnectionProblem(err, "error dialing peer proxies %s", proxyIDs) + } + + // send dial request as the first frame + if err = stream.Send(&clientapi.Frame{ + Message: &clientapi.Frame_DialRequest{ + DialRequest: &clientapi.DialRequest{ + NodeID: nodeID, + TunnelType: tunnelType, + Source: &clientapi.NetAddr{ + Addr: src.String(), + Network: src.Network(), + }, + Destination: &clientapi.NetAddr{ + Addr: dst.String(), + Network: dst.Network(), + }, + }, + }, + }); err != nil { + return nil, trace.ConnectionProblem(err, "error sending dial frame") + } + + msg, err := stream.Recv() + if err != nil { + return nil, trace.ConnectionProblem(err, "error receiving dial response") + } + + if msg.GetConnectionEstablished() == nil { + return nil, trace.ConnectionProblem(nil, "received malformed connection established frame") + } + + return newStreamConn(stream, src, dst), nil +} + +// Shutdown gracefully shuts down all existing client connections. +func (c *Client) Shutdown() { + c.Lock() + defer c.Unlock() + + var wg sync.WaitGroup + for _, conn := range c.conns { + wg.Add(1) + go func(conn *clientConn) { + defer wg.Done() + + timeoutCtx, cancel := context.WithTimeout(context.Background(), c.config.GracefulShutdownTimeout) + defer cancel() + + go func() { + if err := c.shutdownConn(conn); err != nil { + c.config.Log.Infof("proxy peer connection %+v graceful shutdown error: %+v", conn.id, err) + } + }() + + select { + case <-conn.ctx.Done(): + case <-timeoutCtx.Done(): + if err := c.stopConn(conn); err != nil { + c.config.Log.Infof("proxy peer connection %+v close error: %+v", conn.id, err) + } + } + }(conn) + } + wg.Wait() + c.cancel() +} + +// Stop closes all existing client connections. +func (c *Client) Stop() error { + c.Lock() + defer c.Unlock() + + var errs []error + for _, conn := range c.conns { + if err := c.stopConn(conn); err != nil { + errs = append(errs, err) + } + } + c.cancel() + return trace.NewAggregate(errs...) +} + +// shutdownConn gracefully shuts down a clientConn +// by waiting for open streams to finish. +func (c *Client) shutdownConn(conn *clientConn) error { + conn.wg.Wait() // wait for streams to gracefully end + conn.cancel() + return conn.Close() +} + +// stopConn immediately closes a clientConn +func (c *Client) stopConn(conn *clientConn) error { + conn.cancel() + return conn.Close() +} + +// dial opens a new stream to one of the supplied proxy ids. +// it tries to find an existing grpc.ClientConn or initializes a new rpc +// to one of the proxies otherwise. +// The boolean returned in the second argument is intended for testing purposes, +// to indicates whether the connection was cached or newly established. +func (c *Client) dial(proxyIDs []string) (clientapi.ProxyService_DialNodeClient, bool, error) { + conns, existing, err := c.getConnections(proxyIDs) + if err != nil { + return nil, existing, trace.Wrap(err) + } + + var errs []error + for _, conn := range conns { + stream, err := c.startStream(conn) + if err != nil { + c.metrics.reportTunnelError(errorProxyPeerTunnelRPC) + c.config.Log.Debugf("Error opening tunnel rpc to proxy %+v at %+v", conn.id, conn.addr) + errs = append(errs, err) + continue + } + + return stream, existing, nil + } + + return nil, existing, trace.ConnectionProblem(trace.NewAggregate(errs...), "Error opening tunnel rpcs to all proxies") +} + +// getConnections returns connections to the supplied proxy ids. +// it tries to find an existing grpc.ClientConn or initializes a new one +// otherwise. +// The boolean returned in the second argument is intended for testing purposes, +// to indicates whether the connection was cached or newly established. +func (c *Client) getConnections(proxyIDs []string) ([]*clientConn, bool, error) { + ids := make(map[string]struct{}) + var conns []*clientConn + + // look for existing matching connections. + c.RLock() + for _, id := range proxyIDs { + ids[id] = struct{}{} + + conn, ok := c.conns[id] + if !ok { + continue + } + + conns = append(conns, conn) + } + c.RUnlock() + + if len(conns) != 0 { + return conns, true, nil + } + + c.metrics.reportTunnelError(errorProxyPeerTunnelNotFound) + + // try to establish new connections otherwise. + proxies, err := c.config.AuthClient.GetProxies() + if err != nil { + c.metrics.reportTunnelError(errorProxyPeerFetchProxies) + return nil, false, trace.Wrap(err) + } + + var errs []error + for _, proxy := range proxies { + id := proxy.GetName() + if _, ok := ids[id]; !ok { + continue + } + + conn, err := c.connect(id, proxy.GetPeerAddr()) + if err != nil { + c.metrics.reportTunnelError(errorProxyPeerTunnelDirectDial) + c.config.Log.Debugf("Error direct dialing peer proxy %+v at %+v", id, proxy.GetPeerAddr()) + errs = append(errs, err) + continue + } + + conns = append(conns, conn) + } + + if len(conns) == 0 { + c.metrics.reportTunnelError(errorProxyPeerProxiesUnreachable) + return nil, false, trace.ConnectionProblem(trace.NewAggregate(errs...), "Error dialing all proxies") + } + + c.Lock() + defer c.Unlock() + + for _, conn := range conns { + c.conns[conn.id] = conn + } + + return conns, false, nil +} + +// connect dials a new connection to proxyAddr. +func (c *Client) connect(id string, proxyPeerAddr string) (*clientConn, error) { + tlsConfig, err := c.config.getConfigForServer() + if err != nil { + return nil, trace.Wrap(err, "Error updating client tls config") + } + + connCtx, cancel := context.WithCancel(c.ctx) + wg := new(sync.WaitGroup) + + transportCreds := newProxyCredentials(credentials.NewTLS(tlsConfig)) + conn, err := grpc.DialContext( + connCtx, + proxyPeerAddr, + grpc.WithTransportCredentials(transportCreds), + grpc.WithStatsHandler(newStatsHandler(c.reporter)), + grpc.WithChainStreamInterceptor(metadata.StreamClientInterceptor, utils.GRPCClientStreamErrorInterceptor, streamCounterInterceptor(wg)), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: peerKeepAlive, + Timeout: peerTimeout, + PermitWithoutStream: true, + }), + grpc.WithDefaultServiceConfig(`{"loadBalancingPolicy":"round_robin"}`), + ) + if err != nil { + return nil, trace.Wrap(err, "Error dialing proxy %+v", id) + } + + return &clientConn{ + ClientConn: conn, + ctx: connCtx, + cancel: cancel, + wg: wg, + id: id, + addr: proxyPeerAddr, + }, nil +} + +// startStream opens a new stream to the provided connection. +func (c *Client) startStream(conn *clientConn) (clientapi.ProxyService_DialNodeClient, error) { + client := clientapi.NewProxyServiceClient(conn.ClientConn) + + stream, err := client.DialNode(conn.ctx) + if err != nil { + return nil, trace.Wrap(err, "Error opening stream to proxy %+v", conn.id) + } + + go func() { + <-conn.ctx.Done() + if err := stream.CloseSend(); err != nil { + c.config.Log.Debugf("error closing stream: %+v", err) + } + }() + + return stream, nil +} diff --git a/lib/proxy/client_test.go b/lib/proxy/client_test.go new file mode 100644 index 0000000000000..4f7569f4e1292 --- /dev/null +++ b/lib/proxy/client_test.go @@ -0,0 +1,219 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "testing" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/tlsca" + + "github.com/stretchr/testify/require" +) + +// TestClientConn checks the client's connection caching capabilities +func TestClientConn(t *testing.T) { + ca := newSelfSignedCA(t) + + client, _ := setupClient(t, ca, ca, types.RoleProxy) + _, _, def1 := setupServer(t, "s1", ca, ca, types.RoleProxy) + server2, _, def2 := setupServer(t, "s2", ca, ca, types.RoleProxy) + + // simulate watcher finding two servers + err := client.updateConnections([]types.Server{def1, def2}) + require.NoError(t, err) + require.Len(t, client.conns, 2) + + // dial first server and send a test data frame + stream, cached, err := client.dial([]string{"s1"}) + require.NoError(t, err) + require.True(t, cached) + require.NotNil(t, stream) + sendDialRequest(t, stream) + stream.CloseSend() + + // dial second server + stream, cached, err = client.dial([]string{"s2"}) + require.NoError(t, err) + require.True(t, cached) + require.NotNil(t, stream) + stream.CloseSend() + + // redial second server + stream, cached, err = client.dial([]string{"s2"}) + require.NoError(t, err) + require.True(t, cached) + require.NotNil(t, stream) + stream.CloseSend() + + // close second server + // and attempt to redial it + server2.Shutdown() + stream, cached, err = client.dial([]string{"s2"}) + require.Error(t, err) + require.True(t, cached) + require.Nil(t, stream) +} + +// TestClientUpdate checks the client's watcher update behaviour +func TestClientUpdate(t *testing.T) { + ca := newSelfSignedCA(t) + + client, _ := setupClient(t, ca, ca, types.RoleProxy) + _, _, def1 := setupServer(t, "s1", ca, ca, types.RoleProxy) + server2, _, def2 := setupServer(t, "s2", ca, ca, types.RoleProxy) + + // watcher finds two servers + err := client.updateConnections([]types.Server{def1, def2}) + require.NoError(t, err) + require.Len(t, client.conns, 2) + require.Contains(t, client.conns, "s1") + require.Contains(t, client.conns, "s2") + + s1, _, err := client.dial([]string{"s1"}) + require.NoError(t, err) + require.NotNil(t, s1) + sendDialRequest(t, s1) + s2, _, err := client.dial([]string{"s2"}) + require.NoError(t, err) + require.NotNil(t, s2) + sendDialRequest(t, s2) + + // watcher finds one of the two servers + err = client.updateConnections([]types.Server{def1}) + require.NoError(t, err) + require.Len(t, client.conns, 1) + require.Contains(t, client.conns, "s1") + sendMsg(t, s1) // stream is not broken across updates + sendMsg(t, s2) // stream is not forcefully closed. ClientConn waits for a graceful shutdown before it closes. + + s2.CloseSend() + + // watcher finds two servers with one broken connection + server2.Shutdown() + err = client.updateConnections([]types.Server{def1, def2}) + require.NoError(t, err) // server2 is in a transient failure state but not reported as an error + require.Len(t, client.conns, 2) + require.Contains(t, client.conns, "s1") + sendMsg(t, s1) // stream is still going strong + _, _, err = client.dial([]string{"s2"}) + require.Error(t, err) // can't dial server2, obviously + + // peer address change + _, _, def3 := setupServer(t, "s1", ca, ca, types.RoleProxy) + err = client.updateConnections([]types.Server{def3}) + require.NoError(t, err) + require.Len(t, client.conns, 1) + require.Contains(t, client.conns, "s1") + sendMsg(t, s1) // stream is not forcefully closed. ClientConn waits for a graceful shutdown before it closes. + s3, _, err := client.dial([]string{"s1"}) + require.NoError(t, err) + require.NotNil(t, s3) + sendDialRequest(t, s3) // new stream is working + + s1.CloseSend() + s3.CloseSend() +} + +func TestCAChange(t *testing.T) { + clientCA := newSelfSignedCA(t) + serverCA := newSelfSignedCA(t) + + client, clientTLSConfig := setupClient(t, clientCA, serverCA, types.RoleProxy) + server, serverTLSConfig, serverDef := setupServer(t, "s1", serverCA, clientCA, types.RoleProxy) + + err := client.updateConnections([]types.Server{serverDef}) + require.NoError(t, err) + require.Len(t, client.conns, 1) + + // dial server and send a test data frame + ogStream, cached, err := client.dial([]string{"s1"}) + require.NoError(t, err) + require.True(t, cached) + require.NotNil(t, ogStream) + + sendDialRequest(t, ogStream) + ogStream.CloseSend() + + // server ca rotated + newServerCA := newSelfSignedCA(t) + + newServerTLSConfig := certFromIdentity(t, newServerCA, tlsca.Identity{ + Groups: []string{string(types.RoleProxy)}, + }) + + *serverTLSConfig = *newServerTLSConfig + + // existing connection should still be working + ogStream, cached, err = client.dial([]string{"s1"}) + require.NoError(t, err) + require.True(t, cached) + require.NotNil(t, ogStream) + sendDialRequest(t, ogStream) + + // new connection should fail because client tls config still references old + // RootCAs. + conn, err := client.connect("s1", server.config.Listener.Addr().String()) + require.NoError(t, err) + require.NotNil(t, conn) + stream, err := client.startStream(conn) + require.Error(t, err) + require.Nil(t, stream) + + // new connection should succeed because client references new RootCAs + *serverCA = *newServerCA + conn, err = client.connect("s1", server.config.Listener.Addr().String()) + require.NoError(t, err) + require.NotNil(t, conn) + stream, err = client.startStream(conn) + require.NoError(t, err) + sendDialRequest(t, stream) + stream.CloseSend() + + // for good measure, original stream should still be working + sendMsg(t, ogStream) + + // client ca rotated + newClientCA := newSelfSignedCA(t) + + newClientTLSConfig := certFromIdentity(t, newClientCA, tlsca.Identity{ + Groups: []string{string(types.RoleProxy)}, + }) + + *clientTLSConfig = *newClientTLSConfig + + // new connection should fail because server tls config still references old + // ClientCAs. + conn, err = client.connect("s1", server.config.Listener.Addr().String()) + require.NoError(t, err) + require.NotNil(t, conn) + stream, err = client.startStream(conn) + require.Error(t, err) + require.Nil(t, stream) + + // new connection should succeed because client references new RootCAs + *clientCA = *newClientCA + conn, err = client.connect("s1", server.config.Listener.Addr().String()) + require.NoError(t, err) + require.NotNil(t, conn) + stream, err = client.startStream(conn) + require.NoError(t, err) + sendDialRequest(t, stream) + stream.CloseSend() + + // and one final time, original stream should still be working + sendMsg(t, ogStream) + ogStream.CloseSend() +} diff --git a/lib/proxy/clientmetrics.go b/lib/proxy/clientmetrics.go new file mode 100644 index 0000000000000..8500300057cc0 --- /dev/null +++ b/lib/proxy/clientmetrics.go @@ -0,0 +1,166 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "github.com/gravitational/teleport/lib/utils" + + "github.com/gravitational/trace" + "github.com/prometheus/client_golang/prometheus" +) + +const ( + errorProxyPeerTunnelNotFound = "TUNNEL_NOT_FOUND" + errorProxyPeerTunnelDial = "TUNNEL_DIAL" + errorProxyPeerTunnelDirectDial = "TUNNEL_DIRECT_DIAL" + errorProxyPeerTunnelRPC = "TUNNEL_RPC" + errorProxyPeerFetchProxies = "FETCH_PROXIES" + errorProxyPeerProxiesUnreachable = "PROXIES_UNREACHABLE" +) + +// clientMetrics represents a collection of metrics for a proxy peer client +type clientMetrics struct { + dialErrors *prometheus.CounterVec + connections *prometheus.GaugeVec + rpcs *prometheus.GaugeVec + rpcTotal *prometheus.CounterVec + rpcDuration *prometheus.HistogramVec + messageSent *prometheus.HistogramVec + messageReceived *prometheus.HistogramVec +} + +// newClientMetrics inits and registers client metrics prometheus collectors. +func newClientMetrics() (*clientMetrics, error) { + cm := &clientMetrics{ + dialErrors: prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "proxy_peer", + Subsystem: "client", + Name: "dial_error_total", + Help: "Total number of errors encountered dialing peer proxies.", + }, + []string{"error_type"}, + ), + + connections: prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "proxy_peer", + Subsystem: "client", + Name: "connections", + Help: "Number of currently opened connection to proxy peer servers.", + }, + []string{"local_id", "remote_id", "state"}, + ), + + rpcs: prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "proxy_peer", + Subsystem: "client", + Name: "rpc", + Help: "Number of current client RPC requests.", + }, + []string{"service", "method"}, + ), + + rpcTotal: prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "proxy_peer", + Subsystem: "client", + Name: "rpc_total", + Help: "Total number of client RPC requests.", + }, + []string{"service", "method", "code"}, + ), + + rpcDuration: prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "proxy_peer", + Subsystem: "client", + Name: "rpc_duration_seconds", + Help: "Duration in seconds of RPCs sent by the client.", + }, + []string{"service", "handler", "code"}, + ), + + messageSent: prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "proxy_peer", + Subsystem: "client", + Name: "message_sent_size", + Help: "Size of messages sent by the client.", + }, + []string{"service", "handler"}, + ), + + messageReceived: prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "proxy_peer", + Subsystem: "client", + Name: "message_received_size", + Help: "Size of messages received by the client.", + }, + []string{"service", "handler"}, + ), + } + + if err := utils.RegisterPrometheusCollectors( + cm.dialErrors, + cm.connections, + cm.rpcs, + cm.rpcTotal, + cm.rpcDuration, + cm.messageSent, + cm.messageReceived, + ); err != nil { + return nil, trace.Wrap(err) + } + + return cm, nil +} + +// reportTunnelError reports errors encountered dialing an existing peer tunnel. +func (c *clientMetrics) reportTunnelError(errorType string) { + c.dialErrors.WithLabelValues(errorType).Inc() +} + +// getConnectionGauge is a getter for the connections collector. +func (c *clientMetrics) getConnectionGauge() *prometheus.GaugeVec { + return c.connections +} + +// getRPCGauge is a getter for the rpcs collector. +func (c *clientMetrics) getRPCGauge() *prometheus.GaugeVec { + return c.rpcs +} + +// getRPCCounter is a getter for the rpcTotal collector. +func (c *clientMetrics) getRPCCounter() *prometheus.CounterVec { + return c.rpcTotal +} + +// getRPCDuration is a getter for the rpcDuration collector. +func (c *clientMetrics) getRPCDuration() *prometheus.HistogramVec { + return c.rpcDuration +} + +// getMessageSent is a getter for the messageSent collector. +func (c *clientMetrics) getMessageSent() *prometheus.HistogramVec { + return c.messageSent +} + +// getMessageReceived is a getter for the messageReceived collector. +func (c *clientMetrics) getMessageReceived() *prometheus.HistogramVec { + return c.messageReceived +} diff --git a/lib/proxy/helpers_test.go b/lib/proxy/helpers_test.go new file mode 100644 index 0000000000000..b6a91d4a5d653 --- /dev/null +++ b/lib/proxy/helpers_test.go @@ -0,0 +1,263 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "net" + "testing" + "time" + + clientapi "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/fixtures" + "github.com/gravitational/teleport/lib/tlsca" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +type mockAuthClient struct { + auth.ClientI +} + +func (c mockAuthClient) GetProxies() ([]types.Server, error) { + return []types.Server{}, nil +} + +type mockAccessCache struct { + auth.AccessCache +} + +type mockProxyAccessPoint struct { + auth.ProxyAccessPoint +} + +type mockProxyService struct{} + +func (s *mockProxyService) DialNode(stream clientapi.ProxyService_DialNodeServer) error { + sendErr := make(chan error) + recvErr := make(chan error) + + frame, err := stream.Recv() + if err != nil { + return trace.Wrap(err) + } + + if frame.GetDialRequest() == nil { + return trace.BadParameter("invalid dial request") + } + + err = stream.Send(&clientapi.Frame{ + Message: &clientapi.Frame_ConnectionEstablished{ + ConnectionEstablished: &clientapi.ConnectionEstablished{}, + }, + }) + if err != nil { + return trace.Wrap(err) + } + + go func() { + for { + if _, err := stream.Recv(); err != nil { + recvErr <- err + close(recvErr) + return + } + } + }() + + go func() { + for { + err := stream.Send(&clientapi.Frame{ + Message: &clientapi.Frame_Data{ + &clientapi.Data{Bytes: []byte("pong")}, + }, + }) + if err != nil { + sendErr <- err + close(sendErr) + return + } + } + }() + + select { + case <-stream.Context().Done(): + return stream.Context().Err() + case err := <-recvErr: + return err + case err := <-sendErr: + return err + } +} + +// newSelfSignedCA creates a new CA for testing. +func newSelfSignedCA(t *testing.T) *tlsca.CertAuthority { + rsaKey, err := ssh.ParseRawPrivateKey(fixtures.PEMBytes["rsa"]) + require.NoError(t, err) + + cert, err := tlsca.GenerateSelfSignedCAWithSigner( + rsaKey.(*rsa.PrivateKey), pkix.Name{}, nil, defaults.CATTL, + ) + require.NoError(t, err) + + ca, err := tlsca.FromCertAndSigner(cert, rsaKey.(*rsa.PrivateKey)) + require.NoError(t, err) + + return ca +} + +// certFromIdentity creates a tls config for a given CA and identity. +func certFromIdentity(t *testing.T, ca *tlsca.CertAuthority, ident tlsca.Identity) *tls.Config { + if ident.Username == "" { + ident.Username = "test-user" + } + + subj, err := ident.Subject() + require.NoError(t, err) + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + clock := clockwork.NewRealClock() + + request := tlsca.CertificateRequest{ + Clock: clock, + PublicKey: privateKey.Public(), + Subject: subj, + NotAfter: clock.Now().UTC().Add(time.Minute), + DNSNames: []string{"127.0.0.1"}, + } + certBytes, err := ca.GenerateCertificate(request) + require.NoError(t, err) + + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) + cert, err := tls.X509KeyPair(certBytes, keyPEM) + require.NoError(t, err) + + config := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + return config +} + +// setupClients return a Client object. +func setupClient(t *testing.T, clientCA, serverCA *tlsca.CertAuthority, role types.SystemRole) (*Client, *tls.Config) { + tlsConf := certFromIdentity(t, clientCA, tlsca.Identity{ + Groups: []string{string(role)}, + }) + + getConfigForServer := func() (*tls.Config, error) { + config := tlsConf.Clone() + rootCAs := x509.NewCertPool() + rootCAs.AddCert(serverCA.Cert) + config.RootCAs = rootCAs + return config, nil + } + + client, err := NewClient(ClientConfig{ + ID: "client-proxy", + AuthClient: mockAuthClient{}, + AccessPoint: &mockProxyAccessPoint{}, + TLSConfig: tlsConf, + Clock: clockwork.NewFakeClock(), + GracefulShutdownTimeout: time.Second, + getConfigForServer: getConfigForServer, + sync: func() {}, + }) + require.NoError(t, err) + + t.Cleanup(func() { + client.Shutdown() + }) + + return client, tlsConf +} + +// setupServer return a Server object. +func setupServer(t *testing.T, name string, serverCA, clientCA *tlsca.CertAuthority, role types.SystemRole) (*Server, *tls.Config, types.Server) { + tlsConf := certFromIdentity(t, serverCA, tlsca.Identity{ + Groups: []string{string(role)}, + }) + + getConfigForClient := func(chi *tls.ClientHelloInfo) (*tls.Config, error) { + config := tlsConf.Clone() + config.ClientAuth = tls.RequireAndVerifyClientCert + clientCAs := x509.NewCertPool() + clientCAs.AddCert(clientCA.Cert) + config.ClientCAs = clientCAs + return config, nil + } + + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + + server, err := NewServer(ServerConfig{ + AccessCache: &mockAccessCache{}, + Listener: listener, + TLSConfig: tlsConf, + ClusterDialer: &mockClusterDialer{}, + getConfigForClient: getConfigForClient, + service: &mockProxyService{}, + }) + require.NoError(t, err) + + ts, err := types.NewServer( + name, types.KindProxy, + types.ServerSpecV2{PeerAddr: listener.Addr().String()}, + ) + require.NoError(t, err) + + go server.Serve() + t.Cleanup(func() { + require.NoError(t, server.Close()) + }) + + return server, tlsConf, ts +} + +func sendDialRequest(t *testing.T, stream clientapi.ProxyService_DialNodeClient) { + err := stream.Send(&clientapi.Frame{ + Message: &clientapi.Frame_DialRequest{ + DialRequest: &clientapi.DialRequest{}, + }, + }) + require.NoError(t, err) + + frame, err := stream.Recv() + require.NoError(t, err) + require.NotNil(t, frame.GetConnectionEstablished()) +} + +func sendMsg(t *testing.T, stream clientapi.ProxyService_DialNodeClient) { + err := stream.Send(&clientapi.Frame{ + Message: &clientapi.Frame_Data{ + &clientapi.Data{Bytes: []byte("ping")}, + }, + }) + require.NoError(t, err) +} diff --git a/lib/proxy/interceptor.go b/lib/proxy/interceptor.go new file mode 100644 index 0000000000000..de838322c48ff --- /dev/null +++ b/lib/proxy/interceptor.go @@ -0,0 +1,77 @@ +// Copyright 2022 Gravitational, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "sync" + + "github.com/gravitational/trace/trail" + "google.golang.org/grpc" +) + +// streamWrapper wraps around the embedded grpc.ClientStream +// and intercepts the RecvMsg method calls decreading the number of +// streams counter. +type streamWrapper struct { + grpc.ClientStream + wg *sync.WaitGroup + once sync.Once +} + +func (s *streamWrapper) CloseSend() error { + err := s.ClientStream.CloseSend() + s.decreaseCounter() + return err +} + +func (s *streamWrapper) SendMsg(m interface{}) error { + err := s.ClientStream.SendMsg(m) + if err != nil { + s.decreaseCounter() + } + return err +} + +func (s *streamWrapper) RecvMsg(m interface{}) error { + err := s.ClientStream.RecvMsg(m) + if err != nil { + s.decreaseCounter() + } + return err +} + +func (s *streamWrapper) decreaseCounter() { + s.once.Do(func() { + s.wg.Done() + }) +} + +// streamCounterInterceptor is GPRC client stream interceptor that +// counts the number of current open streams for the purpose of +// gracefully shutdown a draining grpc client. +func streamCounterInterceptor(wg *sync.WaitGroup) grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + s, err := streamer(ctx, desc, cc, method, opts...) + if err != nil { + return nil, trail.ToGRPC(err) + } + wg.Add(1) + return &streamWrapper{ + ClientStream: s, + wg: wg, + }, nil + } +} diff --git a/lib/proxy/middleware.go b/lib/proxy/middleware.go deleted file mode 100644 index 3db35914c9822..0000000000000 --- a/lib/proxy/middleware.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2022 Gravitational, Inc -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proxy - -import ( - "github.com/gravitational/trace" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -// errorStreamInterceptor is a GPRC stream interceptor that handles converting -// errors to the appropriate grpc status code. -func errorStreamInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - err := handler(srv, stream) - return toGRPCError(err) -} - -// errorHandler converts trace errors to grpc errors with appropriate status codes. -func toGRPCError(err error) error { - if err == nil { - return nil - } - if trace.IsNotFound(err) { - return status.Error(codes.NotFound, err.Error()) - } - if trace.IsBadParameter(err) { - return status.Error(codes.InvalidArgument, err.Error()) - } - return status.Error(codes.Internal, err.Error()) -} diff --git a/lib/proxy/reporter.go b/lib/proxy/reporter.go new file mode 100644 index 0000000000000..8da0d11f5e981 --- /dev/null +++ b/lib/proxy/reporter.go @@ -0,0 +1,88 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "time" + + "github.com/prometheus/client_golang/prometheus" +) + +type metrics interface { + getConnectionGauge() *prometheus.GaugeVec + getRPCGauge() *prometheus.GaugeVec + getRPCCounter() *prometheus.CounterVec + getRPCDuration() *prometheus.HistogramVec + getMessageSent() *prometheus.HistogramVec + getMessageReceived() *prometheus.HistogramVec +} + +// reporter is grpc request specific metrics reporter. +type reporter struct { + metrics +} + +// newReporter returns a new reporter object that is used to +// report metrics relative to proxy peer client or server. +func newReporter(m metrics) *reporter { + return &reporter{ + metrics: m, + } +} + +// resetConnections resets the current number of connections. +func (r *reporter) resetConnections() { + r.getConnectionGauge().Reset() +} + +// incConnection increases the current number of connections. +func (r *reporter) incConnection(localID, remoteID, state string) { + r.getConnectionGauge().WithLabelValues(localID, remoteID, state).Inc() +} + +// decConnection decreases the current number of connections. +func (r *reporter) decConnection(localAddr, remoteAddr, state string) { + r.getConnectionGauge().WithLabelValues(localAddr, remoteAddr, state).Dec() +} + +// incRPC increases the current number of rpcs. +func (r *reporter) incRPC(service, method string) { + r.getRPCGauge().WithLabelValues(service, method).Inc() +} + +// decRPC decreases the current number of rpcs. +func (r *reporter) decRPC(service, method string) { + r.getRPCGauge().WithLabelValues(service, method).Dec() +} + +// countRPC reports the total number of handled rpcs. +func (r *reporter) countRPC(service, method, code string) { + r.getRPCCounter().WithLabelValues(service, method, code).Inc() +} + +// timeRPC reports the duration of handled rpcs. +func (r *reporter) timeRPC(service, method, code string, duration time.Duration) { + r.getRPCDuration().WithLabelValues(service, method, code).Observe(duration.Seconds()) +} + +// measureMessageSent reports the size of sent messages. +func (r *reporter) measureMessageSent(service, method string, size float64) { + r.getMessageSent().WithLabelValues(service, method).Observe(size) +} + +// measureMessageReceived reports the size of received messages. +func (r *reporter) measureMessageReceived(service, method string, size float64) { + r.getMessageReceived().WithLabelValues(service, method).Observe(size) +} diff --git a/lib/proxy/server.go b/lib/proxy/server.go index 55ca9695b7460..f60c78f7238f2 100644 --- a/lib/proxy/server.go +++ b/lib/proxy/server.go @@ -23,6 +23,8 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/metadata" "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -44,17 +46,32 @@ type ServerConfig struct { Log logrus.FieldLogger // getConfigForClient gets the client tls config. + // configurable for testing purposes. getConfigForClient func(*tls.ClientHelloInfo) (*tls.Config, error) + + // service is a custom ProxyServiceServer + // configurable for testing purposes. + service proto.ProxyServiceServer } // checkAndSetDefaults checks and sets default values func (c *ServerConfig) checkAndSetDefaults() error { + if c.Log == nil { + c.Log = logrus.New() + } + c.Log = c.Log.WithField( + trace.Component, + teleport.Component(teleport.ComponentProxy, "peer"), + ) + if c.AccessCache == nil { return trace.BadParameter("missing access cache") } + if c.Listener == nil { return trace.BadParameter("missing listener") } + if c.ClusterDialer == nil { return trace.BadParameter("missing cluster dialer server") } @@ -62,16 +79,10 @@ func (c *ServerConfig) checkAndSetDefaults() error { if c.TLSConfig == nil { return trace.BadParameter("missing tls config") } + if len(c.TLSConfig.Certificates) == 0 { return trace.BadParameter("missing tls certificate") } - if c.Log == nil { - c.Log = logrus.New() - } - c.Log = c.Log.WithField( - trace.Component, - teleport.Component(teleport.ComponentProxy, "peer"), - ) c.TLSConfig = c.TLSConfig.Clone() c.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert @@ -79,16 +90,22 @@ func (c *ServerConfig) checkAndSetDefaults() error { if c.getConfigForClient == nil { c.getConfigForClient = getConfigForClient(c.TLSConfig, c.AccessCache, c.Log) } - c.TLSConfig.GetConfigForClient = c.getConfigForClient + if c.service == nil { + c.service = &proxyService{ + c.ClusterDialer, + c.Log, + } + } + return nil } // Server is a proxy service server using grpc and tls. type Server struct { - server *grpc.Server config ServerConfig + server *grpc.Server } // NewServer creates a new proxy server instance. @@ -98,15 +115,18 @@ func NewServer(config ServerConfig) (*Server, error) { return nil, trace.Wrap(err) } - service := &proxyService{ - config.ClusterDialer, - config.Log, + metrics, err := newServerMetrics() + if err != nil { + return nil, trace.Wrap(err) } + reporter := newReporter(metrics) + transportCreds := newProxyCredentials(credentials.NewTLS(config.TLSConfig)) server := grpc.NewServer( grpc.Creds(transportCreds), - grpc.ChainStreamInterceptor(metadata.StreamServerInterceptor, errorStreamInterceptor), + grpc.StatsHandler(newStatsHandler(reporter)), + grpc.ChainStreamInterceptor(metadata.StreamServerInterceptor, utils.GRPCServerStreamErrorInterceptor), grpc.KeepaliveParams(keepalive.ServerParameters{ Time: peerKeepAlive, Timeout: peerTimeout, @@ -116,18 +136,21 @@ func NewServer(config ServerConfig) (*Server, error) { PermitWithoutStream: true, }), ) - proto.RegisterProxyServiceServer(server, service) + + proto.RegisterProxyServiceServer(server, config.service) return &Server{ - server: server, config: config, + server: server, }, nil } // Serve starts the proxy server. func (s *Server) Serve() error { - err := s.server.Serve(s.config.Listener) - return trace.Wrap(err) + if err := s.server.Serve(s.config.Listener); err != nil && err != grpc.ErrServerStopped { + return trace.Wrap(err) + } + return nil } // Close closes the proxy server immediately. diff --git a/lib/proxy/server_test.go b/lib/proxy/server_test.go index e3d83e77d9a34..b93ace0436f65 100644 --- a/lib/proxy/server_test.go +++ b/lib/proxy/server_test.go @@ -15,166 +15,46 @@ package proxy import ( - "context" - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "net" "testing" - "time" - "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/fixtures" - "github.com/gravitational/teleport/lib/tlsca" - "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" - "golang.org/x/crypto/ssh" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" ) -// newSelfSignedCA creates a new CA for testing. -func newSelfSignedCA(t *testing.T) *tlsca.CertAuthority { - rsaKey, err := ssh.ParseRawPrivateKey(fixtures.PEMBytes["rsa"]) - require.NoError(t, err) - - cert, err := tlsca.GenerateSelfSignedCAWithSigner( - rsaKey.(*rsa.PrivateKey), pkix.Name{}, nil, defaults.CATTL, - ) - require.NoError(t, err) - - ca, err := tlsca.FromCertAndSigner(cert, rsaKey.(*rsa.PrivateKey)) - require.NoError(t, err) - - return ca -} - -// certFromIdentity creates a tls config for a given CA and identity. -func certFromIdentity(t *testing.T, ca *tlsca.CertAuthority, ident tlsca.Identity) *tls.Config { - if ident.Username == "" { - ident.Username = "test-user" - } - - subj, err := ident.Subject() - require.NoError(t, err) - - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - - clock := clockwork.NewRealClock() - - request := tlsca.CertificateRequest{ - Clock: clock, - PublicKey: privateKey.Public(), - Subject: subj, - NotAfter: clock.Now().UTC().Add(time.Minute), - DNSNames: []string{"127.0.0.1"}, - } - certBytes, err := ca.GenerateCertificate(request) - require.NoError(t, err) - - keyPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) - cert, err := tls.X509KeyPair(certBytes, keyPEM) - require.NoError(t, err) - - pool := x509.NewCertPool() - pool.AddCert(ca.Cert) - - config := &tls.Config{ - Certificates: []tls.Certificate{cert}, - RootCAs: pool, - } - - return config -} - -type mockAccessCache struct { - auth.AccessCache -} - // TestServerTLS ensures that only trusted certificates with the proxy role // are accepted by the server. func TestServerTLS(t *testing.T) { ca1 := newSelfSignedCA(t) ca2 := newSelfSignedCA(t) - tests := []struct { - desc string - server *tls.Config - client *tls.Config - assertErr require.ErrorAssertionFunc - }{ - { - desc: "trusted certificates with proxy roles", - server: certFromIdentity(t, ca1, tlsca.Identity{ - Groups: []string{string(types.RoleProxy)}, - }), - client: certFromIdentity(t, ca1, tlsca.Identity{ - Groups: []string{string(types.RoleProxy)}, - }), - assertErr: require.NoError, - }, - { - desc: "trusted certificates with incorrect server role", - server: certFromIdentity(t, ca1, tlsca.Identity{ - Groups: []string{string(types.RoleAdmin)}, - }), - client: certFromIdentity(t, ca1, tlsca.Identity{ - Groups: []string{string(types.RoleProxy)}, - }), - assertErr: require.Error, - }, - { - desc: "certificates with correct role from different CAs", - server: certFromIdentity(t, ca1, tlsca.Identity{ - Groups: []string{string(types.RoleProxy)}, - }), - client: certFromIdentity(t, ca2, tlsca.Identity{ - Groups: []string{string(types.RoleProxy)}, - }), - assertErr: require.Error, - }, - } - - for _, tc := range tests { - t.Run(tc.desc, func(t *testing.T) { - listener, err := net.Listen("tcp", "localhost:0") - require.NoError(t, err) - - clientCAs := tc.server.RootCAs - tc.server.RootCAs = nil - - server, err := NewServer(ServerConfig{ - AccessCache: &mockAccessCache{}, - Listener: listener, - TLSConfig: tc.server, - ClusterDialer: &mockClusterDialer{}, - getConfigForClient: func(chi *tls.ClientHelloInfo) (*tls.Config, error) { - config := tc.server.Clone() - config.ClientAuth = tls.RequireAndVerifyClientCert - config.ClientCAs = clientCAs - return config, nil - }, - }) - require.NoError(t, err) - go server.Serve() - t.Cleanup(func() { server.Close() }) - - creds := newProxyCredentials(credentials.NewTLS(tc.client)) - conn, err := grpc.Dial(listener.Addr().String(), grpc.WithTransportCredentials(creds)) - require.NoError(t, err) - - defer conn.Close() - - client := proto.NewProxyServiceClient(conn) - _, err = client.DialNode(context.Background()) - tc.assertErr(t, err) - }) - } + // trusted certificates with proxy roles. + client1, _ := setupClient(t, ca1, ca1, types.RoleProxy) + _, _, serverDef1 := setupServer(t, "s1", ca1, ca1, types.RoleProxy) + err := client1.updateConnections([]types.Server{serverDef1}) + require.NoError(t, err) + stream, _, err := client1.dial([]string{"s1"}) + require.NoError(t, err) + require.NotNil(t, stream) + sendDialRequest(t, stream) + stream.CloseSend() + + // trusted certificates with incorrect server role. + client2, _ := setupClient(t, ca1, ca1, types.RoleNode) + _, _, serverDef2 := setupServer(t, "s2", ca1, ca1, types.RoleProxy) + err = client2.updateConnections([]types.Server{serverDef2}) + require.NoError(t, err) // connection succeeds but is in transient failure state + _, _, err = client2.dial([]string{"s2"}) + require.Error(t, err) + + // certificates with correct role from different CAs + client3, _ := setupClient(t, ca1, ca2, types.RoleProxy) + _, _, serverDef3 := setupServer(t, "s3", ca2, ca1, types.RoleProxy) + err = client3.updateConnections([]types.Server{serverDef3}) + require.NoError(t, err) + stream, _, err = client3.dial([]string{"s3"}) + require.NoError(t, err) + require.NotNil(t, stream) + sendDialRequest(t, stream) + stream.CloseSend() } diff --git a/lib/proxy/servermetrics.go b/lib/proxy/servermetrics.go new file mode 100644 index 0000000000000..a5f0fc986a75e --- /dev/null +++ b/lib/proxy/servermetrics.go @@ -0,0 +1,140 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "github.com/gravitational/teleport/lib/utils" + + "github.com/gravitational/trace" + "github.com/prometheus/client_golang/prometheus" +) + +// serverMetrics represents a collection of metrics for a proxy peer server +type serverMetrics struct { + connections *prometheus.GaugeVec + rpcs *prometheus.GaugeVec + rpcTotal *prometheus.CounterVec + rpcDuration *prometheus.HistogramVec + messageSent *prometheus.HistogramVec + messageReceived *prometheus.HistogramVec +} + +// newServerMetrics inits and registers client metrics prometheus collectors. +func newServerMetrics() (*serverMetrics, error) { + sm := &serverMetrics{ + connections: prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "proxy_peer", + Subsystem: "server", + Name: "connections", + Help: "Number of currently opened connection to proxy peer clients.", + }, + []string{"local_id", "remote_id", "state"}, + ), + + rpcs: prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "proxy_peer", + Subsystem: "server", + Name: "rpc", + Help: "Number of current server RPC requests.", + }, + []string{"service", "method"}, + ), + + rpcTotal: prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "proxy_peer", + Subsystem: "server", + Name: "rpc_total", + Help: "Total number of server RPC requests.", + }, + []string{"service", "method", "code"}, + ), + + rpcDuration: prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "proxy_peer", + Subsystem: "server", + Name: "rpc_duration_seconds", + Help: "Duration in seconds of RPCs sent by the server.", + }, + []string{"service", "handler", "code"}, + ), + + messageSent: prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "proxy_peer", + Subsystem: "server", + Name: "message_sent_size", + Help: "Size of messages sent by the server.", + }, + []string{"service", "handler"}, + ), + + messageReceived: prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "proxy_peer", + Subsystem: "server", + Name: "message_received_size", + Help: "Size of messages received by the server.", + }, + []string{"service", "handler"}, + ), + } + + if err := utils.RegisterPrometheusCollectors( + sm.connections, + sm.rpcs, + sm.rpcTotal, + sm.rpcDuration, + sm.messageSent, + sm.messageReceived, + ); err != nil { + return nil, trace.Wrap(err) + } + + return sm, nil +} + +// getConnectionGauge is a getter for the connectionCounter collector. +func (s *serverMetrics) getConnectionGauge() *prometheus.GaugeVec { + return s.connections +} + +// getRPCGauge is a getter for the rpcs collector. +func (s *serverMetrics) getRPCGauge() *prometheus.GaugeVec { + return s.rpcs +} + +// getRPCCounter is a getter for the rpcTotal collector. +func (s *serverMetrics) getRPCCounter() *prometheus.CounterVec { + return s.rpcTotal +} + +// getRPCDuration is a getter for the rpcDuration collector. +func (s *serverMetrics) getRPCDuration() *prometheus.HistogramVec { + return s.rpcDuration +} + +// getMessageSent is a getter for the messageSent collector. +func (s *serverMetrics) getMessageSent() *prometheus.HistogramVec { + return s.messageSent +} + +// getMessageReceived is a getter for the messageReceived collector. +func (s *serverMetrics) getMessageReceived() *prometheus.HistogramVec { + return s.messageReceived +} diff --git a/lib/proxy/stats.go b/lib/proxy/stats.go new file mode 100644 index 0000000000000..0d37b741718e9 --- /dev/null +++ b/lib/proxy/stats.go @@ -0,0 +1,129 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "io" + "strings" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" +) + +type ctxKey struct{} + +type ( + serviceKey ctxKey + methodKey ctxKey + remoteAddrKey ctxKey + localAddrKey ctxKey +) + +// StatsHandler is for gRPC stats +type statsHandler struct { + reporter *reporter +} + +func newStatsHandler(r *reporter) stats.Handler { + return &statsHandler{ + reporter: r, + } +} + +// TagConn implements per-Connection context tagging. +func (s *statsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context { + ctx = context.WithValue(ctx, remoteAddrKey{}, info.RemoteAddr.String()) + ctx = context.WithValue(ctx, localAddrKey{}, info.LocalAddr.String()) + return ctx +} + +// HandleRPC implements per-Connection stats reporting. +func (s *statsHandler) HandleConn(ctx context.Context, connStats stats.ConnStats) { + // client connection stats are monitored by the monitor() function in client.go + if connStats.IsClient() { + return + } + + remoteAddr, _ := ctx.Value(remoteAddrKey{}).(string) + localAddr, _ := ctx.Value(localAddrKey{}).(string) + + switch connStats.(type) { + case *stats.ConnBegin: + s.reporter.incConnection(localAddr, remoteAddr, "SERVER_CONN") + case *stats.ConnEnd: + s.reporter.decConnection(localAddr, remoteAddr, "SERVER_CONN") + } +} + +// TagRPC implements per-RPC context tagging. +func (s *statsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { + service, method := split(info.FullMethodName) + ctx = context.WithValue(ctx, serviceKey{}, service) + ctx = context.WithValue(ctx, methodKey{}, method) + return ctx +} + +// HandleRPC implements per-RPC stats reporting. +func (s *statsHandler) HandleRPC(ctx context.Context, rpcStats stats.RPCStats) { + service, _ := ctx.Value(serviceKey{}).(string) + method, _ := ctx.Value(methodKey{}).(string) + + switch rs := rpcStats.(type) { + case *stats.InPayload: + s.reporter.measureMessageReceived(service, method, float64(rs.WireLength)) + case *stats.OutPayload: + s.reporter.measureMessageSent(service, method, float64(rs.WireLength)) + case *stats.Begin: + s.reporter.incRPC(service, method) + case *stats.End: + code := codes.OK.String() + if isError(rs.Error) { + code = status.Code(rs.Error).String() + } + s.reporter.decRPC(service, method) + s.reporter.countRPC(service, method, code) + s.reporter.timeRPC(service, method, code, rs.EndTime.Sub(rs.BeginTime)) + } +} + +// split splits a grpc request path into service and method strings +// request format /%s/%s +func split(request string) (string, string) { + if i := strings.LastIndex(request, "/"); i >= 0 { + return request[1:i], request[i+1:] + } + return "unknown", "unknown" +} + +// isError returns false if the supplied error +// - is nil +// - has a codes.OK code +// - is io.EOF +func isError(err error) bool { + if err == nil { + return false + } + + grpcErr := status.Convert(err) + code := grpcErr.Code() + if code == codes.OK { + return false + } + + eof := status.Convert(io.EOF) + return code != eof.Code() || grpcErr.Message() != eof.Message() +} diff --git a/lib/service/service.go b/lib/service/service.go index 7f09bdf1a2c91..4de5eae16d302 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -3235,11 +3235,11 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { if alpnRouter != nil { grpcServer = grpc.NewServer( grpc.ChainUnaryInterceptor( - utils.ErrorConvertUnaryInterceptor, + utils.GRPCServerUnaryErrorInterceptor , proxyLimiter.UnaryServerInterceptor(), ), grpc.ChainStreamInterceptor( - utils.ErrorConvertStreamInterceptor, + utils.GRPCServerStreamErrorInterceptor , proxyLimiter.StreamServerInterceptor, ), ) diff --git a/lib/utils/grpc.go b/lib/utils/grpc.go index 15b67b125f28e..941911e66c70b 100644 --- a/lib/utils/grpc.go +++ b/lib/utils/grpc.go @@ -20,16 +20,72 @@ import ( "context" "github.com/gravitational/trace/trail" + "google.golang.org/grpc" ) -func ErrorConvertUnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { +// grpcServerStreamWrapper wraps around the embedded grpc.ServerStream +// and intercepts the RecvMsg and SendMsg method calls converting errors to the +// appropriate grpc status error. +type grpcServerStreamWrapper struct { + grpc.ServerStream +} + +// SendMsg wraps around ServerStream.SendMsg and adds metrics reporting +func (s *grpcServerStreamWrapper) SendMsg(m interface{}) error { + return trail.FromGRPC(s.ServerStream.SendMsg(m)) +} + +// RecvMsg wraps around ServerStream.RecvMsg and adds metrics reporting +func (s *grpcServerStreamWrapper) RecvMsg(m interface{}) error { + return trail.FromGRPC(s.ServerStream.RecvMsg(m)) +} + +// grpcClientStreamWrapper wraps around the embedded grpc.ClientStream +// and intercepts the RecvMsg and SendMsg method calls converting errors to the +// appropriate grpc status error. +type grpcClientStreamWrapper struct { + grpc.ClientStream +} + +// SendMsg wraps around ClientStream.SendMsg +func (s *grpcClientStreamWrapper) SendMsg(m interface{}) error { + return trail.FromGRPC(s.ClientStream.SendMsg(m)) +} + +// RecvMsg wraps around ClientStream.RecvMsg +func (s *grpcClientStreamWrapper) RecvMsg(m interface{}) error { + return trail.FromGRPC(s.ClientStream.RecvMsg(m)) +} + +// GRPCServerUnaryErrorInterceptor is a GPRC unary server interceptor that +// handles converting errors to the appropriate grpc status error. +func GRPCServerUnaryErrorInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { resp, err := handler(ctx, req) return resp, trail.ToGRPC(err) } -func ErrorConvertStreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - return trail.ToGRPC(handler(srv, ss)) +// GRPCClientUnaryErrorInterceptor is a GPRC unary client interceptor that +// handles converting errors to the appropriate grpc status error. +func GRPCClientUnaryErrorInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + return trail.FromGRPC(invoker(ctx, method, req, reply, cc, opts...)) +} + +// GRPCServerStreamErrorInterceptor is a GPRC server stream interceptor that +// handles converting errors to the appropriate grpc status error. +func GRPCServerStreamErrorInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + serverWrapper := &grpcServerStreamWrapper{ss} + return trail.ToGRPC(handler(srv, serverWrapper)) +} + +// GRPCClientStreamErrorInterceptor is GPRC client stream interceptor that +// handles converting errors to the appropriate grpc status error. +func GRPCClientStreamErrorInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + s, err := streamer(ctx, desc, cc, method, opts...) + if err != nil { + return nil, trail.ToGRPC(err) + } + return &grpcClientStreamWrapper{s}, nil } // ChainUnaryServerInterceptors takes 1 or more grpc.UnaryServerInterceptors and diff --git a/lib/utils/grpc_test.go b/lib/utils/grpc_test.go index 6027dc4d2b810..79bdb7d99e379 100644 --- a/lib/utils/grpc_test.go +++ b/lib/utils/grpc_test.go @@ -19,10 +19,14 @@ package utils import ( "context" "fmt" + "net" "testing" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" "google.golang.org/grpc" + + pb "google.golang.org/grpc/examples/features/proto/echo" ) func TestChainUnaryServerInterceptors(t *testing.T) { @@ -66,3 +70,60 @@ func TestChainStreamServerInterceptors(t *testing.T) { err := chainedInterceptor(nil, nil, nil, handler) require.Equal(t, "1 2 3 4 handler", err.Error()) } + +// service is used to implement EchoServer +type service struct { + pb.UnimplementedEchoServer +} + +func (s *service) UnaryEcho(ctx context.Context, in *pb.EchoRequest) (*pb.EchoResponse, error) { + return nil, trace.NotFound("not found") +} + +func (s *service) BidirectionalStreamingEcho(stream pb.Echo_BidirectionalStreamingEchoServer) error { + return trace.AlreadyExists("already exists") +} + +// TestGRPCErrorWrapping tests the error wrapping capability of the client +// and server unary and stream interceptors +func TestGRPCErrorWrapping(t *testing.T) { + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + + server := grpc.NewServer( + grpc.ChainUnaryInterceptor(GRPCServerUnaryErrorInterceptor), + grpc.ChainStreamInterceptor(GRPCServerStreamErrorInterceptor), + ) + pb.RegisterEchoServer(server, &service{}) + go func() { + server.Serve(listener) + }() + defer server.Stop() + + conn, err := grpc.Dial( + listener.Addr().String(), + grpc.WithInsecure(), + grpc.WithChainUnaryInterceptor(GRPCClientUnaryErrorInterceptor), + grpc.WithChainStreamInterceptor(GRPCClientStreamErrorInterceptor), + ) + require.NoError(t, err) + defer conn.Close() + + // test unary interceptor + client := pb.NewEchoClient(conn) + resp, err := client.UnaryEcho(context.Background(), &pb.EchoRequest{Message: "Hi!"}) + require.Nil(t, resp) + require.True(t, trace.IsNotFound(err)) + require.Equal(t, err.Error(), "not found") + + // test stream interceptor + stream, err := client.BidirectionalStreamingEcho(context.Background()) + require.NoError(t, err) + + err = stream.Send(&pb.EchoRequest{Message: "Hi!"}) + require.NoError(t, err) + + resp, err = stream.Recv() + require.True(t, trace.IsAlreadyExists(err)) + require.Equal(t, err.Error(), "already exists") +} diff --git a/lib/utils/prometheus.go b/lib/utils/prometheus.go index 13f26610555dd..8ba1d1246f03c 100644 --- a/lib/utils/prometheus.go +++ b/lib/utils/prometheus.go @@ -31,15 +31,16 @@ import ( // - returns an error if a collector does not fulfill the consistency and // uniqueness criteria func RegisterPrometheusCollectors(collectors ...prometheus.Collector) error { + var errs []error for _, c := range collectors { if err := prometheus.Register(c); err != nil { if _, ok := err.(prometheus.AlreadyRegisteredError); ok { continue } - return trace.Wrap(err) + errs = append(errs, err) } } - return nil + return trace.NewAggregate(errs...) } // BuildCollector provides a Collector that contains build information gauge