diff --git a/core/os/device/remotessh/commands.go b/core/os/device/remotessh/commands.go index ee3ca62315..dff1346a24 100644 --- a/core/os/device/remotessh/commands.go +++ b/core/os/device/remotessh/commands.go @@ -32,21 +32,20 @@ import ( "github.com/google/gapid/core/os/device" "github.com/google/gapid/core/os/shell" "github.com/google/gapid/core/text" - "golang.org/x/crypto/ssh" ) // remoteProcess is the interface to a running process, as started by a Target. type remoteProcess struct { - session *ssh.Session wg sync.WaitGroup + session *pooledSession } func (r *remoteProcess) Kill() error { - return r.session.Signal(ssh.SIGSEGV) + return r.session.kill() } func (r *remoteProcess) Wait(ctx context.Context) error { - ret := r.session.Wait() + ret := r.session.wait() r.wg.Wait() return ret } @@ -57,17 +56,17 @@ type sshShellTarget struct{ b *binding } // Start starts the given command in the remote shell. func (t sshShellTarget) Start(cmd shell.Cmd) (shell.Process, error) { - session, err := t.b.connection.NewSession() + pooled, err := t.b.newPooledSession() if err != nil { return nil, err } p := &remoteProcess{ - session: session, + session: pooled, wg: sync.WaitGroup{}, } if cmd.Stdin != nil { - stdin, err := session.StdinPipe() + stdin, err := pooled.session.StdinPipe() if err != nil { return nil, err } @@ -78,7 +77,7 @@ func (t sshShellTarget) Start(cmd shell.Cmd) (shell.Process, error) { } if cmd.Stdout != nil { - stdout, err := session.StdoutPipe() + stdout, err := pooled.session.StdoutPipe() if err != nil { return nil, err } @@ -90,7 +89,7 @@ func (t sshShellTarget) Start(cmd shell.Cmd) (shell.Process, error) { } if cmd.Stderr != nil { - stderr, err := session.StderrPipe() + stderr, err := pooled.session.StderrPipe() if err != nil { return nil, err } @@ -121,7 +120,7 @@ func (t sshShellTarget) Start(cmd shell.Cmd) (shell.Process, error) { } val := prefix + cmd.Name + " " + strings.Join(cmd.Args, " ") - if err := session.Start(val); err != nil { + if err := pooled.session.Start(val); err != nil { return nil, err } diff --git a/core/os/device/remotessh/device.go b/core/os/device/remotessh/device.go index 092b357d50..9e9fb7b439 100644 --- a/core/os/device/remotessh/device.go +++ b/core/os/device/remotessh/device.go @@ -49,6 +49,10 @@ type Device interface { WriteFile(ctx context.Context, contents io.Reader, mode os.FileMode, destPath string) error } +// MaxNumberOfSSHConnections defines the max number of ssh connections to each +// ssh remote device that can be used to run commands concurrently. +const MaxNumberOfSSHConnections = 15 + // binding represents an attached SSH client. type binding struct { bind.Simple @@ -59,6 +63,63 @@ type binding struct { // We duplicate OS here because we need to use it // before we get the rest of the information os device.OSKind + + // pool to limit the maximum number of connections + ch chan int +} + +type pooledSession struct { + ch chan int + session *ssh.Session +} + +func (p *pooledSession) kill() error { + select { + case <-p.ch: + default: + } + <-p.ch + return p.session.Signal(ssh.SIGSEGV) +} + +func (p *pooledSession) wait() error { + ret := p.session.Wait() + select { + case <-p.ch: + default: + } + return ret +} + +func newBinding(conn *ssh.Client, conf *Configuration, env *shell.Env) *binding { + b := &binding{ + connection: conn, + configuration: conf, + env: env, + ch: make(chan int, MaxNumberOfSSHConnections), + Simple: bind.Simple{ + To: &device.Instance{ + Serial: "", + Configuration: &device.Configuration{}, + }, + LastStatus: bind.Status_Online, + }, + } + return b +} + +func (b *binding) newPooledSession() (*pooledSession, error) { + b.ch <- int(0) + session, err := b.connection.NewSession() + if err != nil { + <-b.ch + err = fmt.Errorf("New SSH Session Error: %v, Current maximum number of ssh connections GAPID can issue to each remote device is: %v", err, MaxNumberOfSSHConnections) + return nil, err + } + return &pooledSession{ + ch: b.ch, + session: session, + }, nil } var _ Device = &binding{} @@ -145,18 +206,7 @@ func GetConnectedDevice(ctx context.Context, c Configuration) (Device, error) { env.Add(e) } - b := &binding{ - connection: connection, - configuration: &c, - env: env, - Simple: bind.Simple{ - To: &device.Instance{ - Serial: "", - Configuration: &device.Configuration{}, - }, - LastStatus: bind.Status_Online, - }, - } + b := newBinding(connection, &c, env) kind := device.UnknownOS