diff --git a/ssh/client_auth.go b/ssh/client_auth.go index 409b5ea1d4..a531b4c81e 100644 --- a/ssh/client_auth.go +++ b/ssh/client_auth.go @@ -35,23 +35,10 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error { // RFC 8308, Section 2.4. extensions := make(map[string][]byte) if len(packet) > 0 && packet[0] == msgExtInfo { - var extInfo extInfoMsg - if err := Unmarshal(packet, &extInfo); err != nil { + extensions, err = parseExtInfoMsg(packet) + if err != nil { return err } - payload := extInfo.Payload - for i := uint32(0); i < extInfo.NumExtensions; i++ { - name, rest, ok := parseString(payload) - if !ok { - return parseError(msgExtInfo) - } - value, rest, ok := parseString(rest) - if !ok { - return parseError(msgExtInfo) - } - extensions[string(name)] = value - payload = rest - } packet, err = c.transport.readPacket() if err != nil { return err diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go index a6bedbfcaa..35b62e3311 100644 --- a/ssh/client_auth_test.go +++ b/ssh/client_auth_test.go @@ -132,9 +132,7 @@ func TestClientAuthPublicKey(t *testing.T) { if err := tryAuth(t, config); err != nil { t.Fatalf("unable to dial remote side: %s", err) } - // Once the server implements the server-sig-algs extension, this will turn - // into KeyAlgoRSASHA256. - if len(signer.used) != 1 || signer.used[0] != KeyAlgoRSA { + if len(signer.used) != 1 || signer.used[0] != KeyAlgoRSASHA256 { t.Errorf("unexpected Sign/SignWithAlgorithm calls: %q", signer.used) } } diff --git a/ssh/common.go b/ssh/common.go index 2a47a61ded..ba781a5303 100644 --- a/ssh/common.go +++ b/ssh/common.go @@ -24,6 +24,14 @@ const ( serviceSSH = "ssh-connection" ) +// These are string constants related to extensions and extension negotiation. +// See RFC 8308 +const ( + extInfoServer = "ext-info-s" + extInfoClient = "ext-info-c" + extServerSigAlgs = "server-sig-algs" +) + // supportedCiphers lists ciphers we support but might not recommend. var supportedCiphers = []string{ "aes128-ctr", "aes192-ctr", "aes256-ctr", @@ -89,6 +97,15 @@ var supportedMACs = []string{ var supportedCompressions = []string{compressionNone} +// supportedServerSigAlgs defines the algorithms supported for pubkey authentication. +// Order should not matter, but to avoid any issues we use the same order as OpenSSH. +// See RFC 8308, Section 3.1. +var supportedServerSigAlgs = []string{KeyAlgoED25519, KeyAlgoSKED25519, + KeyAlgoRSA, KeyAlgoRSASHA256, KeyAlgoRSASHA512, + KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, + KeyAlgoSKECDSA256, +} + // hashFuncs keeps the mapping of supported signature algorithms to their // respective hashes needed for signing and verification. var hashFuncs = map[string]crypto.Hash{ @@ -129,6 +146,31 @@ func parseError(tag uint8) error { return fmt.Errorf("ssh: parse error in message type %d", tag) } +// parseExtInfoMsg returns the extensions from an extInfoMsg packet. +// packet must be an already validated extInfoMsg +func parseExtInfoMsg(packet []byte) (map[string][]byte, error) { + extensions := make(map[string][]byte) + var extInfo extInfoMsg + + if err := Unmarshal(packet, &extInfo); err != nil { + return nil, err + } + payload := extInfo.Payload + for i := uint32(0); i < extInfo.NumExtensions; i++ { + name, rest, ok := parseString(payload) + if !ok { + return nil, parseError(msgExtInfo) + } + value, rest, ok := parseString(rest) + if !ok { + return nil, parseError(msgExtInfo) + } + extensions[string(name)] = value + payload = rest + } + return extensions, nil +} + func findCommon(what string, client []string, server []string) (common string, err error) { for _, c := range client { for _, s := range server { @@ -180,6 +222,10 @@ func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMs result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) if err != nil { return + } else if result.kex == extInfoClient || result.kex == extInfoServer { + // According to RFC8308 section 2.2 if either the client or server extension signal + // is chosen as the kex algorithm the parties must disconnect. + return result, fmt.Errorf("ssh: invalid kex algorithm chosen: %s", result.kex) } result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) diff --git a/ssh/handshake.go b/ssh/handshake.go index 653dc4d2cf..694b853d60 100644 --- a/ssh/handshake.go +++ b/ssh/handshake.go @@ -11,6 +11,7 @@ import ( "io" "log" "net" + "strings" "sync" ) @@ -456,6 +457,7 @@ func (t *handshakeTransport) sendKexInit() error { io.ReadFull(rand.Reader, msg.Cookie[:]) isServer := len(t.hostKeys) > 0 + firstKeyExchange := t.sessionID == nil if isServer { for _, k := range t.hostKeys { // If k is an AlgorithmSigner, presume it supports all signature algorithms @@ -474,16 +476,24 @@ func (t *handshakeTransport) sendKexInit() error { msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat) } } + // As a server we add ext-info-s to the KEX algorithms to indicate that we support + // the Extension Negotiation Mechanism. The ext-info-s indicator must be added only + // in the first key exchange. See RFC 8308, Section 2.1. + if firstKeyExchange { + msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+1) + msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...) + msg.KexAlgos = append(msg.KexAlgos, extInfoServer) + } } else { msg.ServerHostKeyAlgos = t.hostKeyAlgorithms // As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what // algorithms the server supports for public key authentication. See RFC // 8308, Section 2.1. - if firstKeyExchange := t.sessionID == nil; firstKeyExchange { + if firstKeyExchange { msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+1) msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...) - msg.KexAlgos = append(msg.KexAlgos, "ext-info-c") + msg.KexAlgos = append(msg.KexAlgos, extInfoClient) } } @@ -615,7 +625,8 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { return err } - if t.sessionID == nil { + firstKeyExchange := t.sessionID == nil + if firstKeyExchange { t.sessionID = result.H } result.SessionID = t.sessionID @@ -632,6 +643,29 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { return unexpectedMessageError(msgNewKeys, packet[0]) } + if !isClient { + // We're on the server side, if this is the first key exchange + // and the client sent the ext-info-c indicator, we send an SSH_MSG_EXT_INFO + // message with the server-sig-algs extension. See RFC 8308, Section 3.1. + if firstKeyExchange && contains(clientInit.KexAlgos, extInfoClient) { + extensions := map[string][]byte{} + extensions[extServerSigAlgs] = []byte(strings.Join(supportedServerSigAlgs, ",")) + + extInfo := &extInfoMsg{ + NumExtensions: uint32(len(extensions)), + } + for k, v := range extensions { + extInfo.Payload = appendInt(extInfo.Payload, len(k)) + extInfo.Payload = append(extInfo.Payload, k...) + extInfo.Payload = appendInt(extInfo.Payload, len(v)) + extInfo.Payload = append(extInfo.Payload, v...) + } + if err := t.conn.writePacket(Marshal(extInfo)); err != nil { + return err + } + } + } + return nil } diff --git a/ssh/server.go b/ssh/server.go index 70045bdfd8..20c09e58b4 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -256,11 +256,24 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) // We just did the key change, so the session ID is established. s.sessionID = s.transport.getSessionID() + // the client could send a SSH_MSG_EXT_INFO after the first SSH_MSG_NEWKEYS + // and so before SSH_MSG_SERVICE_REQUEST. See RFC 8308, Section 2.4. var packet []byte if packet, err = s.transport.readPacket(); err != nil { return nil, err } + if len(packet) > 0 && packet[0] == msgExtInfo { + // read SSH_MSG_EXT_INFO + if _, err := parseExtInfoMsg(packet); err != nil { + return nil, err + } + // read the next packet + if packet, err = s.transport.readPacket(); err != nil { + return nil, err + } + } + var serviceRequest serviceRequestMsg if err = Unmarshal(packet, &serviceRequest); err != nil { return nil, err @@ -286,7 +299,8 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) func isAcceptableAlgo(algo string) bool { switch algo { case KeyAlgoRSA, KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoSKECDSA256, KeyAlgoED25519, KeyAlgoSKED25519, - CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01: + CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01, + CertAlgoRSASHA256v01, CertAlgoRSASHA512v01: return true } return false