diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000..a46404e --- /dev/null +++ b/client/client.go @@ -0,0 +1,634 @@ +package client + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "strings" + "syscall" + "time" + + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/http3" + "github.com/rs/zerolog/log" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + "golang.org/x/term" + + "github.com/francoismichel/ssh3" + "github.com/francoismichel/ssh3/auth" + "github.com/francoismichel/ssh3/client/winsize" + ssh3Messages "github.com/francoismichel/ssh3/message" + "github.com/francoismichel/ssh3/util" +) + +type ExitStatus struct { + StatusCode int +} + +func (e ExitStatus) Error() string { + return fmt.Sprintf("exited with status %d", e.StatusCode) +} + +type ExitSignal struct { + Signal string + ErrorMessageUTF8 string +} + +func (e ExitSignal) Error() string { + return fmt.Sprintf("exited with signal %s: %s", e.Signal, e.ErrorMessageUTF8) +} + +type NoSuitableIdentity struct{} + +func (e NoSuitableIdentity) Error() string { + return "no suitable identity found" +} + +func forwardAgent(parent context.Context, channel ssh3.Channel) error { + sockPath := os.Getenv("SSH_AUTH_SOCK") + if sockPath == "" { + return fmt.Errorf("no auth socket in SSH_AUTH_SOCK env var") + } + c, err := net.Dial("unix", sockPath) + if err != nil { + return err + } + defer c.Close() + ctx, cancel := context.WithCancelCause(parent) + go func() { + var err error = nil + var genericMessage ssh3Messages.Message + for { + select { + case <-ctx.Done(): + err = context.Cause(ctx) + if err != nil { + log.Error().Msgf("reading message stopped on channel %d: %s", channel.ChannelID(), err.Error()) + } + return + default: + genericMessage, err = channel.NextMessage() + if err != nil && err != io.EOF { + err = fmt.Errorf("error when getting message on channel %d: %s", channel.ChannelID(), err.Error()) + cancel(err) + return + } + if genericMessage == nil { + return + } + switch message := genericMessage.(type) { + case *ssh3Messages.DataOrExtendedDataMessage: + _, err = c.Write([]byte(message.Data)) + if err != nil { + err = fmt.Errorf("error when writing on unix socker for agent forwarding channel %d: %s", channel.ChannelID(), err.Error()) + cancel(err) + return + } + default: + err = fmt.Errorf("unhandled message type on agent channel %d: %T", channel.ChannelID(), message) + cancel(err) + return + } + } + } + }() + + buf := make([]byte, channel.MaxPacketSize()) + for { + select { + case <-ctx.Done(): + err = context.Cause(ctx) + if err != nil { + log.Error().Msgf("ending agent forwarding on channel %d: %s", channel.ChannelID(), err.Error()) + } + return err + default: + n, err := c.Read(buf) + if err == io.EOF { + log.Debug().Msgf("unix socket for ssh agent closed") + return nil + } else if err != nil { + cancel(err) + log.Error().Msgf("could not read on unix socket: %s", err.Error()) + return err + } + _, err = channel.WriteData(buf[:n], ssh3Messages.SSH_EXTENDED_DATA_NONE) + if err != nil { + cancel(err) + log.Error().Msgf("could not write on ssh channel: %s", err.Error()) + return err + } + } + } +} + +func forwardTCPInBackground(ctx context.Context, channel ssh3.Channel, conn *net.TCPConn) { + go func() { + defer conn.CloseWrite() + for { + select { + case <-ctx.Done(): + return + default: + } + genericMessage, err := channel.NextMessage() + if err == io.EOF { + log.Info().Msgf("eof on tcp-forwarding channel %d", channel.ChannelID()) + } else if err != nil { + log.Error().Msgf("could get message from tcp forwarding channel: %s", err) + return + } + + // nothing to process + if genericMessage == nil { + return + } + + switch message := genericMessage.(type) { + case *ssh3Messages.DataOrExtendedDataMessage: + if message.DataType == ssh3Messages.SSH_EXTENDED_DATA_NONE { + _, err := conn.Write([]byte(message.Data)) + if err != nil { + log.Error().Msgf("could not write data on TCP socket: %s", err) + // signal the write error to the peer + channel.CancelRead() + return + } + } else { + log.Warn().Msgf("ignoring message data of unexpected type %d on TCP forwarding channel %d", message.DataType, channel.ChannelID()) + } + default: + log.Warn().Msgf("ignoring message of type %T on TCP forwarding channel %d", message, channel.ChannelID()) + } + } + }() + + go func() { + defer channel.Close() + defer conn.CloseRead() + buf := make([]byte, channel.MaxPacketSize()) + for { + select { + case <-ctx.Done(): + return + default: + } + n, err := conn.Read(buf) + if err != nil && err != io.EOF { + log.Error().Msgf("could read data on TCP socket: %s", err) + return + } + _, errWrite := channel.WriteData(buf[:n], ssh3Messages.SSH_EXTENDED_DATA_NONE) + if errWrite != nil { + switch quicErr := errWrite.(type) { + case *quic.StreamError: + if quicErr.Remote && quicErr.ErrorCode == 42 { + log.Info().Msgf("writing was canceled by the remote, closing the socket: %s", errWrite) + } else { + log.Error().Msgf("unhandled quic stream error: %+v", quicErr) + } + default: + log.Error().Msgf("could send data on channel: %s", errWrite) + } + return + } + if err == io.EOF { + return + } + } + }() +} + +type Client struct { + qconn quic.EarlyConnection + *ssh3.Conversation +} + +func Dial(ctx context.Context, options *Options, qconn quic.EarlyConnection, + roundTripper *http3.RoundTripper, + sshAgent agent.ExtendedAgent) (*Client, error) { + + hostUrl := url.URL{} + hostUrl.Scheme = "https" + hostUrl.Host = options.URLHostnamePort() + hostUrl.Path = options.UrlPath() + urlQuery := hostUrl.Query() + urlQuery.Set("user", options.Username()) + hostUrl.RawQuery = urlQuery.Encode() + requestUrl := hostUrl.String() + + var qconf quic.Config + + qconf.MaxIncomingStreams = 10 + qconf.Allow0RTT = true + qconf.EnableDatagrams = true + qconf.KeepAlivePeriod = 1 * time.Second + + var agentKeys []ssh.PublicKey + if sshAgent != nil { + keys, err := sshAgent.List() + if err != nil { + log.Error().Msgf("Failed to list agent keys: %s", err) + return nil, err + } + for _, key := range keys { + agentKeys = append(agentKeys, key) + } + } + + // dirty hack: ensure only one QUIC connection is used + roundTripper.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + return qconn, nil + } + + // Do 0RTT GET requests here if needed + // Currently, we don't need it but we could use it to retrieve + // config or version info from the server + // We could also allow user-defined safe/idempotent commands to run with 0-RTT + qconn.HandshakeComplete() + log.Debug().Msgf("QUIC handshake complete") + // Now, we're 1-RTT, we can get the TLS exporter and create the conversation + tls := qconn.ConnectionState().TLS + conv, err := ssh3.NewClientConversation(30000, 10, &tls) + if err != nil { + return nil, err + } + + // the connection struct is created, now build the request used to establish the connection + req, err := http.NewRequest("CONNECT", requestUrl, nil) + if err != nil { + log.Fatal().Msgf("%s", err) + } + req.Proto = "ssh3" + req.Header.Set("User-Agent", ssh3.GetCurrentVersion()) + + var identity ssh3.Identity + for _, method := range options.authMethods { + switch m := method.(type) { + case *ssh3.PasswordAuthMethod: + fmt.Printf("password for %s:", hostUrl.String()) + password, err := term.ReadPassword(int(syscall.Stdin)) + fmt.Println() + if err != nil { + log.Error().Msgf("could not get password: %s", err) + return nil, err + } + identity = m.IntoIdentity(string(password)) + case *ssh3.PrivkeyFileAuthMethod: + identity, err = m.IntoIdentityWithoutPassphrase() + // could not identify without passphrase, try agent authentication by using the key's public key + if passphraseErr, ok := err.(*ssh.PassphraseMissingError); ok { + // the pubkey may be contained in the privkey file + pubkey := passphraseErr.PublicKey + if pubkey == nil { + // if it is not the case, try to find a .pub equivalent, like OpenSSH does + pubkeyBytes, err := os.ReadFile(fmt.Sprintf("%s.pub", m.Filename())) + if err == nil { + filePubkey, _, _, _, err := ssh.ParseAuthorizedKey(pubkeyBytes) + if err == nil { + pubkey = filePubkey + } + } + } + + // now, try to see of the agent manages this key + foundAgentKey := false + if pubkey != nil { + for _, agentKey := range agentKeys { + if bytes.Equal(agentKey.Marshal(), pubkey.Marshal()) { + log.Debug().Msgf("found key in agent: %s", agentKey) + identity = ssh3.NewAgentAuthMethod(pubkey).IntoIdentity(sshAgent) + foundAgentKey = true + break + } + } + } + + // key not handled by agent, let's try to decrypt it ourselves + if !foundAgentKey { + fmt.Printf("passphrase for private key stored in %s:", m.Filename()) + var passphraseBytes []byte + passphraseBytes, err = term.ReadPassword(int(syscall.Stdin)) + fmt.Println() + if err != nil { + log.Error().Msgf("could not get passphrase: %s", err) + return nil, err + } + passphrase := string(passphraseBytes) + identity, err = m.IntoIdentityPassphrase(passphrase) + if err != nil { + log.Error().Msgf("could not load private key: %s", err) + return nil, err + } + } + } else if err != nil { + log.Warn().Msgf("Could not load private key: %s", err) + } + case *ssh3.AgentAuthMethod: + identity = m.IntoIdentity(sshAgent) + case *ssh3.OidcAuthMethod: + token, err := auth.Connect(context.Background(), m.OIDCConfig(), m.OIDCConfig().IssuerUrl, m.DoPKCE()) + if err != nil { + log.Error().Msgf("could not get token: %s", err) + return nil, err + } + identity = m.IntoIdentity(token) + } + // currently only tries a single identity (the first one), but the goal is to + // try several identities, similarly to OpenSSH + break + } + + if identity == nil { + return nil, NoSuitableIdentity{} + } + + log.Debug().Msgf("try the following Identity: %s", identity) + err = identity.SetAuthorizationHeader(req, options.Username(), conv) + if err != nil { + log.Error().Msgf("could not set authorization header in HTTP request: %s", err) + return nil, err + } + + log.Debug().Msgf("send CONNECT request to the server") + err = conv.EstablishClientConversation(req, roundTripper) + if errors.Is(err, util.Unauthorized{}) { + log.Error().Msgf("Access denied from the server: unauthorized") + return nil, err + } else if err != nil { + log.Error().Msgf("Could not establish conversation: %+v", err) + return nil, err + } + + return &Client{ + qconn: qconn, + Conversation: conv, + }, nil +} + +func (c *Client) ForwardUDP(ctx context.Context, localUDPAddr *net.UDPAddr, remoteUDPAddr *net.UDPAddr) error { + log.Debug().Msgf("start UDP forwarding from %s to %s", localUDPAddr, remoteUDPAddr) + conn, err := net.ListenUDP("udp", localUDPAddr) + if err != nil { + log.Error().Msgf("could listen on UDP socket: %s", err) + return err + } + forwardings := make(map[string]ssh3.Channel) + go func() { + buf := make([]byte, 1500) + for { + n, addr, err := conn.ReadFromUDP(buf) + if err != nil { + log.Error().Msgf("could read on UDP socket: %s", err) + return + } + channel, ok := forwardings[addr.String()] + if !ok { + channel, err = c.OpenUDPForwardingChannel(30000, 10, localUDPAddr, remoteUDPAddr) + if err != nil { + log.Error().Msgf("could open new UDP forwarding channel: %s", err) + return + } + forwardings[addr.String()] = channel + + go func() { + for { + dgram, err := channel.ReceiveDatagram(ctx) + if err != nil { + log.Error().Msgf("could open receive datagram on channel: %s", err) + return + } + _, err = conn.WriteToUDP(dgram, addr) + if err != nil { + log.Error().Msgf("could open write datagram on socket: %s", err) + return + } + } + }() + } + err = channel.SendDatagram(buf[:n]) + if err != nil { + log.Error().Msgf("could not send datagram: %s", err) + return + } + } + }() + return nil +} + +func (c *Client) ForwardTCP(ctx context.Context, localTCPAddr *net.TCPAddr, remoteTCPAddr *net.TCPAddr) error { + log.Debug().Msgf("start TCP forwarding from %s to %s", localTCPAddr, remoteTCPAddr) + conn, err := net.ListenTCP("tcp", localTCPAddr) + if err != nil { + log.Error().Msgf("could listen on TCP socket: %s", err) + return err + } + go func() { + for { + conn, err := conn.AcceptTCP() + if err != nil { + log.Error().Msgf("could read on UDP socket: %s", err) + return + } + forwardingChannel, err := c.OpenTCPForwardingChannel(30000, 10, localTCPAddr, remoteTCPAddr) + if err != nil { + log.Error().Msgf("could open new UDP forwarding channel: %s", err) + return + } + forwardTCPInBackground(ctx, forwardingChannel, conn) + } + }() + return nil +} + +func (c *Client) RunSession(tty *os.File, forwardSSHAgent bool, command ...string) error { + + ctx := c.Context() + + channel, err := c.OpenChannel("session", 30000, 0) + if err != nil { + fmt.Fprintf(os.Stderr, "Could not open channel: %+v", err) + os.Exit(-1) + } + + log.Debug().Msgf("opened new session channel") + + if forwardSSHAgent { + _, err := channel.WriteData([]byte("forward-agent"), ssh3Messages.SSH_EXTENDED_DATA_NONE) + if err != nil { + log.Error().Msgf("could not forward agent: %s", err.Error()) + return err + } + go func() { + for { + forwardChannel, err := c.AcceptChannel(ctx) + if err != nil { + if err != context.Canceled { + log.Error().Msgf("could not accept forwarding channel: %s", err.Error()) + } + return + } else if forwardChannel.ChannelType() != "agent-connection" { + log.Error().Msgf("unexpected server-initiated channel: %s", channel.ChannelType()) + return + } + log.Debug().Msg("new agent connection, forwarding") + go func() { + err = forwardAgent(ctx, forwardChannel) + if err != nil { + log.Error().Msgf("agent forwarding error: %s", err.Error()) + c.Close() + } + }() + } + }() + } + + if len(command) == 0 { + // avoid requesting a pty on the other side if stdin is not a pty + // similar behaviour to OpenSSH + isATTY := term.IsTerminal(int(tty.Fd())) + + windowSize, err := winsize.GetWinsize(tty) + if err != nil { + log.Warn().Msgf("could not get window size: %+v", err) + } + hasWinSize := err == nil + if isATTY && hasWinSize { + err = channel.SendRequest( + &ssh3Messages.ChannelRequestMessage{ + WantReply: true, + ChannelRequest: &ssh3Messages.PtyRequest{ + Term: os.Getenv("TERM"), + CharWidth: uint64(windowSize.NCols), + CharHeight: uint64(windowSize.NRows), + PixelWidth: uint64(windowSize.PixelWidth), + PixelHeight: uint64(windowSize.PixelHeight), + }, + }, + ) + + if err != nil { + fmt.Fprintf(os.Stderr, "Could send pty request: %+v", err) + return err + } + log.Debug().Msgf("sent pty request for session") + } + + err = channel.SendRequest( + &ssh3Messages.ChannelRequestMessage{ + WantReply: true, + ChannelRequest: &ssh3Messages.ShellRequest{}, + }, + ) + if err != nil { + log.Error().Msgf("could not send shell request: %s", err) + return err + } + log.Debug().Msgf("sent shell request, hasWinSize = %t", hasWinSize) + // avoid making the terminal raw if stdin is not a TTY + // similar behaviour to OpenSSH + if isATTY { + fd := os.Stdin.Fd() + oldState, err := term.MakeRaw(int(fd)) + if err != nil { + log.Warn().Msgf("cannot make tty raw: %s", err) + } else { + defer term.Restore(int(fd), oldState) + } + } + } else { + channel.SendRequest( + &ssh3Messages.ChannelRequestMessage{ + WantReply: true, + ChannelRequest: &ssh3Messages.ExecRequest{ + Command: strings.Join(command, " "), + }, + }, + ) + log.Debug().Msgf("sent exec request for command \"%s\"", strings.Join(command, " ")) + } + + if err != nil { + fmt.Fprintf(os.Stderr, "Could send shell request: %+v", err) + return err + } + + go func() { + buf := make([]byte, channel.MaxPacketSize()) + for { + n, err := os.Stdin.Read(buf) + if n > 0 { + _, err2 := channel.WriteData(buf[:n], ssh3Messages.SSH_EXTENDED_DATA_NONE) + if err2 != nil { + fmt.Fprintf(os.Stderr, "could not write data on channel: %+v", err2) + return + } + } + if err != nil { + fmt.Fprintf(os.Stderr, "could not read data from stdin: %+v", err) + return + } + } + }() + + defer fmt.Printf("\r") + + for { + genericMessage, err := channel.NextMessage() + if err != nil { + fmt.Fprintf(os.Stderr, "Could not get message: %+v\n", err) + os.Exit(-1) + } + switch message := genericMessage.(type) { + case *ssh3Messages.ChannelRequestMessage: + switch requestMessage := message.ChannelRequest.(type) { + case *ssh3Messages.PtyRequest: + fmt.Fprintf(os.Stderr, "receiving a pty request on the client is not implemented\n") + case *ssh3Messages.X11Request: + fmt.Fprintf(os.Stderr, "receiving a x11 request on the client is not implemented\n") + case *ssh3Messages.ShellRequest: + fmt.Fprintf(os.Stderr, "receiving a shell request on the client is not implemented\n") + case *ssh3Messages.ExecRequest: + fmt.Fprintf(os.Stderr, "receiving a exec request on the client is not implemented\n") + case *ssh3Messages.SubsystemRequest: + fmt.Fprintf(os.Stderr, "receiving a subsystem request on the client is not implemented\n") + case *ssh3Messages.WindowChangeRequest: + fmt.Fprintf(os.Stderr, "receiving a windowchange request on the client is not implemented\n") + case *ssh3Messages.SignalRequest: + fmt.Fprintf(os.Stderr, "receiving a signal request on the client is not implemented\n") + case *ssh3Messages.ExitStatusRequest: + log.Info().Msgf("ssh3: process exited with status: %d\n", requestMessage.ExitStatus) + // forward the process' status code to the user + return ExitStatus{StatusCode: int(requestMessage.ExitStatus)} + case *ssh3Messages.ExitSignalRequest: + log.Info().Msgf("ssh3: process exited with signal: %s: %s\n", requestMessage.SignalNameWithoutSig, requestMessage.ErrorMessageUTF8) + return ExitSignal{Signal: requestMessage.SignalNameWithoutSig, ErrorMessageUTF8: requestMessage.ErrorMessageUTF8} + } + case *ssh3Messages.DataOrExtendedDataMessage: + switch message.DataType { + case ssh3Messages.SSH_EXTENDED_DATA_NONE: + _, err = os.Stdout.Write([]byte(message.Data)) + if err != nil { + log.Fatal().Msgf("%s", err) + } + + log.Debug().Msgf("received data %s", message.Data) + case ssh3Messages.SSH_EXTENDED_DATA_STDERR: + _, err = os.Stderr.Write([]byte(message.Data)) + if err != nil { + log.Fatal().Msgf("%s", err) + } + + log.Debug().Msgf("received stderr data %s", message.Data) + } + } + } +} diff --git a/client/options.go b/client/options.go new file mode 100644 index 0000000..b0d4a4e --- /dev/null +++ b/client/options.go @@ -0,0 +1,70 @@ +package client + +import ( + "fmt" + "net" +) + +type Options struct { + username string + hostname string + port int + urlPath string + authMethods []interface{} +} + +func NewOptions(username string, hostname string, port int, urlPath string, authMethods []interface{}) (*Options, error) { + if len(urlPath) == 0 || urlPath[0] != '/' { + urlPath = "/" + urlPath + } + return &Options{ + username: username, + hostname: hostname, + port: port, + urlPath: urlPath, + authMethods: authMethods, + }, nil +} + +func (o *Options) Username() string { + return o.username +} + +func (o *Options) Hostname() string { + return o.hostname +} + +// Returns the pair hostname:port in a valid URL format. +// This means that an IPv6 will be written inside square brackets []. +// examples: "127.0.0.1:443", "example.org:1234", "[::1]:22"" +func (o *Options) URLHostnamePort() string { + hostnameIsAnIP := net.ParseIP(o.hostname) != nil + if hostnameIsAnIP { + ip := net.ParseIP(o.hostname) + if ip.To4() == nil && ip.To16() != nil { + // enforce the square-bracketed notation for ipv6 UDP addresses + return fmt.Sprintf("[%s]:%d", o.hostname, o.port) + } + } + return fmt.Sprintf("%s:%d", o.hostname, o.port) +} + +func (o *Options) Port() int { + return o.port +} +func (o *Options) UrlPath() string { + return o.urlPath +} + +// Returns the canonical host representation used by SSH3. +// The format is +// is the host:port pair in the format returned by +// URLHostnamePort() +// is the URL path, it always starts with a "/", it is therefore never empty. +func (o *Options) CanonicalHostFormat() string { + return fmt.Sprintf("%s:%d%s", o.hostname, o.port, o.urlPath) +} + +func (o *Options) AuthMethods() []interface{} { + return o.authMethods +} diff --git a/cmd/ssh3/winsize/common.go b/client/winsize/common.go similarity index 100% rename from cmd/ssh3/winsize/common.go rename to client/winsize/common.go diff --git a/client/winsize/winsize.go b/client/winsize/winsize.go new file mode 100644 index 0000000..ab347ad --- /dev/null +++ b/client/winsize/winsize.go @@ -0,0 +1,18 @@ +//go:build !windows + +package winsize + +import ( + "os" + "syscall" + "unsafe" +) + +func GetWinsize(tty *os.File) (ws WindowSize, err error) { + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(tty.Fd()), uintptr(syscall.TIOCGWINSZ), + uintptr(unsafe.Pointer(&ws))) + if errno != 0 { + err = errno + } + return ws, err +} diff --git a/cmd/ssh3/winsize/winsize_windows.go b/client/winsize/winsize_windows.go similarity index 67% rename from cmd/ssh3/winsize/winsize_windows.go rename to client/winsize/winsize_windows.go index d1f7e98..932a9cb 100644 --- a/cmd/ssh3/winsize/winsize_windows.go +++ b/client/winsize/winsize_windows.go @@ -2,13 +2,16 @@ package winsize -import "os" -import "golang.org/x/term" +import ( + "os" -func GetWinsize() (ws WindowSize, err error) { + "golang.org/x/term" +) + +func GetWinsize(tty *os.File) (ws WindowSize, err error) { // for Windows, it is a bit more complicated to get the window size in pixels, so on rely // on window size expressed in columns - width, height, err := term.GetSize(int(os.Stdout.Fd())) + width, height, err := term.GetSize(int(tty.Fd())) if err != nil { return ws, err } diff --git a/client.go b/client_auth.go similarity index 99% rename from client.go rename to client_auth.go index adb9d9b..22cc56a 100644 --- a/client.go +++ b/client_auth.go @@ -31,6 +31,10 @@ func (m *OidcAuthMethod) OIDCConfig() *auth.OIDCConfig { return m.config } +func (m *OidcAuthMethod) DoPKCE() bool { + return m.doPKCE +} + type PrivkeyFileAuthMethod struct { filename string } diff --git a/cmd/ssh3-server/main.go b/cmd/ssh3-server/main.go index 9b55871..e672094 100644 --- a/cmd/ssh3-server/main.go +++ b/cmd/ssh3-server/main.go @@ -578,7 +578,7 @@ func newDataReq(user *unix_util.User, channel ssh3.Channel, request ssh3Messages func handleAuthAgentSocketConn(conn net.Conn, conversation *ssh3.Conversation) { channel, err := conversation.OpenChannel("agent-connection", 30000, 10) if err != nil { - log.Error().Msgf("could not open channel: %s", err.Error()) + log.Error().Msgf("could not open channel from server: %s", err.Error()) return } go func() { diff --git a/cmd/ssh3/main.go b/cmd/ssh3/main.go index 92200c1..5499581 100644 --- a/cmd/ssh3/main.go +++ b/cmd/ssh3/main.go @@ -1,11 +1,7 @@ package main import ( - // "bufio" - // "bytes" - // "context" "bufio" - "bytes" "context" "crypto/tls" "crypto/x509" @@ -15,29 +11,24 @@ import ( "fmt" "io" "net" - "net/http" "net/url" "os" osuser "os/user" "path" "strconv" "strings" - "syscall" "time" - "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/agent" - "golang.org/x/term" - "github.com/francoismichel/ssh3" "github.com/francoismichel/ssh3/auth" - "github.com/francoismichel/ssh3/cmd/ssh3/winsize" - ssh3Messages "github.com/francoismichel/ssh3/message" + "github.com/francoismichel/ssh3/client" "github.com/francoismichel/ssh3/util" - - "github.com/kevinburke/ssh_config" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + + "github.com/kevinburke/ssh_config" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) @@ -51,159 +42,144 @@ func homedir() string { } } -func forwardAgent(parent context.Context, channel ssh3.Channel) error { - sockPath := os.Getenv("SSH_AUTH_SOCK") - if sockPath == "" { - return fmt.Errorf("no auth socket in SSH_AUTH_SOCK env var") - } - c, err := net.Dial("unix", sockPath) +// Prepares the QUIC connection that will be used by SSH3 +// If non-nil, use udpConn as transport (can be used for proxy jump) +// Otherwise, create a UDPConn from udp://host:port +func setupQUICConnection(ctx context.Context, skipHostVerification bool, keylog io.Writer, ssh3Dir string, certPool *x509.CertPool, knownHostsPath string, knownHosts ssh3.KnownHosts, + oidcConfig []*auth.OIDCConfig, options *client.Options, udpConn *net.UDPConn, tty *os.File) (quic.EarlyConnection, int) { + + remoteAddr, err := net.ResolveUDPAddr("udp", options.URLHostnamePort()) if err != nil { - return err + log.Error().Msgf("could not resolve UDP address: %s", err) + return nil, -1 } - defer c.Close() - ctx, cancel := context.WithCancelCause(parent) - go func() { - var err error = nil - var genericMessage ssh3Messages.Message - for { - select { - case <-ctx.Done(): - err = context.Cause(ctx) - if err != nil { - log.Error().Msgf("reading message stopped on channel %d: %s", channel.ChannelID(), err.Error()) - } - return - default: - genericMessage, err = channel.NextMessage() - if err != nil && err != io.EOF { - err = fmt.Errorf("error when getting message on channel %d: %s", channel.ChannelID(), err.Error()) - cancel(err) - return - } - if genericMessage == nil { - return - } - switch message := genericMessage.(type) { - case *ssh3Messages.DataOrExtendedDataMessage: - _, err = c.Write([]byte(message.Data)) - if err != nil { - err = fmt.Errorf("error when writing on unix socker for agent forwarding channel %d: %s", channel.ChannelID(), err.Error()) - cancel(err) - return - } - default: - err = fmt.Errorf("unhandled message type on agent channel %d: %T", channel.ChannelID(), message) - cancel(err) - return - } - } + if udpConn == nil { + udpConn, err = net.ListenUDP("udp", nil) + if err != nil { + log.Error().Msgf("could not create UDP connection: %s", err) + return nil, -1 } - }() + } - buf := make([]byte, channel.MaxPacketSize()) - for { - select { - case <-ctx.Done(): - err = context.Cause(ctx) - if err != nil { - log.Error().Msgf("ending agent forwarding on channel %d: %s", channel.ChannelID(), err.Error()) - } - return err - default: - n, err := c.Read(buf) - if err == io.EOF { - log.Debug().Msgf("unix socket for ssh agent closed") - return nil - } else if err != nil { - cancel(err) - log.Error().Msgf("could not read on unix socket: %s", err.Error()) - return err - } - _, err = channel.WriteData(buf[:n], ssh3Messages.SSH_EXTENDED_DATA_NONE) - if err != nil { - cancel(err) - log.Error().Msgf("could not write on ssh channel: %s", err.Error()) - return err - } - } + tlsConf := &tls.Config{ + RootCAs: certPool, + InsecureSkipVerify: skipHostVerification, + NextProtos: []string{http3.NextProtoH3}, + KeyLogWriter: keylog, + ServerName: options.Hostname(), } -} -func forwardTCPInBackground(ctx context.Context, channel ssh3.Channel, conn *net.TCPConn) { - go func() { - defer conn.CloseWrite() - for { - select { - case <-ctx.Done(): - return - default: - } - genericMessage, err := channel.NextMessage() - if err == io.EOF { - log.Info().Msgf("eof on tcp-forwarding channel %d", channel.ChannelID()) - } else if err != nil { - log.Error().Msgf("could get message from tcp forwarding channel: %s", err) - return - } + var qconf quic.Config - // nothing to process - if genericMessage == nil { - return - } + qconf.MaxIncomingStreams = 10 + qconf.Allow0RTT = true + qconf.EnableDatagrams = true + qconf.KeepAlivePeriod = 1 * time.Second - switch message := genericMessage.(type) { - case *ssh3Messages.DataOrExtendedDataMessage: - if message.DataType == ssh3Messages.SSH_EXTENDED_DATA_NONE { - _, err := conn.Write([]byte(message.Data)) - if err != nil { - log.Error().Msgf("could not write data on TCP socket: %s", err) - // signal the write error to the peer - channel.CancelRead() - return - } - } else { - log.Warn().Msgf("ignoring message data of unexpected type %d on TCP forwarding channel %d", message.DataType, channel.ChannelID()) - } - default: - log.Warn().Msgf("ignoring message of type %T on TCP forwarding channel %d", message, channel.ChannelID()) + if certs, ok := knownHosts[options.CanonicalHostFormat()]; ok { + foundSelfsignedSSH3 := false + + for _, cert := range certs { + certPool.AddCert(cert) + if cert.VerifyHostname("selfsigned.ssh3") == nil { + foundSelfsignedSSH3 = true } } - }() - go func() { - defer channel.Close() - defer conn.CloseRead() - buf := make([]byte, channel.MaxPacketSize()) - for { - select { - case <-ctx.Done(): - return - default: - } - n, err := conn.Read(buf) - if err != nil && err != io.EOF { - log.Error().Msgf("could read data on TCP socket: %s", err) - return - } - _, errWrite := channel.WriteData(buf[:n], ssh3Messages.SSH_EXTENDED_DATA_NONE) - if errWrite != nil { - switch quicErr := errWrite.(type) { - case *quic.StreamError: - if quicErr.Remote && quicErr.ErrorCode == 42 { - log.Info().Msgf("writing was canceled by the remote, closing the socket: %s", errWrite) - } else { - log.Error().Msgf("unhandled quic stream error: %+v", quicErr) + // If no IP SAN was in the cert, then assume the self-signed cert at least matches the .ssh3 TLD + if foundSelfsignedSSH3 { + // Put "ssh3" as ServerName so that the TLS verification can succeed + // Otherwise, TLS refuses to validate a certificate without IP SANs + // if the hostname is an IP address. + tlsConf.ServerName = "selfsigned.ssh3" + } + } + + log.Debug().Msgf("dialing QUIC host at %s", remoteAddr) + qClient, err := quic.DialEarly(ctx, + udpConn, + remoteAddr, + tlsConf, + &qconf) + if err != nil { + if transportErr, ok := err.(*quic.TransportError); ok { + if transportErr.ErrorCode.IsCryptoError() { + log.Debug().Msgf("received QUIC crypto error on first connection attempt: %s", err) + if tty == nil { + log.Error().Msgf("insecure server cert in non-terminal session, aborting") + return nil, -1 + } + if _, ok := knownHosts[options.CanonicalHostFormat()]; ok { + log.Error().Msgf("The server certificate cannot be verified using the one installed in %s. "+ + "If you did not change the server certificate, it could be a machine-in-the-middle attack. "+ + "TLS error: %s", knownHostsPath, err) + log.Error().Msgf("Aborting.") + return nil, -1 + } + // bad certificates, let's mimic the OpenSSH's behaviour similar to host keys + tlsConf.InsecureSkipVerify = true + var peerCertificate *x509.Certificate + certError := fmt.Errorf("we don't want to start a totally insecure connection") + tlsConf.VerifyConnection = func(ctx tls.ConnectionState) error { + peerCertificate = ctx.PeerCertificates[0] + return certError + } + + _, err := quic.DialEarly(ctx, + udpConn, + remoteAddr, + tlsConf, + &qconf) + if !errors.Is(err, certError) { + log.Error().Msgf("could not create client QUIC connection: %s", err) + return nil, -1 + } + // let's first check that the certificate is self-signed + if err := peerCertificate.CheckSignatureFrom(peerCertificate); err != nil { + log.Error().Msgf("the peer provided an unknown, insecure certificate, that is not self-signed: %s", err) + return nil, -1 + } + // first, carriage return + _, _ = tty.WriteString("\r") + _, err = tty.WriteString("Received an unknown self-signed certificate from the server.\n\r" + + "We recommend not using self-signed certificates.\n\r" + + "This session is vulnerable a machine-in-the-middle attack.\n\r" + + "Certificate fingerprint: " + + "SHA256 " + util.Sha256Fingerprint(peerCertificate.Raw) + "\n\r" + + "Do you want to add this certificate to ~/.ssh3/known_hosts (yes/no)? ") + if err != nil { + log.Error().Msgf("cound not write on /dev/tty: %s", err) + return nil, -1 + } + + answer := "" + reader := bufio.NewReader(tty) + for { + answer, _ = reader.ReadString('\n') + answer = strings.TrimSpace(answer) + _, _ = tty.WriteString("\r") // always ensure a carriage return + if answer == "yes" || answer == "no" { + break } - default: - log.Error().Msgf("could send data on channel: %s", errWrite) + tty.WriteString("Invalid answer, answer \"yes\" or \"no\" ") } - return - } - if err == io.EOF { - return + if answer == "no" { + log.Info().Msg("Connection aborted") + return nil, 0 + } + if err := ssh3.AppendKnownHost(knownHostsPath, options.CanonicalHostFormat(), peerCertificate); err != nil { + log.Error().Msgf("could not append known host to %s: %s", knownHostsPath, err) + return nil, -1 + } + tty.WriteString(fmt.Sprintf("Successfully added the certificate to %s, please rerun the command\n\r", knownHostsPath)) + return nil, 0 } } - }() + log.Error().Msgf("could not establish client QUIC connection: %s", err) + return nil, -1 + } + + return qClient, 0 } func parseAddrPort(addrPort string) (localPort int, remoteIP net.IP, remotePort int, err error) { @@ -228,9 +204,60 @@ func parseAddrPort(addrPort string) (localPort int, remoteIP net.IP, remotePort return localPort, remoteIP, remotePort, err } +func getConfigOptions(hostUrl *url.URL, sshConfig *ssh_config.Config) (*client.Options, error) { + urlHostname, urlPort := hostUrl.Hostname(), hostUrl.Port() + + configHostname, configPort, configUser, configAuthMethods, err := ssh3.GetConfigForHost(urlHostname, sshConfig) + if err != nil { + log.Error().Msgf("could not get config for %s: %s", urlHostname, err) + return nil, err + } + + hostname := configHostname + if hostname == "" { + hostname = urlHostname + } + + port := 443 + if urlPort != "" { + if parsedPort, err := strconv.Atoi(urlPort); err == nil && parsedPort < 0xffff { + // There is a port in the CLI and the port is valid. Use the CLI port. + port = parsedPort + } else { + // There is a port in the CLI but it is not valid. + // use WithLevel(zerolog.FatalLevel) to log a fatal level, but let us handle + // program termination. log.Fatal() exits with os.Exit(1). + log.WithLevel(zerolog.FatalLevel).Str("Port", urlPort).Err(err).Msg("cli contains an invalid port") + fmt.Fprintf(os.Stderr, "Bad port '%s'\n", urlPort) + return nil, err + } + } else if configPort != -1 { + // There is no port in the CLI, but one in a config file. Use the config port. + port = configPort + } + + username := hostUrl.User.Username() + if username == "" { + username = hostUrl.Query().Get("user") + } + if username == "" { + username = configUser + } + if username == "" { + u, err := osuser.Current() + if err == nil { + username = u.Username + } else { + log.Error().Msgf("could not get current username: %s", err) + } + } + if username == "" { + return nil, fmt.Errorf("no username could be found") + } + return client.NewOptions(username, hostname, port, hostUrl.Path, configAuthMethods) +} + func mainWithStatusCode() int { - // verbose := flag.Bool("v", false, "verbose") - // quiet := flag.Bool("q", false, "don't print the data") keyLogFile := flag.String("keylog", "", "Write QUIC TLS keys and master secret in the specified keylog file: only for debugging purpose") privKeyFile := flag.String("privkey", "", "private key file") pubkeyForAgent := flag.String("pubkey-for-agent", "", "if set, use an agent key whose public key matches the one in the specified path") @@ -243,7 +270,6 @@ func mainWithStatusCode() int { forwardSSHAgent := flag.Bool("forward-agent", false, "if set, forwards ssh agent to be used with sshv2 connections on the remote host") forwardUDP := flag.String("forward-udp", "", "if set, take a localport/remoteip@remoteport forwarding localhost@localport towards remoteip@remoteport") forwardTCP := flag.String("forward-tcp", "", "if set, take a localport/remoteip@remoteport forwarding localhost@localport towards remoteip@remoteport") - // enableQlog := flag.Bool("qlog", false, "output a qlog (in the same directory)") flag.Parse() args := flag.Args() @@ -392,124 +418,23 @@ func mainWithStatusCode() int { keyLog = f } - parsedUrl, err := url.Parse(urlFromParam) + pool, err := x509.SystemCertPool() if err != nil { log.Fatal().Msgf("%s", err) } - urlHostname, urlPort := parsedUrl.Hostname(), parsedUrl.Port() - - configHostname, configPort, configUser, configAuthMethods, err := ssh3.GetConfigForHost(urlHostname, sshConfig) + parsedUrl, err := url.Parse(urlFromParam) if err != nil { - log.Error().Msgf("could not get config for %s: %s", urlHostname, err) + log.Error().Msgf("could not parse URL: %s", err) return -1 } - - hostname := configHostname - if hostname == "" { - hostname = urlHostname - } - - hostnameIsAnIP := net.ParseIP(hostname) != nil - - var port int - if urlPort != "" { - if parsedPort, err := strconv.Atoi(urlPort); err == nil && parsedPort < 0xffff { - // There is a port in the CLI and the port is valid. Use the CLI port. - port = parsedPort - } else { - // There is a port in the CLI but it is not valid. - // use WithLevel(zerolog.FatalLevel) to log a fatal level, but let us handle - // program termination. log.Fatal() exits with os.Exit(1). - log.WithLevel(zerolog.FatalLevel).Str("Port", urlPort).Err(err).Msg("cli contains an invalid port") - fmt.Fprintf(os.Stderr, "Bad port '%s'\n", urlPort) - return -1 - } - } else if configPort != -1 { - // There is no port in the CLI, but one in a config file. Use the config port. - port = configPort - } else { - // There is no port specified, neither in the CLI, nor in the configuration. - port = 443 - } - - username := parsedUrl.User.Username() - if username == "" { - username = parsedUrl.Query().Get("user") - } - if username == "" { - username = configUser - } - if username == "" { - u, err := osuser.Current() - if err == nil { - username = u.Username - } else { - log.Error().Msgf("could not get current username: %s", err) - } - } - if username == "" { - log.Error().Msgf("no username could be found") - return -1 - } - - urlQuery := parsedUrl.Query() - urlQuery.Set("user", username) - parsedUrl.RawQuery = urlQuery.Encode() - requestUrl := parsedUrl.String() - - pool, err := x509.SystemCertPool() + configOptions, err := getConfigOptions(parsedUrl, sshConfig) if err != nil { - log.Fatal().Msgf("%s", err) - } - - tlsConf := &tls.Config{ - RootCAs: pool, - InsecureSkipVerify: *insecure, - KeyLogWriter: keyLog, - NextProtos: []string{http3.NextProtoH3}, - } - - if certs, ok := knownHosts[hostname]; ok { - foundSelfsignedSSH3 := false - - for _, cert := range certs { - pool.AddCert(cert) - if cert.VerifyHostname("selfsigned.ssh3") == nil { - foundSelfsignedSSH3 = true - } - } - - // If no IP SAN was in the cert, then assume the self-signed cert at least matches the .ssh3 TLD - if foundSelfsignedSSH3 { - // Put "ssh3" as ServerName so that the TLS verification can succeed - // Otherwise, TLS refuses to validate a certificate without IP SANs - // if the hostname is an IP address. - tlsConf.ServerName = "selfsigned.ssh3" - } - } - - var qconf quic.Config - - qconf.MaxIncomingStreams = 10 - qconf.Allow0RTT = true - qconf.EnableDatagrams = true - qconf.KeepAlivePeriod = 1 * time.Second - - roundTripper := &http3.RoundTripper{ - TLSClientConfig: tlsConf, - QuicConfig: &qconf, - EnableDatagrams: true, + log.Error().Msgf("Could not apply config to %s: %s", parsedUrl, err) + return -1 } - ctx, _ := context.WithCancelCause(context.Background()) - - defer roundTripper.Close() - - // connect to SSH agent if it exists var agentClient agent.ExtendedAgent - var agentKeys []ssh.PublicKey - socketPath := os.Getenv("SSH_AUTH_SOCK") if socketPath != "" { conn, err := net.Dial("unix", socketPath) @@ -518,136 +443,9 @@ func mainWithStatusCode() int { return -1 } agentClient = agent.NewClient(conn) - keys, err := agentClient.List() - if err != nil { - log.Error().Msgf("Failed to list agent keys: %s", err) - return -1 - } - for _, key := range keys { - agentKeys = append(agentKeys, key) - } } - log.Debug().Msgf("dialing QUIC host at %s", fmt.Sprintf("%s:%d", hostname, port)) - - if hostnameIsAnIP { - ip := net.ParseIP(hostname) - if ip.To4() == nil && ip.To16() != nil { - // enforce the square-bracketed notation for ipv6 UDP addresses - hostname = fmt.Sprintf("[%s]", hostname) - } - } - - qClient, err := quic.DialAddrEarly(ctx, - fmt.Sprintf("%s:%d", hostname, port), - tlsConf, - &qconf) - if err != nil { - if transportErr, ok := err.(*quic.TransportError); ok { - if transportErr.ErrorCode.IsCryptoError() { - log.Debug().Msgf("received QUIC crypto error on first connection attempt: %s", err) - if tty == nil { - log.Error().Msgf("insecure server cert in non-terminal session, aborting") - return -1 - } - if _, ok = knownHosts[hostname]; ok { - log.Error().Msgf("The server certificate cannot be verified using the one installed in %s. "+ - "If you did not change the server certificate, it could be a machine-in-the-middle attack. "+ - "TLS error: %s", knownHostsPath, err) - log.Error().Msgf("Aborting.") - return -1 - } - // bad certificates, let's mimic the OpenSSH's behaviour similar to host keys - tlsConf.InsecureSkipVerify = true - var peerCertificate *x509.Certificate - certError := fmt.Errorf("we don't want to start a totally insecure connection") - tlsConf.VerifyConnection = func(ctx tls.ConnectionState) error { - peerCertificate = ctx.PeerCertificates[0] - return certError - } - - _, err := quic.DialAddrEarly(ctx, - fmt.Sprintf("%s:%d", hostname, port), - tlsConf, - &qconf) - if !errors.Is(err, certError) { - log.Error().Msgf("could not create client QUIC connection: %s", err) - return -1 - } - // let's first check that the certificate is self-signed - if err := peerCertificate.CheckSignatureFrom(peerCertificate); err != nil { - log.Error().Msgf("the peer provided an unknown, insecure certificate, that is not self-signed: %s", err) - return -1 - } - // first, carriage return - _, _ = tty.WriteString("\r") - _, err = tty.WriteString("Received an unknown self-signed certificate from the server.\n\r" + - "We recommend not using self-signed certificates.\n\r" + - "This session is vulnerable a machine-in-the-middle attack.\n\r" + - "Certificate fingerprint: " + - "SHA256 " + util.Sha256Fingerprint(peerCertificate.Raw) + "\n\r" + - "Do you want to add this certificate to ~/.ssh3/known_hosts (yes/no)? ") - if err != nil { - log.Error().Msgf("cound not write on /dev/tty: %s", err) - return -1 - } - - answer := "" - reader := bufio.NewReader(tty) - for { - answer, _ = reader.ReadString('\n') - answer = strings.TrimSpace(answer) - _, _ = tty.WriteString("\r") // always ensure a carriage return - if answer == "yes" || answer == "no" { - break - } - tty.WriteString("Invalid answer, answer \"yes\" or \"no\" ") - } - if answer == "no" { - log.Info().Msg("Connection aborted") - return 0 - } - if err := ssh3.AppendKnownHost(knownHostsPath, hostname, peerCertificate); err != nil { - log.Error().Msgf("could not append known host to %s: %s", knownHostsPath, err) - return -1 - } - tty.WriteString(fmt.Sprintf("Successfully added the certificate to %s, please rerun the command\n\r", knownHostsPath)) - return 0 - } - } - log.Error().Msgf("could not establish client QUIC connection: %s", err) - return -1 - } - - // dirty hack: ensure only one QUIC connection is used - roundTripper.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - return qClient, nil - } - - // Do 0RTT GET requests here if needed - // Currently, we don't need it but we could use it to retrieve - // config or version info from the server - // We could also allow user-defined safe/idempotent commands to run with 0-RTT - qClient.HandshakeComplete() - log.Debug().Msgf("QUIC handshake complete") - // Now, we're 1-RTT, we can get the TLS exporter and create the conversation - tls := qClient.ConnectionState().TLS - conv, err := ssh3.NewClientConversation(30000, 10, &tls) - if err != nil { - log.Error().Msgf("could not create new client conversation: %s", err) - return -1 - } - - // the connection struct is created, now build the request used to establish the connection - req, err := http.NewRequest("CONNECT", requestUrl, nil) - if err != nil { - log.Fatal().Msgf("%s", err) - } - req.Proto = "ssh3" - req.Header.Set("User-Agent", ssh3.GetCurrentVersion()) - var authMethods []interface{} - // Only do privkey and agent auth if OIDC is not asked explicitly if !useOIDC { if *privKeyFile != "" { @@ -670,13 +468,7 @@ func mainWithStatusCode() int { log.Error().Msgf("could not parse public key: %s", err) return -1 } - } - - for _, candidateKey := range agentKeys { - if pubkey == nil || bytes.Equal(candidateKey.Marshal(), pubkey.Marshal()) { - log.Debug().Msgf("found key in agent: %s", candidateKey) - authMethods = append(authMethods, ssh3.NewAgentAuthMethod(candidateKey)) - } + authMethods = append(authMethods, ssh3.NewAgentAuthMethod(pubkey)) } } } @@ -694,12 +486,12 @@ func mainWithStatusCode() int { } } } else { - log.Error().Msgf("OIDC was asked explicitly bit did not find suitable issuer URL") + log.Error().Msgf("OIDC was asked explicitly but did not find suitable issuer URL") return -1 } } - authMethods = append(authMethods, configAuthMethods...) + authMethods = append(authMethods, configOptions.AuthMethods()...) if *issuerUrl == "" { for _, issuerConfig := range oidcConfig { @@ -707,351 +499,58 @@ func mainWithStatusCode() int { } } - var identity ssh3.Identity - for _, method := range authMethods { - switch m := method.(type) { - case *ssh3.PasswordAuthMethod: - fmt.Printf("password for %s:", parsedUrl.String()) - password, err := term.ReadPassword(int(syscall.Stdin)) - fmt.Println() - if err != nil { - log.Error().Msgf("could not get password: %s", err) - return -1 - } - identity = m.IntoIdentity(string(password)) - case *ssh3.PrivkeyFileAuthMethod: - identity, err = m.IntoIdentityWithoutPassphrase() - // could not identify without passphrase, try agent authentication by using the key's public key - if passphraseErr, ok := err.(*ssh.PassphraseMissingError); ok { - // the pubkey may be contained in the privkey file - pubkey := passphraseErr.PublicKey - if pubkey == nil { - // if it is not the case, try to find a .pub equivalent, like OpenSSH does - pubkeyBytes, err := os.ReadFile(fmt.Sprintf("%s.pub", m.Filename())) - if err == nil { - filePubkey, _, _, _, err := ssh.ParseAuthorizedKey(pubkeyBytes) - if err == nil { - pubkey = filePubkey - } - } - } - - // now, try to see of the agent manages this key - foundAgentKey := false - if pubkey != nil { - for _, agentKey := range agentKeys { - if bytes.Equal(agentKey.Marshal(), pubkey.Marshal()) { - log.Debug().Msgf("found key in agent: %s", agentKey) - identity = ssh3.NewAgentAuthMethod(pubkey).IntoIdentity(agentClient) - foundAgentKey = true - break - } - } - } - - // key not handled by agent, let's try to decrypt it ourselves - if !foundAgentKey { - fmt.Printf("passphrase for private key stored in %s:", m.Filename()) - var passphraseBytes []byte - passphraseBytes, err = term.ReadPassword(int(syscall.Stdin)) - fmt.Println() - if err != nil { - log.Error().Msgf("could not get passphrase: %s", err) - return -1 - } - passphrase := string(passphraseBytes) - identity, err = m.IntoIdentityPassphrase(passphrase) - if err != nil { - log.Error().Msgf("could not load private key: %s", err) - return -1 - } - } - } else if err != nil { - log.Warn().Msgf("Could not load private key: %s", err) - } - case *ssh3.AgentAuthMethod: - identity = m.IntoIdentity(agentClient) - case *ssh3.OidcAuthMethod: - token, err := auth.Connect(context.Background(), m.OIDCConfig(), m.OIDCConfig().IssuerUrl, *doPKCE) - if err != nil { - log.Error().Msgf("could not get token: %s", err) - return -1 - } - identity = m.IntoIdentity(token) - } - // currently only tries a single identity (the first one), but the goal is to - // try several identities, similarly to OpenSSH - break - } - - if identity == nil { - log.Error().Msg("no suitable identity found") - return -1 - } - - log.Debug().Msgf("try the following Identity: %s", identity) - err = identity.SetAuthorizationHeader(req, username, conv) + options, err := client.NewOptions(configOptions.Username(), configOptions.Hostname(), configOptions.Port(), configOptions.UrlPath(), authMethods) if err != nil { - log.Error().Msgf("could not set authorization header in HTTP request: %s", err) - } - - log.Debug().Msgf("send CONNECT request to the server") - err = conv.EstablishClientConversation(req, roundTripper) - if errors.Is(err, util.Unauthorized{}) { - log.Error().Msgf("Access denied from the server: unauthorized") - return -1 - } else if err != nil { - log.Error().Msgf("Could not open channel: %+v", err) + log.Error().Msgf("could not instantiate invalid options: %s", err) return -1 } - ctx = conv.Context() + ctx := context.Background() - channel, err := conv.OpenChannel("session", 30000, 0) - if err != nil { - fmt.Fprintf(os.Stderr, "Could not open channel: %+v", err) - os.Exit(-1) - } + qconn, status := setupQUICConnection(ctx, *insecure, keyLog, ssh3Dir, pool, knownHostsPath, knownHosts, oidcConfig, options, nil, tty) - log.Debug().Msgf("opened new session channel") - - if *forwardSSHAgent { - _, err := channel.WriteData([]byte("forward-agent"), ssh3Messages.SSH_EXTENDED_DATA_NONE) - if err != nil { - log.Error().Msgf("could not forward agent: %s", err.Error()) - return -1 + if qconn == nil { + if status != 0 { + log.Error().Msgf("could not setup transport for client: %s", err) } - go func() { - for { - forwardChannel, err := conv.AcceptChannel(ctx) - if err != nil { - if err != context.Canceled { - log.Error().Msgf("could not accept forwarding channel: %s", err.Error()) - } - return - } else if forwardChannel.ChannelType() != "agent-connection" { - log.Error().Msgf("unexpected server-initiated channel: %s", channel.ChannelType()) - return - } - log.Debug().Msg("new agent connection, forwarding") - go func() { - err = forwardAgent(ctx, forwardChannel) - if err != nil { - log.Error().Msgf("agent forwarding error: %s", err.Error()) - conv.Close() - } - }() - } - }() + return status } - if len(command) == 0 { - // avoid requesting a pty on the other side if stdin is not a pty - // similar behaviour to OpenSSH - isATTY := term.IsTerminal(int(os.Stdin.Fd())) - if isATTY { - windowSize, err := winsize.GetWinsize() - if err != nil { - fmt.Fprintf(os.Stderr, "Could not get window size: %+v", err) - os.Exit(-1) - } - err = channel.SendRequest( - &ssh3Messages.ChannelRequestMessage{ - WantReply: true, - ChannelRequest: &ssh3Messages.PtyRequest{ - Term: os.Getenv("TERM"), - CharWidth: uint64(windowSize.NCols), - CharHeight: uint64(windowSize.NRows), - PixelWidth: uint64(windowSize.PixelWidth), - PixelHeight: uint64(windowSize.PixelHeight), - }, - }, - ) - - if err != nil { - fmt.Fprintf(os.Stderr, "Could send pty request: %+v", err) - return -1 - } - log.Debug().Msgf("sent pty request for session") - } - - err = channel.SendRequest( - &ssh3Messages.ChannelRequestMessage{ - WantReply: true, - ChannelRequest: &ssh3Messages.ShellRequest{}, - }, - ) - log.Debug().Msgf("sent shell request") - // avoid making the terminal raw if stdin is not a TTY - // similar behaviour to OpenSSH - if isATTY { - fd := os.Stdin.Fd() - oldState, err := term.MakeRaw(int(fd)) - if err != nil { - log.Fatal().Msgf("%s", err) - } - defer term.Restore(int(fd), oldState) - } - } else { - channel.SendRequest( - &ssh3Messages.ChannelRequestMessage{ - WantReply: true, - ChannelRequest: &ssh3Messages.ExecRequest{ - Command: strings.Join(command, " "), - }, - }, - ) - log.Debug().Msgf("sent exec request for command \"%s\"", strings.Join(command, " ")) + roundTripper := &http3.RoundTripper{ + EnableDatagrams: true, } + c, err := client.Dial(ctx, options, qconn, roundTripper, agentClient) if err != nil { - fmt.Fprintf(os.Stderr, "Could send shell request: %+v", err) + log.Error().Msgf("could not establish SSH3 conversation: %s", err) return -1 } - - go func() { - buf := make([]byte, channel.MaxPacketSize()) - for { - n, err := os.Stdin.Read(buf) - if n > 0 { - _, err2 := channel.WriteData(buf[:n], ssh3Messages.SSH_EXTENDED_DATA_NONE) - if err2 != nil { - fmt.Fprintf(os.Stderr, "could not write data on channel: %+v", err2) - return - } - } - if err != nil { - fmt.Fprintf(os.Stderr, "could not read data from stdin: %+v", err) - return - } - } - }() - - if localUDPAddr != nil && remoteUDPAddr != nil { - log.Debug().Msgf("start forwarding from %s to %s", localUDPAddr, remoteUDPAddr) - conn, err := net.ListenUDP("udp", localUDPAddr) + if localTCPAddr != nil && remoteTCPAddr != nil { + err := c.ForwardTCP(ctx, localTCPAddr, remoteTCPAddr) if err != nil { - log.Error().Msgf("could listen on UDP socket: %s", err) + log.Error().Msgf("could not forward UDP: %s", err) return -1 } - forwardings := make(map[string]ssh3.Channel) - go func() { - buf := make([]byte, 1500) - for { - n, addr, err := conn.ReadFromUDP(buf) - if err != nil { - log.Error().Msgf("could read on UDP socket: %s", err) - return - } - channel, ok := forwardings[addr.String()] - if !ok { - channel, err = conv.OpenUDPForwardingChannel(30000, 10, localUDPAddr, remoteUDPAddr) - if err != nil { - log.Error().Msgf("could open new UDP forwarding channel: %s", err) - return - } - forwardings[addr.String()] = channel - - go func() { - for { - dgram, err := channel.ReceiveDatagram(ctx) - if err != nil { - log.Error().Msgf("could open receive datagram on channel: %s", err) - return - } - _, err = conn.WriteToUDP(dgram, addr) - if err != nil { - log.Error().Msgf("could open write datagram on socket: %s", err) - return - } - } - }() - } - err = channel.SendDatagram(buf[:n]) - if err != nil { - log.Error().Msgf("could not send datagram: %s", err) - return - } - } - }() } - - if localTCPAddr != nil && remoteTCPAddr != nil { - log.Debug().Msgf("start forwarding from %s to %s", localTCPAddr, remoteTCPAddr) - conn, err := net.ListenTCP("tcp", localTCPAddr) + if localUDPAddr != nil && remoteUDPAddr != nil { + err := c.ForwardUDP(ctx, localUDPAddr, remoteUDPAddr) if err != nil { - log.Error().Msgf("could listen on TCP socket: %s", err) + log.Error().Msgf("could not forward UDP: %s", err) return -1 } - go func() { - for { - conn, err := conn.AcceptTCP() - if err != nil { - log.Error().Msgf("could read on UDP socket: %s", err) - return - } - forwardingChannel, err := conv.OpenTCPForwardingChannel(30000, 10, localTCPAddr, remoteTCPAddr) - if err != nil { - log.Error().Msgf("could open new UDP forwarding channel: %s", err) - return - } - forwardTCPInBackground(ctx, forwardingChannel, conn) - } - }() } - defer conv.Close() - defer fmt.Printf("\r") - - for { - genericMessage, err := channel.NextMessage() - if err != nil { - fmt.Fprintf(os.Stderr, "Could not get message: %+v\n", err) - os.Exit(-1) - } - switch message := genericMessage.(type) { - case *ssh3Messages.ChannelRequestMessage: - switch requestMessage := message.ChannelRequest.(type) { - case *ssh3Messages.PtyRequest: - fmt.Fprintf(os.Stderr, "receiving a pty request on the client is not implemented\n") - case *ssh3Messages.X11Request: - fmt.Fprintf(os.Stderr, "receiving a x11 request on the client is not implemented\n") - case *ssh3Messages.ShellRequest: - fmt.Fprintf(os.Stderr, "receiving a shell request on the client is not implemented\n") - case *ssh3Messages.ExecRequest: - fmt.Fprintf(os.Stderr, "receiving a exec request on the client is not implemented\n") - case *ssh3Messages.SubsystemRequest: - fmt.Fprintf(os.Stderr, "receiving a subsystem request on the client is not implemented\n") - case *ssh3Messages.WindowChangeRequest: - fmt.Fprintf(os.Stderr, "receiving a windowchange request on the client is not implemented\n") - case *ssh3Messages.SignalRequest: - fmt.Fprintf(os.Stderr, "receiving a signal request on the client is not implemented\n") - case *ssh3Messages.ExitStatusRequest: - log.Info().Msgf("ssh3: process exited with status: %d\n", requestMessage.ExitStatus) - // forward the process' status code to the user - return int(requestMessage.ExitStatus) - case *ssh3Messages.ExitSignalRequest: - log.Info().Msgf("ssh3: process exited with signal: %s: %s\n", requestMessage.SignalNameWithoutSig, requestMessage.ErrorMessageUTF8) - return -1 - } - case *ssh3Messages.DataOrExtendedDataMessage: - switch message.DataType { - case ssh3Messages.SSH_EXTENDED_DATA_NONE: - _, err = os.Stdout.Write([]byte(message.Data)) - if err != nil { - log.Fatal().Msgf("%s", err) - } - - log.Debug().Msgf("received data %s", message.Data) - case ssh3Messages.SSH_EXTENDED_DATA_STDERR: - _, err = os.Stderr.Write([]byte(message.Data)) - if err != nil { - log.Fatal().Msgf("%s", err) - } - - log.Debug().Msgf("received stderr data %s", message.Data) - } - } + err = c.RunSession(tty, *forwardSSHAgent, command...) + switch sessionError := err.(type) { + case client.ExitStatus: + log.Info().Msgf("the process exited with status %d", sessionError.StatusCode) + return sessionError.StatusCode + case client.ExitSignal: + log.Error().Msgf("the process exited with signal %s: %s", sessionError.Signal, sessionError.ErrorMessageUTF8) + return -1 + default: + log.Error().Msgf("an error was encountered when running the session: %s", sessionError) + return -1 } } diff --git a/cmd/ssh3/winsize/winsize.go b/cmd/ssh3/winsize/winsize.go deleted file mode 100644 index a47edbe..0000000 --- a/cmd/ssh3/winsize/winsize.go +++ /dev/null @@ -1,17 +0,0 @@ -//go:build !windows - -package winsize - -import ( - "syscall" - "unsafe" -) - -func GetWinsize() (ws WindowSize, err error) { - _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(syscall.Stdin), uintptr(syscall.TIOCGWINSZ), - uintptr(unsafe.Pointer(&ws))) - if errno != 0 { - err = errno - } - return ws, err -} diff --git a/integration_tests/ssh3_test.go b/integration_tests/ssh3_test.go index bc11538..cb80512 100644 --- a/integration_tests/ssh3_test.go +++ b/integration_tests/ssh3_test.go @@ -111,6 +111,7 @@ var _ = Describe("Testing the ssh3 cli", func() { var clientArgs []string getClientArgs := func(privKeyPath string, additionalArgs ...string) []string { args := []string{ + "-v", "-insecure", "-privkey", privKeyPath, } diff --git a/known_hosts.go b/known_hosts.go index be8d544..d8ffb14 100644 --- a/known_hosts.go +++ b/known_hosts.go @@ -10,6 +10,16 @@ import ( "syscall" ) +type KnownHosts map[string][]*x509.Certificate + +func (kh KnownHosts) Knows(hostname string) bool { + if len(kh) == 0 { + return false + } + _, ok := kh[hostname] + return ok +} + type InvalidKnownHost struct { line string } @@ -18,7 +28,7 @@ func (e InvalidKnownHost) Error() string { return fmt.Sprintf("invalid known host line: %s", e.line) } -func ParseKnownHosts(filename string) (knownHosts map[string][]*x509.Certificate, invalidLines []int, err error) { +func ParseKnownHosts(filename string) (knownHosts KnownHosts, invalidLines []int, err error) { knownHosts = make(map[string][]*x509.Certificate) file, err := os.Open(filename) if os.IsNotExist(err) {