-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring of reverse tunnel agent #365
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,45 +104,10 @@ func (a *Agent) Close() error { | |
|
||
// Start starts agent that attempts to connect to remote server part | ||
func (a *Agent) Start() error { | ||
if err := a.reconnect(); err != nil { | ||
return trace.Wrap(err) | ||
} | ||
go a.handleDisconnect() | ||
return nil | ||
} | ||
|
||
func (a *Agent) handleDisconnect() { | ||
a.log.Infof("handle disconnects") | ||
for { | ||
select { | ||
case <-a.broadcastClose.C: | ||
a.log.Infof("is closed, returning") | ||
return | ||
case <-a.disconnectC: | ||
a.log.Infof("detected disconnect, reconnecting") | ||
a.reconnect() | ||
} | ||
} | ||
} | ||
|
||
func (a *Agent) reconnect() error { | ||
ticker := time.NewTicker(defaults.ReverseTunnelAgentReconnectPeriod) | ||
defer ticker.Stop() | ||
var err error | ||
i := 0 | ||
for { | ||
select { | ||
case <-a.broadcastClose.C: | ||
a.log.Infof("is closed, return") | ||
return nil | ||
case <-ticker.C: | ||
if err = a.connect(); err != nil { | ||
i++ | ||
continue | ||
} | ||
return nil | ||
} | ||
} | ||
conn, err := a.connect() | ||
// start heartbeat even if error happend, it will reconnect | ||
go a.runHeartbeat(conn) | ||
return err | ||
} | ||
|
||
// Wait waits until all outstanding operations are completed | ||
|
@@ -180,79 +145,21 @@ func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.Pub | |
"no matching keys found when checking server's host signature") | ||
} | ||
|
||
func (a *Agent) connect() error { | ||
func (a *Agent) connect() (conn *ssh.Client, err error) { | ||
if a.addr.IsEmpty() { | ||
err := trace.BadParameter("reverse tunnel cannot be created: target address is empty") | ||
a.log.Error(err) | ||
return err | ||
return nil, trace.BadParameter("reverse tunnel cannot be created: target address is empty") | ||
} | ||
|
||
var c *ssh.Client | ||
var err error | ||
for _, authMethod := range a.authMethods { | ||
c, err = ssh.Dial(a.addr.AddrNetwork, a.addr.Addr, &ssh.ClientConfig{ | ||
conn, err = ssh.Dial(a.addr.AddrNetwork, a.addr.Addr, &ssh.ClientConfig{ | ||
User: a.domainName, | ||
Auth: []ssh.AuthMethod{authMethod}, | ||
HostKeyCallback: a.hostKeyCallback, | ||
}) | ||
if c != nil { | ||
if conn != nil { | ||
break | ||
} | ||
} | ||
if c == nil { | ||
a.log.Errorf("connect err: %v", err) | ||
return trace.Wrap(err) | ||
} | ||
|
||
go a.startHeartbeat(c) | ||
go a.handleAccessPoint(c.HandleChannelOpen(chanAccessPoint)) | ||
go a.handleTransport(c.HandleChannelOpen(chanTransport)) | ||
|
||
return nil | ||
} | ||
|
||
func (a *Agent) handleAccessPoint(newC <-chan ssh.NewChannel) { | ||
for { | ||
var nch ssh.NewChannel | ||
select { | ||
case <-a.broadcastClose.C: | ||
a.log.Infof("is closed, return") | ||
return | ||
case nch = <-newC: | ||
if nch == nil { | ||
a.log.Infof("connection closed, return") | ||
return | ||
} | ||
} | ||
a.log.Infof("got access point request: %v", nch.ChannelType()) | ||
ch, req, err := nch.Accept() | ||
if err != nil { | ||
a.log.Errorf("failed to accept request: %v", err) | ||
} | ||
go a.proxyAccessPoint(ch, req) | ||
} | ||
} | ||
|
||
func (a *Agent) handleTransport(newC <-chan ssh.NewChannel) { | ||
for { | ||
var nch ssh.NewChannel | ||
select { | ||
case <-a.broadcastClose.C: | ||
a.log.Infof("is closed, return") | ||
return | ||
case nch = <-newC: | ||
if nch == nil { | ||
a.log.Infof("connection closed, return") | ||
return | ||
} | ||
} | ||
a.log.Infof("got transport request: %v", nch.ChannelType()) | ||
ch, req, err := nch.Accept() | ||
if err != nil { | ||
a.log.Errorf("failed to accept request: %v", err) | ||
} | ||
go a.proxyTransport(ch, req) | ||
} | ||
return conn, err | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. trace.Wrap There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these errors aren't returned, they're logged (with line number) and discarded. |
||
} | ||
|
||
func (a *Agent) proxyAccessPoint(ch ssh.Channel, req <-chan *ssh.Request) { | ||
|
@@ -328,75 +235,76 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { | |
wg.Wait() | ||
} | ||
|
||
func (a *Agent) startHeartbeat(conn ssh.Conn) { | ||
defer func() { | ||
a.disconnectC <- true | ||
a.log.Infof("sent disconnect message") | ||
}() | ||
|
||
defer conn.Close() | ||
|
||
hb, reqC, err := conn.OpenChannel(chanHeartbeat, nil) | ||
if err != nil { | ||
a.log.Errorf("failed to open channel: %v", err) | ||
return | ||
} | ||
|
||
closeC := make(chan bool) | ||
errC := make(chan error, 2) | ||
|
||
ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod) | ||
defer ticker.Stop() | ||
|
||
go func() { | ||
_, err := hb.SendRequest("ping", false, nil) | ||
// runHeartbeat is a blocking function which runs in a loop sending heartbeats | ||
// to the given SSH connection. | ||
// | ||
func (a *Agent) runHeartbeat(conn *ssh.Client) { | ||
heartbeatLoop := func() error { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be a separate function |
||
if conn == nil { | ||
return trace.Errorf("heartbeat cannot ping: need to reconnect") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} | ||
defer conn.Close() | ||
hb, reqC, err := conn.OpenChannel(chanHeartbeat, nil) | ||
if err != nil { | ||
a.log.Errorf("failed to send heartbeat: %v", err) | ||
errC <- err | ||
return | ||
return trace.Wrap(err) | ||
} | ||
newAccesspointC := conn.HandleChannelOpen(chanAccessPoint) | ||
newTransportC := conn.HandleChannelOpen(chanTransport) | ||
|
||
// send first ping right away, then start a ping timer: | ||
hb.SendRequest("ping", false, nil) | ||
ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod) | ||
defer ticker.Stop() | ||
|
||
for { | ||
select { | ||
// need to exit: | ||
case <-a.broadcastClose.C: | ||
a.log.Infof("agent is closing") | ||
return | ||
case <-closeC: | ||
a.log.Infof("asked to exit") | ||
return | ||
return nil | ||
// time to ping: | ||
case <-ticker.C: | ||
_, err := hb.SendRequest("ping", false, nil) | ||
if err != nil { | ||
a.log.Errorf("failed to send heartbeat: %v", err) | ||
errC <- err | ||
return | ||
return trace.Wrap(err) | ||
} | ||
} | ||
} | ||
}() | ||
|
||
go func() { | ||
for { | ||
select { | ||
case <-a.broadcastClose.C: | ||
a.log.Infof("agent is closing") | ||
return | ||
case <-closeC: | ||
log.Infof("asked to exit") | ||
return | ||
// ssh channel closed: | ||
case req := <-reqC: | ||
if req == nil { | ||
errC <- trace.Errorf("heartbeat: connection closed") | ||
return | ||
return trace.Errorf("heartbeat: connection closed") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that function needs an error object + message. I only have message here. |
||
} | ||
// new access point request: | ||
case nch := <-newAccesspointC: | ||
if nch == nil { | ||
continue | ||
} | ||
a.log.Infof("got out of band request: %v", req) | ||
a.log.Infof("reverseTunnel.Agent: access point request: %v", nch.ChannelType()) | ||
ch, req, err := nch.Accept() | ||
if err != nil { | ||
a.log.Errorf("failed to accept request: %v", err) | ||
continue | ||
} | ||
go a.proxyAccessPoint(ch, req) | ||
// new transport request: | ||
case nch := <-newTransportC: | ||
if nch == nil { | ||
continue | ||
} | ||
a.log.Infof("reverseTunnel.Agent: transport request: %v", nch.ChannelType()) | ||
ch, req, err := nch.Accept() | ||
if err != nil { | ||
a.log.Errorf("failed to accept request: %v", err) | ||
continue | ||
} | ||
go a.proxyTransport(ch, req) | ||
} | ||
} | ||
}() | ||
} | ||
|
||
a.log.Infof("got error: %v", <-errC) | ||
(&sync.Once{}).Do(func() { | ||
close(closeC) | ||
}) | ||
err := heartbeatLoop() | ||
if err != nil || conn == nil { | ||
time.Sleep(defaults.ReverseTunnelAgentHeartbeatPeriod) | ||
a.Start() | ||
} | ||
} | ||
|
||
const ( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -595,8 +595,8 @@ func (s *tunnelSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch | |
return | ||
} | ||
s.setLastActive(time.Now()) | ||
case <-time.After(2 * defaults.ReverseTunnelAgentHeartbeatPeriod): | ||
conn.markInvalid(trace.ConnectionProblem(nil, "agent missed 2 heartbeats")) | ||
case <-time.After(3 * defaults.ReverseTunnelAgentHeartbeatPeriod): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be a constant too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can't. tests set it to 50ms to run faster. |
||
conn.markInvalid(trace.ConnectionProblem(nil, "agent missed 3 heartbeats")) | ||
conn.sshConn.Close() | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trace.Wrap