Skip to content

Commit

Permalink
Keep multiple per-node remoteConns in localSite
Browse files Browse the repository at this point in the history
  • Loading branch information
espadolini committed Mar 15, 2022
1 parent d83886e commit 4d464c3
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 34 deletions.
10 changes: 0 additions & 10 deletions lib/reversetunnel/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package reversetunnel

import (
"context"
"fmt"
"net"
"sync"
Expand Down Expand Up @@ -67,11 +66,6 @@ type remoteConn struct {
// Used to make sure calling Close on the connection multiple times is safe.
closed int32

// closeContext and closeCancel are used to signal to any waiting goroutines
// that the remoteConn is now closed and to release any resources.
closeContext context.Context
closeCancel context.CancelFunc

// clock is used to control time in tests.
clock clockwork.Clock

Expand Down Expand Up @@ -115,8 +109,6 @@ func newRemoteConn(cfg *connConfig) *remoteConn {
newProxiesC: make(chan []types.Server, 100),
}

c.closeContext, c.closeCancel = context.WithCancel(context.Background())

return c
}

Expand All @@ -125,8 +117,6 @@ func (c *remoteConn) String() string {
}

func (c *remoteConn) Close() error {
defer c.closeCancel()

// If the connection has already been closed, return right away.
if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
return nil
Expand Down
62 changes: 38 additions & 24 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSi
accessPoint: accessPoint,
certificateCache: certificateCache,
domainName: domainName,
remoteConns: make(map[connKey]*remoteConn),
remoteConns: make(map[connKey][]*remoteConn),
clock: srv.Clock,
log: log.WithFields(log.Fields{
trace.Component: teleport.ComponentReverseTunnelServer,
Expand All @@ -89,8 +89,6 @@ func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSi
//
// it implements RemoteSite interface
type localSite struct {
sync.Mutex

log log.FieldLogger
domainName string
srv *server
Expand All @@ -104,8 +102,11 @@ type localSite struct {
// certificateCache caches host certificates for the forwarding server.
certificateCache *certificateCache

// remoteConns maps UUID and connection type to an remote connection.
remoteConns map[connKey]*remoteConn
// remoteConns maps UUID and connection type to remote connections, oldest to newest.
remoteConns map[connKey][]*remoteConn

// remoteConnsMtx protects remoteConns.
remoteConnsMtx sync.Mutex

// clock is used to control time in tests.
clock clockwork.Clock
Expand All @@ -117,8 +118,8 @@ type localSite struct {

// GetTunnelsCount always the number of tunnel connections to this cluster.
func (s *localSite) GetTunnelsCount() int {
s.Lock()
defer s.Unlock()
s.remoteConnsMtx.Lock()
defer s.remoteConnsMtx.Unlock()

return len(s.remoteConns)
}
Expand Down Expand Up @@ -349,8 +350,8 @@ with the cluster.`
}

func (s *localSite) addConn(nodeID string, connType types.TunnelType, conn net.Conn, sconn ssh.Conn) (*remoteConn, error) {
s.Lock()
defer s.Unlock()
s.remoteConnsMtx.Lock()
defer s.remoteConnsMtx.Unlock()

rconn := newRemoteConn(&connConfig{
conn: conn,
Expand All @@ -365,7 +366,7 @@ func (s *localSite) addConn(nodeID string, connType types.TunnelType, conn net.C
uuid: nodeID,
connType: connType,
}
s.remoteConns[key] = rconn
s.remoteConns[key] = append(s.remoteConns[key], rconn)

return rconn, nil
}
Expand All @@ -374,10 +375,13 @@ func (s *localSite) addConn(nodeID string, connType types.TunnelType, conn net.C
// list so that remote connection can notify the remote agent
// about the list update
func (s *localSite) fanOutProxies(proxies []types.Server) {
s.Lock()
defer s.Unlock()
for _, conn := range s.remoteConns {
conn.updateProxies(proxies)
s.remoteConnsMtx.Lock()
defer s.remoteConnsMtx.Unlock()

for _, conns := range s.remoteConns {
for _, conn := range conns {
conn.updateProxies(proxies)
}
}
}

Expand Down Expand Up @@ -446,30 +450,40 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch
}

func (s *localSite) getRemoteConn(dreq *sshutils.DialReq) (*remoteConn, error) {
s.Lock()
defer s.Unlock()
s.remoteConnsMtx.Lock()
defer s.remoteConnsMtx.Unlock()

// Loop over all connections and remove and invalid connections from the
// connection map.
for key := range s.remoteConns {
if s.remoteConns[key].isInvalid() {
for key, conns := range s.remoteConns {
validConns := conns[:0]
for _, conn := range conns {
if !conn.isInvalid() {
validConns = append(validConns, conn)
}
}
if len(validConns) == 0 {
delete(s.remoteConns, key)
} else {
s.remoteConns[key] = validConns
}
}

key := connKey{
uuid: dreq.ServerID,
connType: dreq.ConnType,
}
rconn, ok := s.remoteConns[key]
if !ok {
if len(s.remoteConns[key]) == 0 {
return nil, trace.NotFound("no %v reverse tunnel for %v found", dreq.ConnType, dreq.ServerID)
}
if !rconn.isReady() {
return nil, trace.NotFound("%v is offline: no active %v tunnels found", dreq.ConnType, dreq.ServerID)
}

return rconn, nil
conns := s.remoteConns[key]
for i := len(conns) - 1; i >= 0; i-- {
if conns[i].isReady() {
return conns[i], nil
}
}
return nil, trace.NotFound("%v is offline: no active %v tunnels found", dreq.ConnType, dreq.ServerID)
}

func (s *localSite) chanTransportConn(rconn *remoteConn, dreq *sshutils.DialReq) (net.Conn, error) {
Expand Down

0 comments on commit 4d464c3

Please sign in to comment.