diff --git a/transport/ssh/ssh.go b/transport/ssh/ssh.go index 0275fc1..e030a82 100644 --- a/transport/ssh/ssh.go +++ b/transport/ssh/ssh.go @@ -19,9 +19,10 @@ type Transport struct { sess *ssh.Session stdin io.WriteCloser - // indicate that we "own" the client and should close it with the session - // when the transport is closed. - ownedClient bool + // set to true if the transport is managing the underlying ssh connection + // and should close it when the transport is closed. This is is set to true + // when used with `Dial`. + managed bool *framer } @@ -40,10 +41,36 @@ func Dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) ( if err != nil { return nil, err } + + // Setup a go routine to monitor the context and close the connection. This + // is needed as the underlying ssh library doesn't support contexts so this + // approximates a context based cancelation/timeout for the ssh handshake. + // + // An alternative would be timeout based with conn.SetDeadline(), but then we + // would manage two timeouts. One for tcp connection and one for ssh + // handshake and wouldn't support any other event based cancelation. + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + // context is canceled so close the underlying connection. Will + // will catch ctx.Err() later. + conn.Close() + case <-done: + } + }() + sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config) if err != nil { + // if there is a context timeout return that error instead of the actual + // error from ssh.NewClientConn. + if ctx.Err() != nil { + return nil, ctx.Err() + } return nil, err } + close(done) // make sure we cleanup the context monitor routine + client := ssh.NewClient(sshConn, chans, reqs) return newTransport(client, true) } @@ -56,7 +83,7 @@ func NewTransport(client *ssh.Client) (*Transport, error) { return newTransport(client, false) } -func newTransport(client *ssh.Client, owned bool) (*Transport, error) { +func newTransport(client *ssh.Client, managed bool) (*Transport, error) { sess, err := client.NewSession() if err != nil { return nil, fmt.Errorf("failed to create ssh session: %w", err) @@ -78,10 +105,10 @@ func newTransport(client *ssh.Client, owned bool) (*Transport, error) { } return &Transport{ - c: client, - ownedClient: owned, - sess: sess, - stdin: w, + c: client, + managed: managed, + sess: sess, + stdin: w, framer: transport.NewFramer(r, w), }, nil @@ -104,7 +131,7 @@ func (t *Transport) Close() error { retErr = fmt.Errorf("failed to close ssh channel: %w", err) } - if t.ownedClient { + if t.managed { if err := t.c.Close(); err != nil { return fmt.Errorf("failed to close ssh connnection: %w", t.c.Close()) }