Skip to content

Commit

Permalink
Merge pull request #365 from gravitational/ev/tk
Browse files Browse the repository at this point in the history
Refactoring of reverse tunnel agent
  • Loading branch information
klizhentas committed Apr 17, 2016
2 parents a612f20 + 8cb2716 commit 48d35b2
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 170 deletions.
1 change: 0 additions & 1 deletion integration/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ func SetTestTimeouts(ms int) {
ms = 10
}
testVal := time.Duration(time.Millisecond * time.Duration(ms))
defaults.ReverseTunnelsRefreshPeriod = testVal
defaults.ReverseTunnelAgentHeartbeatPeriod = testVal
defaults.ServerHeartbeatTTL = testVal
defaults.SessionRefreshPeriod = testVal
Expand Down
10 changes: 1 addition & 9 deletions lib/defaults/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,8 @@ const (
)

var (
// ReverseTunnelsRefreshPeriod is a period for agents to refresh their
// state of the reverse tunnels (this will be removed once we roll
// events streams)
ReverseTunnelsRefreshPeriod = 3 * time.Second

// ReverseTunnelAgentReconnectPeriod is the period between agent reconnect attempts
ReverseTunnelAgentReconnectPeriod = 3 * time.Second

// ReverseTunnelAgentHeartbeatPeriod is the period between agent heartbeat messages
ReverseTunnelAgentHeartbeatPeriod = 3 * time.Second
ReverseTunnelAgentHeartbeatPeriod = 5 * time.Second

// ServerHeartbeatTTL is a period between heartbeats
// Median sleep time between node pings is this value / 2 + random
Expand Down
220 changes: 64 additions & 156 deletions lib/reversetunnel/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,45 +104,10 @@ func (a *Agent) Close() error {

// Start starts agent that attempts to connect to remote server part
func (a *Agent) Start() error {
if err := a.reconnect(); err != nil {
return trace.Wrap(err)
}
go a.handleDisconnect()
return nil
}

func (a *Agent) handleDisconnect() {
a.log.Infof("handle disconnects")
for {
select {
case <-a.broadcastClose.C:
a.log.Infof("is closed, returning")
return
case <-a.disconnectC:
a.log.Infof("detected disconnect, reconnecting")
a.reconnect()
}
}
}

func (a *Agent) reconnect() error {
ticker := time.NewTicker(defaults.ReverseTunnelAgentReconnectPeriod)
defer ticker.Stop()
var err error
i := 0
for {
select {
case <-a.broadcastClose.C:
a.log.Infof("is closed, return")
return nil
case <-ticker.C:
if err = a.connect(); err != nil {
i++
continue
}
return nil
}
}
conn, err := a.connect()
// start heartbeat even if error happend, it will reconnect
go a.runHeartbeat(conn)
return err
}

// Wait waits until all outstanding operations are completed
Expand Down Expand Up @@ -180,79 +145,21 @@ func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.Pub
"no matching keys found when checking server's host signature")
}

func (a *Agent) connect() error {
func (a *Agent) connect() (conn *ssh.Client, err error) {
if a.addr.IsEmpty() {
err := trace.BadParameter("reverse tunnel cannot be created: target address is empty")
a.log.Error(err)
return err
return nil, trace.BadParameter("reverse tunnel cannot be created: target address is empty")
}

var c *ssh.Client
var err error
for _, authMethod := range a.authMethods {
c, err = ssh.Dial(a.addr.AddrNetwork, a.addr.Addr, &ssh.ClientConfig{
conn, err = ssh.Dial(a.addr.AddrNetwork, a.addr.Addr, &ssh.ClientConfig{
User: a.domainName,
Auth: []ssh.AuthMethod{authMethod},
HostKeyCallback: a.hostKeyCallback,
})
if c != nil {
if conn != nil {
break
}
}
if c == nil {
a.log.Errorf("connect err: %v", err)
return trace.Wrap(err)
}

go a.startHeartbeat(c)
go a.handleAccessPoint(c.HandleChannelOpen(chanAccessPoint))
go a.handleTransport(c.HandleChannelOpen(chanTransport))

return nil
}

func (a *Agent) handleAccessPoint(newC <-chan ssh.NewChannel) {
for {
var nch ssh.NewChannel
select {
case <-a.broadcastClose.C:
a.log.Infof("is closed, return")
return
case nch = <-newC:
if nch == nil {
a.log.Infof("connection closed, return")
return
}
}
a.log.Infof("got access point request: %v", nch.ChannelType())
ch, req, err := nch.Accept()
if err != nil {
a.log.Errorf("failed to accept request: %v", err)
}
go a.proxyAccessPoint(ch, req)
}
}

func (a *Agent) handleTransport(newC <-chan ssh.NewChannel) {
for {
var nch ssh.NewChannel
select {
case <-a.broadcastClose.C:
a.log.Infof("is closed, return")
return
case nch = <-newC:
if nch == nil {
a.log.Infof("connection closed, return")
return
}
}
a.log.Infof("got transport request: %v", nch.ChannelType())
ch, req, err := nch.Accept()
if err != nil {
a.log.Errorf("failed to accept request: %v", err)
}
go a.proxyTransport(ch, req)
}
return conn, err
}

func (a *Agent) proxyAccessPoint(ch ssh.Channel, req <-chan *ssh.Request) {
Expand Down Expand Up @@ -328,75 +235,76 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) {
wg.Wait()
}

func (a *Agent) startHeartbeat(conn ssh.Conn) {
defer func() {
a.disconnectC <- true
a.log.Infof("sent disconnect message")
}()

defer conn.Close()

hb, reqC, err := conn.OpenChannel(chanHeartbeat, nil)
if err != nil {
a.log.Errorf("failed to open channel: %v", err)
return
}

closeC := make(chan bool)
errC := make(chan error, 2)

ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod)
defer ticker.Stop()

go func() {
_, err := hb.SendRequest("ping", false, nil)
// runHeartbeat is a blocking function which runs in a loop sending heartbeats
// to the given SSH connection.
//
func (a *Agent) runHeartbeat(conn *ssh.Client) {
heartbeatLoop := func() error {
if conn == nil {
return trace.Errorf("heartbeat cannot ping: need to reconnect")
}
defer conn.Close()
hb, reqC, err := conn.OpenChannel(chanHeartbeat, nil)
if err != nil {
a.log.Errorf("failed to send heartbeat: %v", err)
errC <- err
return
return trace.Wrap(err)
}
newAccesspointC := conn.HandleChannelOpen(chanAccessPoint)
newTransportC := conn.HandleChannelOpen(chanTransport)

// send first ping right away, then start a ping timer:
hb.SendRequest("ping", false, nil)
ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod)
defer ticker.Stop()

for {
select {
// need to exit:
case <-a.broadcastClose.C:
a.log.Infof("agent is closing")
return
case <-closeC:
a.log.Infof("asked to exit")
return
return nil
// time to ping:
case <-ticker.C:
_, err := hb.SendRequest("ping", false, nil)
if err != nil {
a.log.Errorf("failed to send heartbeat: %v", err)
errC <- err
return
return trace.Wrap(err)
}
}
}
}()

go func() {
for {
select {
case <-a.broadcastClose.C:
a.log.Infof("agent is closing")
return
case <-closeC:
log.Infof("asked to exit")
return
// ssh channel closed:
case req := <-reqC:
if req == nil {
errC <- trace.Errorf("heartbeat: connection closed")
return
return trace.Errorf("heartbeat: connection closed")
}
// new access point request:
case nch := <-newAccesspointC:
if nch == nil {
continue
}
a.log.Infof("got out of band request: %v", req)
a.log.Infof("reverseTunnel.Agent: access point request: %v", nch.ChannelType())
ch, req, err := nch.Accept()
if err != nil {
a.log.Errorf("failed to accept request: %v", err)
continue
}
go a.proxyAccessPoint(ch, req)
// new transport request:
case nch := <-newTransportC:
if nch == nil {
continue
}
a.log.Infof("reverseTunnel.Agent: transport request: %v", nch.ChannelType())
ch, req, err := nch.Accept()
if err != nil {
a.log.Errorf("failed to accept request: %v", err)
continue
}
go a.proxyTransport(ch, req)
}
}
}()
}

a.log.Infof("got error: %v", <-errC)
(&sync.Once{}).Do(func() {
close(closeC)
})
err := heartbeatLoop()
if err != nil || conn == nil {
time.Sleep(defaults.ReverseTunnelAgentHeartbeatPeriod)
a.Start()
}
}

const (
Expand Down
3 changes: 1 addition & 2 deletions lib/reversetunnel/agentpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (m *AgentPool) FetchAndSyncAgents() error {
}

func (m *AgentPool) pollAndSyncAgents() {
ticker := time.NewTicker(defaults.ReverseTunnelsRefreshPeriod)
ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod)
defer ticker.Stop()
m.FetchAndSyncAgents()
for {
Expand Down Expand Up @@ -137,7 +137,6 @@ func (m *AgentPool) syncAgents(tunnels []services.ReverseTunnel) error {
return trace.Wrap(err)
}
agentsToAdd, agentsToRemove := diffTunnels(m.agents, keys)
//m.Infof("agents to add: %v agents to remove: %v", agentsToAdd, agentsToRemove)
for _, key := range agentsToRemove {
m.Infof("removing %v", &key)
agent := m.agents[key]
Expand Down
4 changes: 2 additions & 2 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -595,8 +595,8 @@ func (s *tunnelSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch
return
}
s.setLastActive(time.Now())
case <-time.After(2 * defaults.ReverseTunnelAgentHeartbeatPeriod):
conn.markInvalid(trace.ConnectionProblem(nil, "agent missed 2 heartbeats"))
case <-time.After(3 * defaults.ReverseTunnelAgentHeartbeatPeriod):
conn.markInvalid(trace.ConnectionProblem(nil, "agent missed 3 heartbeats"))
conn.sshConn.Close()
}
}
Expand Down

0 comments on commit 48d35b2

Please sign in to comment.