From 084455ce260618cf18fb7b00fbee3e19890e8ec1 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 18 Mar 2022 01:21:33 +0100 Subject: [PATCH] Keep multiple per-node remoteConns in localSite (#11074) (#11186) --- lib/reversetunnel/conn.go | 10 --- lib/reversetunnel/localsite.go | 62 +++++++++++------- lib/reversetunnel/localsite_test.go | 99 +++++++++++++++++++++++++++++ 3 files changed, 137 insertions(+), 34 deletions(-) create mode 100644 lib/reversetunnel/localsite_test.go diff --git a/lib/reversetunnel/conn.go b/lib/reversetunnel/conn.go index 0da13e92edf87..834e159b176a5 100644 --- a/lib/reversetunnel/conn.go +++ b/lib/reversetunnel/conn.go @@ -17,7 +17,6 @@ limitations under the License. package reversetunnel import ( - "context" "fmt" "net" "sync" @@ -69,11 +68,6 @@ type remoteConn struct { // Used to make sure calling Close on the connection multiple times is safe. closed int32 - // closeContext and closeCancel are used to signal to any waiting goroutines - // that the remoteConn is now closed and to release any resources. - closeContext context.Context - closeCancel context.CancelFunc - // clock is used to control time in tests. clock clockwork.Clock @@ -120,8 +114,6 @@ func newRemoteConn(cfg *connConfig) *remoteConn { newProxiesC: make(chan []types.Server, 100), } - c.closeContext, c.closeCancel = context.WithCancel(context.Background()) - return c } @@ -130,8 +122,6 @@ func (c *remoteConn) String() string { } func (c *remoteConn) Close() error { - defer c.closeCancel() - // If the connection has already been closed, return right away. if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) { return nil diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 6599e5af7fb60..9f85e03bf9078 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -67,7 +67,7 @@ func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSi accessPoint: accessPoint, certificateCache: certificateCache, domainName: domainName, - remoteConns: make(map[connKey]*remoteConn), + remoteConns: make(map[connKey][]*remoteConn), clock: srv.Clock, log: log.WithFields(log.Fields{ trace.Component: teleport.ComponentReverseTunnelServer, @@ -89,8 +89,6 @@ func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSi // // it implements RemoteSite interface type localSite struct { - sync.Mutex - log log.FieldLogger domainName string srv *server @@ -104,8 +102,11 @@ type localSite struct { // certificateCache caches host certificates for the forwarding server. certificateCache *certificateCache - // remoteConns maps UUID and connection type to an remote connection. - remoteConns map[connKey]*remoteConn + // remoteConns maps UUID and connection type to remote connections, oldest to newest. + remoteConns map[connKey][]*remoteConn + + // remoteConnsMtx protects remoteConns. + remoteConnsMtx sync.Mutex // clock is used to control time in tests. clock clockwork.Clock @@ -117,8 +118,8 @@ type localSite struct { // GetTunnelsCount always the number of tunnel connections to this cluster. func (s *localSite) GetTunnelsCount() int { - s.Lock() - defer s.Unlock() + s.remoteConnsMtx.Lock() + defer s.remoteConnsMtx.Unlock() return len(s.remoteConns) } @@ -346,8 +347,8 @@ with the cluster.`, params.ConnType, dreq.Address, tunnelErr) } func (s *localSite) addConn(nodeID string, connType types.TunnelType, conn net.Conn, sconn ssh.Conn) (*remoteConn, error) { - s.Lock() - defer s.Unlock() + s.remoteConnsMtx.Lock() + defer s.remoteConnsMtx.Unlock() rconn := newRemoteConn(&connConfig{ conn: conn, @@ -363,7 +364,7 @@ func (s *localSite) addConn(nodeID string, connType types.TunnelType, conn net.C uuid: nodeID, connType: connType, } - s.remoteConns[key] = rconn + s.remoteConns[key] = append(s.remoteConns[key], rconn) return rconn, nil } @@ -372,10 +373,13 @@ func (s *localSite) addConn(nodeID string, connType types.TunnelType, conn net.C // list so that remote connection can notify the remote agent // about the list update func (s *localSite) fanOutProxies(proxies []types.Server) { - s.Lock() - defer s.Unlock() - for _, conn := range s.remoteConns { - conn.updateProxies(proxies) + s.remoteConnsMtx.Lock() + defer s.remoteConnsMtx.Unlock() + + for _, conns := range s.remoteConns { + for _, conn := range conns { + conn.updateProxies(proxies) + } } } @@ -442,14 +446,22 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch } func (s *localSite) getRemoteConn(dreq *sshutils.DialReq) (*remoteConn, error) { - s.Lock() - defer s.Unlock() + s.remoteConnsMtx.Lock() + defer s.remoteConnsMtx.Unlock() // Loop over all connections and remove and invalid connections from the // connection map. - for key := range s.remoteConns { - if s.remoteConns[key].isInvalid() { + for key, conns := range s.remoteConns { + validConns := conns[:0] + for _, conn := range conns { + if !conn.isInvalid() { + validConns = append(validConns, conn) + } + } + if len(validConns) == 0 { delete(s.remoteConns, key) + } else { + s.remoteConns[key] = validConns } } @@ -457,15 +469,17 @@ func (s *localSite) getRemoteConn(dreq *sshutils.DialReq) (*remoteConn, error) { uuid: dreq.ServerID, connType: dreq.ConnType, } - rconn, ok := s.remoteConns[key] - if !ok { + if len(s.remoteConns[key]) == 0 { return nil, trace.NotFound("no %v reverse tunnel for %v found", dreq.ConnType, dreq.ServerID) } - if !rconn.isReady() { - return nil, trace.NotFound("%v is offline: no active %v tunnels found", dreq.ConnType, dreq.ServerID) - } - return rconn, nil + conns := s.remoteConns[key] + for i := len(conns) - 1; i >= 0; i-- { + if conns[i].isReady() { + return conns[i], nil + } + } + return nil, trace.NotFound("%v is offline: no active %v tunnels found", dreq.ConnType, dreq.ServerID) } func (s *localSite) chanTransportConn(rconn *remoteConn, dreq *sshutils.DialReq) (net.Conn, error) { diff --git a/lib/reversetunnel/localsite_test.go b/lib/reversetunnel/localsite_test.go new file mode 100644 index 0000000000000..27f61faefd422 --- /dev/null +++ b/lib/reversetunnel/localsite_test.go @@ -0,0 +1,99 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reversetunnel + +import ( + "context" + "net" + "testing" + "time" + + "github.com/google/uuid" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/sshutils" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +func TestLocalSiteOverlap(t *testing.T) { + t.Parallel() + + // to stop (*localSite).periodicFunctions() + ctx, ctxCancel := context.WithCancel(context.Background()) + ctxCancel() + + srv := &server{ + ctx: ctx, + newAccessPoint: auth.NoCache, + } + + site, err := newlocalSite(srv, "clustername", &mockLocalSiteClient{}) + require.NoError(t, err) + + nodeID := uuid.NewString() + connType := types.NodeTunnel + dreq := &sshutils.DialReq{ + ServerID: nodeID, + ConnType: connType, + } + + conn1, err := site.addConn(nodeID, connType, mockRemoteConnConn{}, nil) + require.NoError(t, err) + + conn2, err := site.addConn(nodeID, connType, mockRemoteConnConn{}, nil) + require.NoError(t, err) + + c, err := site.getRemoteConn(dreq) + require.True(t, trace.IsNotFound(err)) + require.Nil(t, c) + + conn1.setLastHeartbeat(time.Now()) + c, err = site.getRemoteConn(dreq) + require.NoError(t, err) + require.Equal(t, conn1, c) + + conn2.setLastHeartbeat(time.Now()) + c, err = site.getRemoteConn(dreq) + require.NoError(t, err) + require.Equal(t, conn2, c) + + conn2.markInvalid(nil) + c, err = site.getRemoteConn(dreq) + require.NoError(t, err) + require.Equal(t, conn1, c) + + conn1.markInvalid(nil) + c, err = site.getRemoteConn(dreq) + require.True(t, trace.IsNotFound(err)) + require.Nil(t, c) +} + +type mockLocalSiteClient struct { + auth.Client +} + +// called by (*localSite).sshTunnelStats() as part of (*localSite).periodicFunctions() +func (mockLocalSiteClient) GetNodes(_ context.Context, _ string, _ ...services.MarshalOption) ([]types.Server, error) { + return nil, nil +} + +type mockRemoteConnConn struct { + net.Conn +} + +// called for logging by (*remoteConn).markInvalid() +func (mockRemoteConnConn) RemoteAddr() net.Addr { return nil }