Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of reverse tunnel agent #365

Merged
merged 1 commit into from
Apr 17, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion integration/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ func SetTestTimeouts(ms int) {
ms = 10
}
testVal := time.Duration(time.Millisecond * time.Duration(ms))
defaults.ReverseTunnelsRefreshPeriod = testVal
defaults.ReverseTunnelAgentHeartbeatPeriod = testVal
defaults.ServerHeartbeatTTL = testVal
defaults.SessionRefreshPeriod = testVal
Expand Down
10 changes: 1 addition & 9 deletions lib/defaults/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,8 @@ const (
)

var (
// ReverseTunnelsRefreshPeriod is a period for agents to refresh their
// state of the reverse tunnels (this will be removed once we roll
// events streams)
ReverseTunnelsRefreshPeriod = 3 * time.Second

// ReverseTunnelAgentReconnectPeriod is the period between agent reconnect attempts
ReverseTunnelAgentReconnectPeriod = 3 * time.Second

// ReverseTunnelAgentHeartbeatPeriod is the period between agent heartbeat messages
ReverseTunnelAgentHeartbeatPeriod = 3 * time.Second
ReverseTunnelAgentHeartbeatPeriod = 5 * time.Second

// ServerHeartbeatTTL is a period between heartbeats
// Median sleep time between node pings is this value / 2 + random
Expand Down
220 changes: 64 additions & 156 deletions lib/reversetunnel/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,45 +104,10 @@ func (a *Agent) Close() error {

// Start starts agent that attempts to connect to remote server part
func (a *Agent) Start() error {
if err := a.reconnect(); err != nil {
return trace.Wrap(err)
}
go a.handleDisconnect()
return nil
}

func (a *Agent) handleDisconnect() {
a.log.Infof("handle disconnects")
for {
select {
case <-a.broadcastClose.C:
a.log.Infof("is closed, returning")
return
case <-a.disconnectC:
a.log.Infof("detected disconnect, reconnecting")
a.reconnect()
}
}
}

func (a *Agent) reconnect() error {
ticker := time.NewTicker(defaults.ReverseTunnelAgentReconnectPeriod)
defer ticker.Stop()
var err error
i := 0
for {
select {
case <-a.broadcastClose.C:
a.log.Infof("is closed, return")
return nil
case <-ticker.C:
if err = a.connect(); err != nil {
i++
continue
}
return nil
}
}
conn, err := a.connect()
// start heartbeat even if error happend, it will reconnect
go a.runHeartbeat(conn)
return err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trace.Wrap

}

// Wait waits until all outstanding operations are completed
Expand Down Expand Up @@ -180,79 +145,21 @@ func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.Pub
"no matching keys found when checking server's host signature")
}

func (a *Agent) connect() error {
func (a *Agent) connect() (conn *ssh.Client, err error) {
if a.addr.IsEmpty() {
err := trace.BadParameter("reverse tunnel cannot be created: target address is empty")
a.log.Error(err)
return err
return nil, trace.BadParameter("reverse tunnel cannot be created: target address is empty")
}

var c *ssh.Client
var err error
for _, authMethod := range a.authMethods {
c, err = ssh.Dial(a.addr.AddrNetwork, a.addr.Addr, &ssh.ClientConfig{
conn, err = ssh.Dial(a.addr.AddrNetwork, a.addr.Addr, &ssh.ClientConfig{
User: a.domainName,
Auth: []ssh.AuthMethod{authMethod},
HostKeyCallback: a.hostKeyCallback,
})
if c != nil {
if conn != nil {
break
}
}
if c == nil {
a.log.Errorf("connect err: %v", err)
return trace.Wrap(err)
}

go a.startHeartbeat(c)
go a.handleAccessPoint(c.HandleChannelOpen(chanAccessPoint))
go a.handleTransport(c.HandleChannelOpen(chanTransport))

return nil
}

func (a *Agent) handleAccessPoint(newC <-chan ssh.NewChannel) {
for {
var nch ssh.NewChannel
select {
case <-a.broadcastClose.C:
a.log.Infof("is closed, return")
return
case nch = <-newC:
if nch == nil {
a.log.Infof("connection closed, return")
return
}
}
a.log.Infof("got access point request: %v", nch.ChannelType())
ch, req, err := nch.Accept()
if err != nil {
a.log.Errorf("failed to accept request: %v", err)
}
go a.proxyAccessPoint(ch, req)
}
}

func (a *Agent) handleTransport(newC <-chan ssh.NewChannel) {
for {
var nch ssh.NewChannel
select {
case <-a.broadcastClose.C:
a.log.Infof("is closed, return")
return
case nch = <-newC:
if nch == nil {
a.log.Infof("connection closed, return")
return
}
}
a.log.Infof("got transport request: %v", nch.ChannelType())
ch, req, err := nch.Accept()
if err != nil {
a.log.Errorf("failed to accept request: %v", err)
}
go a.proxyTransport(ch, req)
}
return conn, err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trace.Wrap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these errors aren't returned, they're logged (with line number) and discarded.

}

func (a *Agent) proxyAccessPoint(ch ssh.Channel, req <-chan *ssh.Request) {
Expand Down Expand Up @@ -328,75 +235,76 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) {
wg.Wait()
}

func (a *Agent) startHeartbeat(conn ssh.Conn) {
defer func() {
a.disconnectC <- true
a.log.Infof("sent disconnect message")
}()

defer conn.Close()

hb, reqC, err := conn.OpenChannel(chanHeartbeat, nil)
if err != nil {
a.log.Errorf("failed to open channel: %v", err)
return
}

closeC := make(chan bool)
errC := make(chan error, 2)

ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod)
defer ticker.Stop()

go func() {
_, err := hb.SendRequest("ping", false, nil)
// runHeartbeat is a blocking function which runs in a loop sending heartbeats
// to the given SSH connection.
//
func (a *Agent) runHeartbeat(conn *ssh.Client) {
heartbeatLoop := func() error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be a separate function

if conn == nil {
return trace.Errorf("heartbeat cannot ping: need to reconnect")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trace.BadParameter

}
defer conn.Close()
hb, reqC, err := conn.OpenChannel(chanHeartbeat, nil)
if err != nil {
a.log.Errorf("failed to send heartbeat: %v", err)
errC <- err
return
return trace.Wrap(err)
}
newAccesspointC := conn.HandleChannelOpen(chanAccessPoint)
newTransportC := conn.HandleChannelOpen(chanTransport)

// send first ping right away, then start a ping timer:
hb.SendRequest("ping", false, nil)
ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod)
defer ticker.Stop()

for {
select {
// need to exit:
case <-a.broadcastClose.C:
a.log.Infof("agent is closing")
return
case <-closeC:
a.log.Infof("asked to exit")
return
return nil
// time to ping:
case <-ticker.C:
_, err := hb.SendRequest("ping", false, nil)
if err != nil {
a.log.Errorf("failed to send heartbeat: %v", err)
errC <- err
return
return trace.Wrap(err)
}
}
}
}()

go func() {
for {
select {
case <-a.broadcastClose.C:
a.log.Infof("agent is closing")
return
case <-closeC:
log.Infof("asked to exit")
return
// ssh channel closed:
case req := <-reqC:
if req == nil {
errC <- trace.Errorf("heartbeat: connection closed")
return
return trace.Errorf("heartbeat: connection closed")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trace.ConnectionProblem

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that function needs an error object + message. I only have message here.

}
// new access point request:
case nch := <-newAccesspointC:
if nch == nil {
continue
}
a.log.Infof("got out of band request: %v", req)
a.log.Infof("reverseTunnel.Agent: access point request: %v", nch.ChannelType())
ch, req, err := nch.Accept()
if err != nil {
a.log.Errorf("failed to accept request: %v", err)
continue
}
go a.proxyAccessPoint(ch, req)
// new transport request:
case nch := <-newTransportC:
if nch == nil {
continue
}
a.log.Infof("reverseTunnel.Agent: transport request: %v", nch.ChannelType())
ch, req, err := nch.Accept()
if err != nil {
a.log.Errorf("failed to accept request: %v", err)
continue
}
go a.proxyTransport(ch, req)
}
}
}()
}

a.log.Infof("got error: %v", <-errC)
(&sync.Once{}).Do(func() {
close(closeC)
})
err := heartbeatLoop()
if err != nil || conn == nil {
time.Sleep(defaults.ReverseTunnelAgentHeartbeatPeriod)
a.Start()
}
}

const (
Expand Down
3 changes: 1 addition & 2 deletions lib/reversetunnel/agentpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (m *AgentPool) FetchAndSyncAgents() error {
}

func (m *AgentPool) pollAndSyncAgents() {
ticker := time.NewTicker(defaults.ReverseTunnelsRefreshPeriod)
ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod)
defer ticker.Stop()
m.FetchAndSyncAgents()
for {
Expand Down Expand Up @@ -137,7 +137,6 @@ func (m *AgentPool) syncAgents(tunnels []services.ReverseTunnel) error {
return trace.Wrap(err)
}
agentsToAdd, agentsToRemove := diffTunnels(m.agents, keys)
//m.Infof("agents to add: %v agents to remove: %v", agentsToAdd, agentsToRemove)
for _, key := range agentsToRemove {
m.Infof("removing %v", &key)
agent := m.agents[key]
Expand Down
4 changes: 2 additions & 2 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -595,8 +595,8 @@ func (s *tunnelSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch
return
}
s.setLastActive(time.Now())
case <-time.After(2 * defaults.ReverseTunnelAgentHeartbeatPeriod):
conn.markInvalid(trace.ConnectionProblem(nil, "agent missed 2 heartbeats"))
case <-time.After(3 * defaults.ReverseTunnelAgentHeartbeatPeriod):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be a constant too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't. tests set it to 50ms to run faster.

conn.markInvalid(trace.ConnectionProblem(nil, "agent missed 3 heartbeats"))
conn.sshConn.Close()
}
}
Expand Down