diff --git a/integration/helpers.go b/integration/helpers.go index 54df7b2ede726..a432c76fdc9e9 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -29,7 +29,6 @@ func SetTestTimeouts(ms int) { ms = 10 } testVal := time.Duration(time.Millisecond * time.Duration(ms)) - defaults.ReverseTunnelsRefreshPeriod = testVal defaults.ReverseTunnelAgentHeartbeatPeriod = testVal defaults.ServerHeartbeatTTL = testVal defaults.SessionRefreshPeriod = testVal diff --git a/lib/defaults/defaults.go b/lib/defaults/defaults.go index 75337cc3162ef..b9348bd7bc919 100644 --- a/lib/defaults/defaults.go +++ b/lib/defaults/defaults.go @@ -117,16 +117,8 @@ const ( ) var ( - // ReverseTunnelsRefreshPeriod is a period for agents to refresh their - // state of the reverse tunnels (this will be removed once we roll - // events streams) - ReverseTunnelsRefreshPeriod = 3 * time.Second - - // ReverseTunnelAgentReconnectPeriod is the period between agent reconnect attempts - ReverseTunnelAgentReconnectPeriod = 3 * time.Second - // ReverseTunnelAgentHeartbeatPeriod is the period between agent heartbeat messages - ReverseTunnelAgentHeartbeatPeriod = 3 * time.Second + ReverseTunnelAgentHeartbeatPeriod = 5 * time.Second // ServerHeartbeatTTL is a period between heartbeats // Median sleep time between node pings is this value / 2 + random diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 1c16d7bb5d378..cbd4f278d116c 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -104,45 +104,10 @@ func (a *Agent) Close() error { // Start starts agent that attempts to connect to remote server part func (a *Agent) Start() error { - if err := a.reconnect(); err != nil { - return trace.Wrap(err) - } - go a.handleDisconnect() - return nil -} - -func (a *Agent) handleDisconnect() { - a.log.Infof("handle disconnects") - for { - select { - case <-a.broadcastClose.C: - a.log.Infof("is closed, returning") - return - case <-a.disconnectC: - a.log.Infof("detected disconnect, reconnecting") - a.reconnect() - } - } -} - -func (a *Agent) reconnect() error { - ticker := time.NewTicker(defaults.ReverseTunnelAgentReconnectPeriod) - defer ticker.Stop() - var err error - i := 0 - for { - select { - case <-a.broadcastClose.C: - a.log.Infof("is closed, return") - return nil - case <-ticker.C: - if err = a.connect(); err != nil { - i++ - continue - } - return nil - } - } + conn, err := a.connect() + // start heartbeat even if error happend, it will reconnect + go a.runHeartbeat(conn) + return err } // Wait waits until all outstanding operations are completed @@ -180,79 +145,21 @@ func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.Pub "no matching keys found when checking server's host signature") } -func (a *Agent) connect() error { +func (a *Agent) connect() (conn *ssh.Client, err error) { if a.addr.IsEmpty() { - err := trace.BadParameter("reverse tunnel cannot be created: target address is empty") - a.log.Error(err) - return err + return nil, trace.BadParameter("reverse tunnel cannot be created: target address is empty") } - - var c *ssh.Client - var err error for _, authMethod := range a.authMethods { - c, err = ssh.Dial(a.addr.AddrNetwork, a.addr.Addr, &ssh.ClientConfig{ + conn, err = ssh.Dial(a.addr.AddrNetwork, a.addr.Addr, &ssh.ClientConfig{ User: a.domainName, Auth: []ssh.AuthMethod{authMethod}, HostKeyCallback: a.hostKeyCallback, }) - if c != nil { + if conn != nil { break } } - if c == nil { - a.log.Errorf("connect err: %v", err) - return trace.Wrap(err) - } - - go a.startHeartbeat(c) - go a.handleAccessPoint(c.HandleChannelOpen(chanAccessPoint)) - go a.handleTransport(c.HandleChannelOpen(chanTransport)) - - return nil -} - -func (a *Agent) handleAccessPoint(newC <-chan ssh.NewChannel) { - for { - var nch ssh.NewChannel - select { - case <-a.broadcastClose.C: - a.log.Infof("is closed, return") - return - case nch = <-newC: - if nch == nil { - a.log.Infof("connection closed, return") - return - } - } - a.log.Infof("got access point request: %v", nch.ChannelType()) - ch, req, err := nch.Accept() - if err != nil { - a.log.Errorf("failed to accept request: %v", err) - } - go a.proxyAccessPoint(ch, req) - } -} - -func (a *Agent) handleTransport(newC <-chan ssh.NewChannel) { - for { - var nch ssh.NewChannel - select { - case <-a.broadcastClose.C: - a.log.Infof("is closed, return") - return - case nch = <-newC: - if nch == nil { - a.log.Infof("connection closed, return") - return - } - } - a.log.Infof("got transport request: %v", nch.ChannelType()) - ch, req, err := nch.Accept() - if err != nil { - a.log.Errorf("failed to accept request: %v", err) - } - go a.proxyTransport(ch, req) - } + return conn, err } func (a *Agent) proxyAccessPoint(ch ssh.Channel, req <-chan *ssh.Request) { @@ -328,75 +235,76 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { wg.Wait() } -func (a *Agent) startHeartbeat(conn ssh.Conn) { - defer func() { - a.disconnectC <- true - a.log.Infof("sent disconnect message") - }() - - defer conn.Close() - - hb, reqC, err := conn.OpenChannel(chanHeartbeat, nil) - if err != nil { - a.log.Errorf("failed to open channel: %v", err) - return - } - - closeC := make(chan bool) - errC := make(chan error, 2) - - ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod) - defer ticker.Stop() - - go func() { - _, err := hb.SendRequest("ping", false, nil) +// runHeartbeat is a blocking function which runs in a loop sending heartbeats +// to the given SSH connection. +// +func (a *Agent) runHeartbeat(conn *ssh.Client) { + heartbeatLoop := func() error { + if conn == nil { + return trace.Errorf("heartbeat cannot ping: need to reconnect") + } + defer conn.Close() + hb, reqC, err := conn.OpenChannel(chanHeartbeat, nil) if err != nil { - a.log.Errorf("failed to send heartbeat: %v", err) - errC <- err - return + return trace.Wrap(err) } + newAccesspointC := conn.HandleChannelOpen(chanAccessPoint) + newTransportC := conn.HandleChannelOpen(chanTransport) + + // send first ping right away, then start a ping timer: + hb.SendRequest("ping", false, nil) + ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod) + defer ticker.Stop() + for { select { + // need to exit: case <-a.broadcastClose.C: - a.log.Infof("agent is closing") - return - case <-closeC: - a.log.Infof("asked to exit") - return + return nil + // time to ping: case <-ticker.C: _, err := hb.SendRequest("ping", false, nil) if err != nil { - a.log.Errorf("failed to send heartbeat: %v", err) - errC <- err - return + return trace.Wrap(err) } - } - } - }() - - go func() { - for { - select { - case <-a.broadcastClose.C: - a.log.Infof("agent is closing") - return - case <-closeC: - log.Infof("asked to exit") - return + // ssh channel closed: case req := <-reqC: if req == nil { - errC <- trace.Errorf("heartbeat: connection closed") - return + return trace.Errorf("heartbeat: connection closed") + } + // new access point request: + case nch := <-newAccesspointC: + if nch == nil { + continue } - a.log.Infof("got out of band request: %v", req) + a.log.Infof("reverseTunnel.Agent: access point request: %v", nch.ChannelType()) + ch, req, err := nch.Accept() + if err != nil { + a.log.Errorf("failed to accept request: %v", err) + continue + } + go a.proxyAccessPoint(ch, req) + // new transport request: + case nch := <-newTransportC: + if nch == nil { + continue + } + a.log.Infof("reverseTunnel.Agent: transport request: %v", nch.ChannelType()) + ch, req, err := nch.Accept() + if err != nil { + a.log.Errorf("failed to accept request: %v", err) + continue + } + go a.proxyTransport(ch, req) } } - }() + } - a.log.Infof("got error: %v", <-errC) - (&sync.Once{}).Do(func() { - close(closeC) - }) + err := heartbeatLoop() + if err != nil || conn == nil { + time.Sleep(defaults.ReverseTunnelAgentHeartbeatPeriod) + a.Start() + } } const ( diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index f69c53bff6223..4aa7dd61a961f 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -105,7 +105,7 @@ func (m *AgentPool) FetchAndSyncAgents() error { } func (m *AgentPool) pollAndSyncAgents() { - ticker := time.NewTicker(defaults.ReverseTunnelsRefreshPeriod) + ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod) defer ticker.Stop() m.FetchAndSyncAgents() for { @@ -137,7 +137,6 @@ func (m *AgentPool) syncAgents(tunnels []services.ReverseTunnel) error { return trace.Wrap(err) } agentsToAdd, agentsToRemove := diffTunnels(m.agents, keys) - //m.Infof("agents to add: %v agents to remove: %v", agentsToAdd, agentsToRemove) for _, key := range agentsToRemove { m.Infof("removing %v", &key) agent := m.agents[key] diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 2aaa78703582b..0348bed4ecca2 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -595,8 +595,8 @@ func (s *tunnelSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch return } s.setLastActive(time.Now()) - case <-time.After(2 * defaults.ReverseTunnelAgentHeartbeatPeriod): - conn.markInvalid(trace.ConnectionProblem(nil, "agent missed 2 heartbeats")) + case <-time.After(3 * defaults.ReverseTunnelAgentHeartbeatPeriod): + conn.markInvalid(trace.ConnectionProblem(nil, "agent missed 3 heartbeats")) conn.sshConn.Close() } }