Skip to content

Commit

Permalink
Use built-in ssh impl for all non-pty operations
Browse files Browse the repository at this point in the history
Windows is not guaranteed to have the SSH feature installed, so prefer the use
of the built-in ssh client for all operations other than podman machine ssh,
which requires terminal pty logic. This restores previous behavior in 4.x.

Signed-off-by: Jason T. Greene <[email protected]>
  • Loading branch information
n1hility committed Mar 25, 2024
1 parent 2aad385 commit 11415b3
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 14 deletions.
2 changes: 1 addition & 1 deletion cmd/podman/machine/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func ssh(cmd *cobra.Command, args []string) error {
username = mc.SSH.RemoteUsername
}

err = machine.CommonSSH(username, mc.SSH.IdentityPath, mc.Name, mc.SSH.Port, sshOpts.Args)
err = machine.CommonSSHShell(username, mc.SSH.IdentityPath, mc.Name, mc.SSH.Port, sshOpts.Args)
return utils.HandleOSExecError(err)
}

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ require (
github.com/vbauerster/mpb/v8 v8.7.2
github.com/vishvananda/netlink v1.2.1-beta.2
go.etcd.io/bbolt v1.3.9
golang.org/x/crypto v0.21.0
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225
golang.org/x/net v0.22.0
golang.org/x/sync v0.6.0
Expand Down Expand Up @@ -210,7 +211,6 @@ require (
go.opentelemetry.io/otel/sdk v1.21.0 // indirect
go.opentelemetry.io/otel/trace v1.22.0 // indirect
golang.org/x/arch v0.7.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/mod v0.15.0 // indirect
golang.org/x/oauth2 v0.18.0 // indirect
golang.org/x/time v0.3.0 // indirect
Expand Down
3 changes: 1 addition & 2 deletions pkg/machine/hyperv/volumes.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ func removeShares(mc *vmconfigs.MachineConfig) error {

func startShares(mc *vmconfigs.MachineConfig) error {
for _, mount := range mc.Mounts {
args := []string{"-q", "--"}

var args []string
cleanTarget := path.Clean(mount.Target)
requiresChattr := !strings.HasPrefix(cleanTarget, "/home") && !strings.HasPrefix(cleanTarget, "/mnt")
if requiresChattr {
Expand Down
4 changes: 2 additions & 2 deletions pkg/machine/qemu/stubber.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func (q *QEMUStubber) MountVolumesToVM(mc *vmconfigs.MachineConfig, quiet bool)
// create mountpoint directory if it doesn't exist
// because / is immutable, we have to monkey around with permissions
// if we dont mount in /home or /mnt
args := []string{"-q", "--"}
var args []string
if !strings.HasPrefix(mount.Target, "/home") && !strings.HasPrefix(mount.Target, "/mnt") {
args = append(args, "sudo", "chattr", "-i", "/", ";")
}
Expand All @@ -333,7 +333,7 @@ func (q *QEMUStubber) MountVolumesToVM(mc *vmconfigs.MachineConfig, quiet bool)
if mount.ReadOnly {
mountOptions = append(mountOptions, []string{"-o", "ro"}...)
}
err = machine.CommonSSH(mc.SSH.RemoteUsername, mc.SSH.IdentityPath, mc.Name, mc.SSH.Port, append([]string{"-q", "--", "sudo", "mount"}, mountOptions...))
err = machine.CommonSSH(mc.SSH.RemoteUsername, mc.SSH.IdentityPath, mc.Name, mc.SSH.Port, append([]string{"sudo", "mount"}, mountOptions...))
if err != nil {
return err
}
Expand Down
98 changes: 90 additions & 8 deletions pkg/machine/ssh.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,115 @@
package machine

import (
"bufio"
"fmt"
"io"
"os"
"os/exec"
"strconv"
"strings"

"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)

// CommonSSH is a common function for ssh'ing to a podman machine using system-connections
// and a port
// TODO This should probably be taught about an machineconfig to reduce input
func CommonSSH(username, identityPath, name string, sshPort int, inputArgs []string) error {
return commonSSH(username, identityPath, name, sshPort, inputArgs, false, os.Stdin)
return commonBuiltinSSH(username, identityPath, name, sshPort, inputArgs, true, os.Stdin)
}

func CommonSSHShell(username, identityPath, name string, sshPort int, inputArgs []string) error {
return commonNativeSSH(username, identityPath, name, sshPort, inputArgs, os.Stdin)
}

func CommonSSHSilent(username, identityPath, name string, sshPort int, inputArgs []string) error {
return commonSSH(username, identityPath, name, sshPort, inputArgs, true, os.Stdin)
return commonBuiltinSSH(username, identityPath, name, sshPort, inputArgs, false, nil)
}

func CommonSSHWithStdin(username, identityPath, name string, sshPort int, inputArgs []string, stdin io.Reader) error {
return commonSSH(username, identityPath, name, sshPort, inputArgs, false, stdin)
return commonBuiltinSSH(username, identityPath, name, sshPort, inputArgs, true, stdin)
}

func commonBuiltinSSH(username, identityPath, name string, sshPort int, inputArgs []string, passOutput bool, stdin io.Reader) error {
config, err := createConfig(username, identityPath)
if err != nil {
return err
}

client, err := ssh.Dial("tcp", fmt.Sprintf("localhost:%d", sshPort), config)
if err != nil {
return err
}
defer client.Close()

session, err := client.NewSession()
if err != nil {
return err
}
defer session.Close()

cmd := strings.Join(inputArgs, " ")
logrus.Debugf("Running ssh command on machine %q: %s", name, cmd)
session.Stdin = stdin
if passOutput {
session.Stdout = os.Stdout
session.Stderr = os.Stderr
} else if logrus.IsLevelEnabled(logrus.DebugLevel) {
return runSessionWithDebug(session, cmd)
}

return session.Run(cmd)
}

func commonSSH(username, identityPath, name string, sshPort int, inputArgs []string, silent bool, stdin io.Reader) error {
func runSessionWithDebug(session *ssh.Session, cmd string) error {
outPipe, err := session.StdoutPipe()
if err != nil {
return err
}
errPipe, err := session.StderrPipe()
if err != nil {
return err
}
logOuput := func(pipe io.Reader, done chan struct{}) {
scanner := bufio.NewScanner(pipe)
for scanner.Scan() {
logrus.Debugf("ssh output: %s", scanner.Text())
}
done <- struct{}{}
}
if err := session.Start(cmd); err != nil {
return err
}
completed := make(chan struct{}, 2)
go logOuput(outPipe, completed)
go logOuput(errPipe, completed)
<-completed
<-completed

return session.Wait()
}

func createConfig(user string, identityPath string) (*ssh.ClientConfig, error) {
key, err := os.ReadFile(identityPath)
if err != nil {
return nil, err
}

signer, err := ssh.ParsePrivateKey(key)
if err != nil {
return nil, err
}

return &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}, nil
}

func commonNativeSSH(username, identityPath, name string, sshPort int, inputArgs []string, stdin io.Reader) error {
sshDestination := username + "@localhost"
port := strconv.Itoa(sshPort)
interactive := true
Expand All @@ -45,10 +129,8 @@ func commonSSH(username, identityPath, name string, sshPort int, inputArgs []str
cmd := exec.Command("ssh", args...)
logrus.Debugf("Executing: ssh %v\n", args)

if !silent {
if err := setupIOPassthrough(cmd, interactive, stdin); err != nil {
return err
}
if err := setupIOPassthrough(cmd, interactive, stdin); err != nil {
return err
}

return cmd.Run()
Expand Down

0 comments on commit 11415b3

Please sign in to comment.