Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssh: add support for extension negotiation (rfc 8308) #3

Merged
merged 8 commits into from
Aug 23, 2022
Merged
17 changes: 2 additions & 15 deletions ssh/client_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,10 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
// RFC 8308, Section 2.4.
extensions := make(map[string][]byte)
if len(packet) > 0 && packet[0] == msgExtInfo {
var extInfo extInfoMsg
if err := Unmarshal(packet, &extInfo); err != nil {
extensions, err = parseExtInfoMsg(packet)
if err != nil {
return err
}
payload := extInfo.Payload
for i := uint32(0); i < extInfo.NumExtensions; i++ {
name, rest, ok := parseString(payload)
if !ok {
return parseError(msgExtInfo)
}
value, rest, ok := parseString(rest)
if !ok {
return parseError(msgExtInfo)
}
extensions[string(name)] = value
payload = rest
}
packet, err = c.transport.readPacket()
if err != nil {
return err
Expand Down
4 changes: 1 addition & 3 deletions ssh/client_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ func TestClientAuthPublicKey(t *testing.T) {
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
// Once the server implements the server-sig-algs extension, this will turn
// into KeyAlgoRSASHA256.
if len(signer.used) != 1 || signer.used[0] != KeyAlgoRSA {
if len(signer.used) != 1 || signer.used[0] != KeyAlgoRSASHA256 {
t.Errorf("unexpected Sign/SignWithAlgorithm calls: %q", signer.used)
}
}
Expand Down
46 changes: 46 additions & 0 deletions ssh/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ const (
serviceSSH = "ssh-connection"
)

// These are string constants related to extensions and extension negotiation.
// See RFC 8308
const (
extInfoServer = "ext-info-s"
extInfoClient = "ext-info-c"
extServerSigAlgs = "server-sig-algs"
)

// supportedCiphers lists ciphers we support but might not recommend.
var supportedCiphers = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr",
Expand Down Expand Up @@ -89,6 +97,15 @@ var supportedMACs = []string{

var supportedCompressions = []string{compressionNone}

// supportedServerSigAlgs defines the algorithms supported for pubkey authentication.
// Order should not matter, but to avoid any issues we use the same order as OpenSSH.
// See RFC 8308, Section 3.1.
var supportedServerSigAlgs = []string{KeyAlgoED25519, KeyAlgoSKED25519,
KeyAlgoRSA, KeyAlgoRSASHA256, KeyAlgoRSASHA512,
KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
KeyAlgoSKECDSA256,
}

// hashFuncs keeps the mapping of supported signature algorithms to their
// respective hashes needed for signing and verification.
var hashFuncs = map[string]crypto.Hash{
Expand Down Expand Up @@ -129,6 +146,31 @@ func parseError(tag uint8) error {
return fmt.Errorf("ssh: parse error in message type %d", tag)
}

// parseExtInfoMsg returns the extensions from an extInfoMsg packet.
// packet must be an already validated extInfoMsg
func parseExtInfoMsg(packet []byte) (map[string][]byte, error) {
extensions := make(map[string][]byte)
var extInfo extInfoMsg

if err := Unmarshal(packet, &extInfo); err != nil {
return nil, err
}
payload := extInfo.Payload
for i := uint32(0); i < extInfo.NumExtensions; i++ {
name, rest, ok := parseString(payload)
if !ok {
return nil, parseError(msgExtInfo)
}
value, rest, ok := parseString(rest)
if !ok {
return nil, parseError(msgExtInfo)
}
extensions[string(name)] = value
payload = rest
}
return extensions, nil
}

func findCommon(what string, client []string, server []string) (common string, err error) {
for _, c := range client {
for _, s := range server {
Expand Down Expand Up @@ -180,6 +222,10 @@ func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMs
result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos)
if err != nil {
return
} else if result.kex == extInfoClient || result.kex == extInfoServer {
// According to RFC8308 section 2.2 if either the client or server extension signal
// is chosen as the kex algorithm the parties must disconnect.
return result, fmt.Errorf("ssh: invalid kex algorithm chosen: %s", result.kex)
}

result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
Expand Down
40 changes: 37 additions & 3 deletions ssh/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"io"
"log"
"net"
"strings"
"sync"
)

Expand Down Expand Up @@ -456,6 +457,7 @@ func (t *handshakeTransport) sendKexInit() error {
io.ReadFull(rand.Reader, msg.Cookie[:])

isServer := len(t.hostKeys) > 0
firstKeyExchange := t.sessionID == nil
if isServer {
for _, k := range t.hostKeys {
// If k is an AlgorithmSigner, presume it supports all signature algorithms
Expand All @@ -474,16 +476,24 @@ func (t *handshakeTransport) sendKexInit() error {
msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat)
}
}
// As a server we add ext-info-s to the KEX algorithms to indicate that we support
// the Extension Negotiation Mechanism. The ext-info-s indicator must be added only
// in the first key exchange. See RFC 8308, Section 2.1.
if firstKeyExchange {
msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+1)
msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...)
msg.KexAlgos = append(msg.KexAlgos, extInfoServer)
}
} else {
msg.ServerHostKeyAlgos = t.hostKeyAlgorithms

// As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what
// algorithms the server supports for public key authentication. See RFC
// 8308, Section 2.1.
if firstKeyExchange := t.sessionID == nil; firstKeyExchange {
if firstKeyExchange {
msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+1)
msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...)
msg.KexAlgos = append(msg.KexAlgos, "ext-info-c")
msg.KexAlgos = append(msg.KexAlgos, extInfoClient)
}
}

Expand Down Expand Up @@ -615,7 +625,8 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
return err
}

if t.sessionID == nil {
firstKeyExchange := t.sessionID == nil
if firstKeyExchange {
t.sessionID = result.H
}
result.SessionID = t.sessionID
Expand All @@ -632,6 +643,29 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
return unexpectedMessageError(msgNewKeys, packet[0])
}

if !isClient {
// We're on the server side, if this is the first key exchange
// and the client sent the ext-info-c indicator, we send an SSH_MSG_EXT_INFO
// message with the server-sig-algs extension. See RFC 8308, Section 3.1.
if firstKeyExchange && contains(clientInit.KexAlgos, extInfoClient) {
extensions := map[string][]byte{}
extensions[extServerSigAlgs] = []byte(strings.Join(supportedServerSigAlgs, ","))

extInfo := &extInfoMsg{
NumExtensions: uint32(len(extensions)),
}
for k, v := range extensions {
extInfo.Payload = appendInt(extInfo.Payload, len(k))
extInfo.Payload = append(extInfo.Payload, k...)
extInfo.Payload = appendInt(extInfo.Payload, len(v))
extInfo.Payload = append(extInfo.Payload, v...)
}
if err := t.conn.writePacket(Marshal(extInfo)); err != nil {
return err
}
}
}

return nil
}

Expand Down
16 changes: 15 additions & 1 deletion ssh/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,24 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
// We just did the key change, so the session ID is established.
s.sessionID = s.transport.getSessionID()

// the client could send a SSH_MSG_EXT_INFO after the first SSH_MSG_NEWKEYS
// and so before SSH_MSG_SERVICE_REQUEST. See RFC 8308, Section 2.4.
var packet []byte
if packet, err = s.transport.readPacket(); err != nil {
return nil, err
}

if len(packet) > 0 && packet[0] == msgExtInfo {
// read SSH_MSG_EXT_INFO
if _, err := parseExtInfoMsg(packet); err != nil {
return nil, err
}
// read the next packet
if packet, err = s.transport.readPacket(); err != nil {
return nil, err
}
}

var serviceRequest serviceRequestMsg
if err = Unmarshal(packet, &serviceRequest); err != nil {
return nil, err
Expand All @@ -304,7 +317,8 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
func isAcceptableAlgo(algo string) bool {
switch algo {
case KeyAlgoRSA, KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoSKECDSA256, KeyAlgoED25519, KeyAlgoSKED25519,
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01:
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01,
CertAlgoRSASHA256v01, CertAlgoRSASHA512v01:
return true
}
return false
Expand Down