Skip to content

Commit

Permalink
Keep multiple per-node remoteConns in localSite (#11074) (#11186)
Browse files Browse the repository at this point in the history
  • Loading branch information
espadolini authored Mar 18, 2022
1 parent ef8d60f commit 084455c
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 34 deletions.
10 changes: 0 additions & 10 deletions lib/reversetunnel/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package reversetunnel

import (
"context"
"fmt"
"net"
"sync"
Expand Down Expand Up @@ -69,11 +68,6 @@ type remoteConn struct {
// Used to make sure calling Close on the connection multiple times is safe.
closed int32

// closeContext and closeCancel are used to signal to any waiting goroutines
// that the remoteConn is now closed and to release any resources.
closeContext context.Context
closeCancel context.CancelFunc

// clock is used to control time in tests.
clock clockwork.Clock

Expand Down Expand Up @@ -120,8 +114,6 @@ func newRemoteConn(cfg *connConfig) *remoteConn {
newProxiesC: make(chan []types.Server, 100),
}

c.closeContext, c.closeCancel = context.WithCancel(context.Background())

return c
}

Expand All @@ -130,8 +122,6 @@ func (c *remoteConn) String() string {
}

func (c *remoteConn) Close() error {
defer c.closeCancel()

// If the connection has already been closed, return right away.
if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
return nil
Expand Down
62 changes: 38 additions & 24 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSi
accessPoint: accessPoint,
certificateCache: certificateCache,
domainName: domainName,
remoteConns: make(map[connKey]*remoteConn),
remoteConns: make(map[connKey][]*remoteConn),
clock: srv.Clock,
log: log.WithFields(log.Fields{
trace.Component: teleport.ComponentReverseTunnelServer,
Expand All @@ -89,8 +89,6 @@ func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSi
//
// it implements RemoteSite interface
type localSite struct {
sync.Mutex

log log.FieldLogger
domainName string
srv *server
Expand All @@ -104,8 +102,11 @@ type localSite struct {
// certificateCache caches host certificates for the forwarding server.
certificateCache *certificateCache

// remoteConns maps UUID and connection type to an remote connection.
remoteConns map[connKey]*remoteConn
// remoteConns maps UUID and connection type to remote connections, oldest to newest.
remoteConns map[connKey][]*remoteConn

// remoteConnsMtx protects remoteConns.
remoteConnsMtx sync.Mutex

// clock is used to control time in tests.
clock clockwork.Clock
Expand All @@ -117,8 +118,8 @@ type localSite struct {

// GetTunnelsCount always the number of tunnel connections to this cluster.
func (s *localSite) GetTunnelsCount() int {
s.Lock()
defer s.Unlock()
s.remoteConnsMtx.Lock()
defer s.remoteConnsMtx.Unlock()

return len(s.remoteConns)
}
Expand Down Expand Up @@ -346,8 +347,8 @@ with the cluster.`, params.ConnType, dreq.Address, tunnelErr)
}

func (s *localSite) addConn(nodeID string, connType types.TunnelType, conn net.Conn, sconn ssh.Conn) (*remoteConn, error) {
s.Lock()
defer s.Unlock()
s.remoteConnsMtx.Lock()
defer s.remoteConnsMtx.Unlock()

rconn := newRemoteConn(&connConfig{
conn: conn,
Expand All @@ -363,7 +364,7 @@ func (s *localSite) addConn(nodeID string, connType types.TunnelType, conn net.C
uuid: nodeID,
connType: connType,
}
s.remoteConns[key] = rconn
s.remoteConns[key] = append(s.remoteConns[key], rconn)

return rconn, nil
}
Expand All @@ -372,10 +373,13 @@ func (s *localSite) addConn(nodeID string, connType types.TunnelType, conn net.C
// list so that remote connection can notify the remote agent
// about the list update
func (s *localSite) fanOutProxies(proxies []types.Server) {
s.Lock()
defer s.Unlock()
for _, conn := range s.remoteConns {
conn.updateProxies(proxies)
s.remoteConnsMtx.Lock()
defer s.remoteConnsMtx.Unlock()

for _, conns := range s.remoteConns {
for _, conn := range conns {
conn.updateProxies(proxies)
}
}
}

Expand Down Expand Up @@ -442,30 +446,40 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch
}

func (s *localSite) getRemoteConn(dreq *sshutils.DialReq) (*remoteConn, error) {
s.Lock()
defer s.Unlock()
s.remoteConnsMtx.Lock()
defer s.remoteConnsMtx.Unlock()

// Loop over all connections and remove and invalid connections from the
// connection map.
for key := range s.remoteConns {
if s.remoteConns[key].isInvalid() {
for key, conns := range s.remoteConns {
validConns := conns[:0]
for _, conn := range conns {
if !conn.isInvalid() {
validConns = append(validConns, conn)
}
}
if len(validConns) == 0 {
delete(s.remoteConns, key)
} else {
s.remoteConns[key] = validConns
}
}

key := connKey{
uuid: dreq.ServerID,
connType: dreq.ConnType,
}
rconn, ok := s.remoteConns[key]
if !ok {
if len(s.remoteConns[key]) == 0 {
return nil, trace.NotFound("no %v reverse tunnel for %v found", dreq.ConnType, dreq.ServerID)
}
if !rconn.isReady() {
return nil, trace.NotFound("%v is offline: no active %v tunnels found", dreq.ConnType, dreq.ServerID)
}

return rconn, nil
conns := s.remoteConns[key]
for i := len(conns) - 1; i >= 0; i-- {
if conns[i].isReady() {
return conns[i], nil
}
}
return nil, trace.NotFound("%v is offline: no active %v tunnels found", dreq.ConnType, dreq.ServerID)
}

func (s *localSite) chanTransportConn(rconn *remoteConn, dreq *sshutils.DialReq) (net.Conn, error) {
Expand Down
99 changes: 99 additions & 0 deletions lib/reversetunnel/localsite_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright 2022 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package reversetunnel

import (
"context"
"net"
"testing"
"time"

"github.com/google/uuid"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)

func TestLocalSiteOverlap(t *testing.T) {
t.Parallel()

// to stop (*localSite).periodicFunctions()
ctx, ctxCancel := context.WithCancel(context.Background())
ctxCancel()

srv := &server{
ctx: ctx,
newAccessPoint: auth.NoCache,
}

site, err := newlocalSite(srv, "clustername", &mockLocalSiteClient{})
require.NoError(t, err)

nodeID := uuid.NewString()
connType := types.NodeTunnel
dreq := &sshutils.DialReq{
ServerID: nodeID,
ConnType: connType,
}

conn1, err := site.addConn(nodeID, connType, mockRemoteConnConn{}, nil)
require.NoError(t, err)

conn2, err := site.addConn(nodeID, connType, mockRemoteConnConn{}, nil)
require.NoError(t, err)

c, err := site.getRemoteConn(dreq)
require.True(t, trace.IsNotFound(err))
require.Nil(t, c)

conn1.setLastHeartbeat(time.Now())
c, err = site.getRemoteConn(dreq)
require.NoError(t, err)
require.Equal(t, conn1, c)

conn2.setLastHeartbeat(time.Now())
c, err = site.getRemoteConn(dreq)
require.NoError(t, err)
require.Equal(t, conn2, c)

conn2.markInvalid(nil)
c, err = site.getRemoteConn(dreq)
require.NoError(t, err)
require.Equal(t, conn1, c)

conn1.markInvalid(nil)
c, err = site.getRemoteConn(dreq)
require.True(t, trace.IsNotFound(err))
require.Nil(t, c)
}

type mockLocalSiteClient struct {
auth.Client
}

// called by (*localSite).sshTunnelStats() as part of (*localSite).periodicFunctions()
func (mockLocalSiteClient) GetNodes(_ context.Context, _ string, _ ...services.MarshalOption) ([]types.Server, error) {
return nil, nil
}

type mockRemoteConnConn struct {
net.Conn
}

// called for logging by (*remoteConn).markInvalid()
func (mockRemoteConnConn) RemoteAddr() net.Addr { return nil }

0 comments on commit 084455c

Please sign in to comment.