From b95136da24c447c7d766040124e8cdb782746f55 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Wed, 30 Mar 2022 12:29:44 +0200 Subject: [PATCH] ssh: add support for extension negotiation (rfc 8308) This is a rebase of the following PR https://github.com/golang/crypto/pull/197 with some changes and improvements: - added support for client certificate authentication - removed read loop from server handshake - adapted extInfoMsg to upstream changes Signed-off-by: Nicola Murino --- ssh/common.go | 33 +++++++++++++++++++++++++++++++++ ssh/handshake.go | 37 ++++++++++++++++++++++++++++++++++++- ssh/server.go | 30 ++++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+), 1 deletion(-) diff --git a/ssh/common.go b/ssh/common.go index 44f71de3e1..3035060c1c 100644 --- a/ssh/common.go +++ b/ssh/common.go @@ -25,6 +25,18 @@ const ( serviceSSH = "ssh-connection" ) +// These are string constants related to extensions and extension negotiation +const ( + extInfoServer = "ext-info-s" + extInfoClient = "ext-info-c" + ExtServerSigAlgs = "server-sig-algs" +) + +// defaultExtensions lists extensions enabled by default. +var defaultExtensions = []string{ + ExtServerSigAlgs, +} + // supportedCiphers lists ciphers we support but might not recommend. var supportedCiphers = []string{ "aes128-ctr", "aes192-ctr", "aes256-ctr", @@ -90,6 +102,15 @@ var supportedMACs = []string{ var supportedCompressions = []string{compressionNone} +// supportedServerSigAlgs defines the algorithms supported for pubkey authentication +// in no particular order. +var supportedServerSigAlgs = []string{KeyAlgoRSASHA256, + KeyAlgoRSASHA512, KeyAlgoRSA, + KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, + KeyAlgoSKECDSA256, KeyAlgoED25519, KeyAlgoSKED25519, + KeyAlgoDSA, +} + // hashFuncs keeps the mapping of supported signature algorithms to their // respective hashes needed for signing and verification. var hashFuncs = map[string]crypto.Hash{ @@ -203,6 +224,10 @@ func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMs result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) if err != nil { return + } else if result.kex == extInfoClient || result.kex == extInfoServer { + // According to RFC8308 section 2.2 if either the client or server extension signal + // is chosen as the kex algorithm the parties must disconnect. + return result, fmt.Errorf("ssh: invalid kex algorithm chosen: %s", result.kex) } result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) @@ -280,6 +305,10 @@ type Config struct { // The allowed MAC algorithms. If unspecified then a sensible default // is used. MACs []string + + // A list of enabled extensions. If unspecified then a sensible + // default is used + Extensions []string } // SetDefaults sets sensible values for unset fields in config. This is @@ -309,6 +338,10 @@ func (c *Config) SetDefaults() { c.MACs = supportedMACs } + if c.Extensions == nil { + c.Extensions = defaultExtensions + } + if c.RekeyThreshold == 0 { // cipher specific default } else if c.RekeyThreshold < minRekeyThreshold { diff --git a/ssh/handshake.go b/ssh/handshake.go index 07a1843e0a..476246cf70 100644 --- a/ssh/handshake.go +++ b/ssh/handshake.go @@ -11,6 +11,7 @@ import ( "io" "log" "net" + "strings" "sync" ) @@ -95,6 +96,10 @@ type handshakeTransport struct { // The session ID or nil if first kex did not complete yet. sessionID []byte + + // True if the first ext info message has been sent immediately following + // SSH_MSG_NEWKEYS, false otherwise. + extInfoSent bool } type pendingKex struct { @@ -477,6 +482,9 @@ func (t *handshakeTransport) sendKexInit() error { msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat) } } + if contains(t.config.Extensions, ExtServerSigAlgs) { + msg.KexAlgos = append(msg.KexAlgos, extInfoServer) + } } else { msg.ServerHostKeyAlgos = t.hostKeyAlgorithms @@ -486,7 +494,7 @@ func (t *handshakeTransport) sendKexInit() error { if firstKeyExchange := t.sessionID == nil; firstKeyExchange { msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+1) msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...) - msg.KexAlgos = append(msg.KexAlgos, "ext-info-c") + msg.KexAlgos = append(msg.KexAlgos, extInfoClient) } } @@ -663,6 +671,33 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { return unexpectedMessageError(msgNewKeys, packet[0]) } + if !isClient { + // We're on the server side, see if the client sent the extension signal + if !t.extInfoSent && contains(clientInit.KexAlgos, extInfoClient) && contains(t.config.Extensions, ExtServerSigAlgs) { + // The other side supports ext info, an ext info message hasn't been sent this session, + // and we have at least one extension enabled, so send an SSH_MSG_EXT_INFO message. + extensions := map[string][]byte{} + // We're the server, the client supports SSH_MSG_EXT_INFO and server-sig-algs + // is enabled. Prepare the server-sig-algos extension message to send. + extensions[ExtServerSigAlgs] = []byte(strings.Join(supportedServerSigAlgs, ",")) + var payload []byte + for k, v := range extensions { + payload = appendInt(payload, len(k)) + payload = append(payload, k...) + payload = appendInt(payload, len(v)) + payload = append(payload, v...) + } + extInfo := extInfoMsg{ + NumExtensions: uint32(len(extensions)), + Payload: payload, + } + if err := t.conn.writePacket(Marshal(&extInfo)); err != nil { + return err + } + t.extInfoSent = true + } + } + return nil } diff --git a/ssh/server.go b/ssh/server.go index 3d4f176436..1588634efa 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -274,11 +274,41 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) // We just did the key change, so the session ID is established. s.sessionID = s.transport.getSessionID() + // the client could send a SSH_MSG_EXT_INFO before SSH_MSG_SERVICE_REQUEST var packet []byte if packet, err = s.transport.readPacket(); err != nil { return nil, err } + // be permissive and don't add contains(s.transport.config.Extensions, ExtServerSigAlgs) + if len(packet) > 0 && packet[0] == msgExtInfo { + // read SSH_MSG_EXT_INFO + var extInfo extInfoMsg + extensions := make(map[string][]byte) + if err := Unmarshal(packet, &extInfo); err != nil { + return nil, err + } + payload := extInfo.Payload + for i := uint32(0); i < extInfo.NumExtensions; i++ { + name, rest, ok := parseString(payload) + if !ok { + return nil, parseError(msgExtInfo) + } + value, rest, ok := parseString(rest) + if !ok { + return nil, parseError(msgExtInfo) + } + extensions[string(name)] = value + payload = rest + } + + // read the next packet + packet = nil + if packet, err = s.transport.readPacket(); err != nil { + return nil, err + } + } + var serviceRequest serviceRequestMsg if err = Unmarshal(packet, &serviceRequest); err != nil { return nil, err