From 36879566a0a1dc474d238f74fe3e3f86b8d8f2a8 Mon Sep 17 00:00:00 2001 From: Kristin Davidson Date: Mon, 1 Nov 2021 20:59:34 +0000 Subject: [PATCH] ssh: add support for extension negotiation (rfc 8308) This adds support for SSH extension negotiation via SSH_MSG_EXT_INFO as defined in RFC 8308 [1]. It also adds support for the `server-sig-algs` extension on both the client and server sides. [1] https://datatracker.ietf.org/doc/html/rfc8308 Fixes golang/go#49269 Change-Id: Iab374d1254ec83eabdb5f433b95ff39a1a540cc3 GitHub-Last-Rev: 29bf9ec043429558b99ff88e27422ccd0873e235 GitHub-Pull-Request: golang/crypto#197 --- ssh/client_auth.go | 47 ++++++- ssh/client_auth_test.go | 128 +++++++++++++++++++ ssh/common.go | 51 ++++++++ ssh/common_test.go | 22 ++++ ssh/handshake.go | 48 +++++++ ssh/handshake_test.go | 269 ++++++++++++++++++++++++++++++++++++++-- ssh/messages.go | 89 +++++++++++++ ssh/messages_test.go | 64 +++++++--- ssh/server.go | 45 +++++-- 9 files changed, 719 insertions(+), 44 deletions(-) diff --git a/ssh/client_auth.go b/ssh/client_auth.go index c611aeb684..e864a139a1 100644 --- a/ssh/client_auth.go +++ b/ssh/client_auth.go @@ -25,13 +25,31 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error { if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { return err } - packet, err := c.transport.readPacket() - if err != nil { - return err - } + var serviceAccept serviceAcceptMsg - if err := Unmarshal(packet, &serviceAccept); err != nil { - return err +readAcceptLoop: + for { + packet, err := c.transport.readPacket() + if err != nil { + return err + } + + switch packet[0] { + case msgExtInfo: + var extInfo extInfoMsg + if err := Unmarshal(packet, &extInfo); err != nil { + return err + } + c.transport.extensions = extInfo.Extensions + continue + case msgServiceAccept: + if err := Unmarshal(packet, &serviceAccept); err != nil { + return err + } + break readAcceptLoop + default: + return fmt.Errorf("ssh: unexpected message received") + } } // during the authentication phase the client first attempts the "none" method @@ -337,6 +355,14 @@ func handleAuthResponse(c packetConn) (authResult, []string, error) { } switch packet[0] { + case msgExtInfo: + var extInfo extInfoMsg + if err := Unmarshal(packet, &extInfo); err != nil { + return authFailure, nil, err + } + if transport, ok := c.(*handshakeTransport); ok { + transport.extensions = extInfo.Extensions + } case msgUserAuthBanner: if err := handleBannerResponse(c, packet); err != nil { return authFailure, nil, err @@ -420,6 +446,15 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe // like handleAuthResponse, but with less options. switch packet[0] { + case msgExtInfo: + var extInfo extInfoMsg + if err := Unmarshal(packet, &extInfo); err != nil { + return authFailure, nil, err + } + if transport, ok := c.(*handshakeTransport); ok { + transport.extensions = extInfo.Extensions + } + continue case msgUserAuthBanner: if err := handleBannerResponse(c, packet); err != nil { return authFailure, nil, err diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go index 63a8e22487..9d392c21dd 100644 --- a/ssh/client_auth_test.go +++ b/ssh/client_auth_test.go @@ -117,6 +117,75 @@ func TestClientAuthPublicKey(t *testing.T) { } } +func TestClientAuthPublicKeyExtensions(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + certChecker := CertChecker{ + IsUserAuthority: func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal()) + }, + UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + return nil, nil + } + + return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) + }, + IsRevoked: func(c *Certificate) bool { + return c.Serial == 666 + }, + } + serverConfig := &ServerConfig{ + PublicKeyCallback: certChecker.Authenticate, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + go newServer(c1, serverConfig) + clientConn, _, _, err := NewClientConn(c2, "", &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + }) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + + conn, ok := clientConn.(*connection) + if !ok { + t.Fatalf("conn is not a *connection") + } + + rawServerSigAlgs, ok := conn.transport.extensions[ExtServerSigAlgs] + if !ok { + t.Fatalf("did not receive server-sig-algs extension") + } + + serverSigAlgs := strings.Split(string(rawServerSigAlgs), ",") + if len(serverSigAlgs) == 0 { + t.Fatalf("did not receive any server-sig-algs") + } + + for _, expectedAlg := range supportedSigAlgs() { + hasAlg := false + for _, receivedAlg := range serverSigAlgs { + if receivedAlg == expectedAlg { + hasAlg = true + break + } + } + if !hasAlg { + t.Errorf("server-sig-algs did not have expected alg: %s", expectedAlg) + } + } +} + func TestAuthMethodPassword(t *testing.T) { config := &ClientConfig{ User: "testuser", @@ -131,6 +200,65 @@ func TestAuthMethodPassword(t *testing.T) { } } +func TestClientAuthPasswordExtensions(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConfig := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { + if conn.User() == "testuser" && string(pass) == clientPassword { + return nil, nil + } + return nil, errors.New("password auth failed") + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + go newServer(c1, serverConfig) + clientConn, _, _, err := NewClientConn(c2, "", &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + }) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + + conn, ok := clientConn.(*connection) + if !ok { + t.Fatalf("conn is not a *connection") + } + + rawServerSigAlgs, ok := conn.transport.extensions[ExtServerSigAlgs] + if !ok { + t.Fatalf("did not receive server-sig-algs extension") + } + + serverSigAlgs := strings.Split(string(rawServerSigAlgs), ",") + if len(serverSigAlgs) == 0 { + t.Fatalf("did not receive any server-sig-algs") + } + + for _, expectedAlg := range supportedSigAlgs() { + hasAlg := false + for _, receivedAlg := range serverSigAlgs { + if receivedAlg == expectedAlg { + hasAlg = true + break + } + } + if !hasAlg { + t.Errorf("server-sig-algs did not have expected alg: %s", expectedAlg) + } + } +} + func TestAuthMethodFallback(t *testing.T) { var passwordCalled bool config := &ClientConfig{ diff --git a/ssh/common.go b/ssh/common.go index 5ae2275744..4a20113315 100644 --- a/ssh/common.go +++ b/ssh/common.go @@ -24,6 +24,22 @@ const ( serviceSSH = "ssh-connection" ) +// These are string constants related to extensions and extension negotiation +const ( + extInfoServer = "ext-info-s" + extInfoClient = "ext-info-c" + + ExtServerSigAlgs = "server-sig-algs" + // extDelayCompression = "delay-compression" + // extNoFlowControl = "no-flow-control" + // extElevation = "elevation" +) + +// defaultExtensions lists extensions enabled by default. +var defaultExtensions = []string{ + ExtServerSigAlgs, +} + // supportedCiphers lists ciphers we support but might not recommend. var supportedCiphers = []string{ "aes128-ctr", "aes192-ctr", "aes256-ctr", @@ -108,6 +124,18 @@ var hashFuncs = map[string]crypto.Hash{ CertAlgoECDSA521v01: crypto.SHA512, } +// supportedSigAlgs returns a slice of algorithms supported for pubkey authentication +// in no particular order. +func supportedSigAlgs() []string { + // TODO(kxd) I'm not sure if hashFuncs is the best place to get this set but it seemed + // like a sensible first step. Should this be a curated list? + var serverSigAlgs []string + for k := range hashFuncs { + serverSigAlgs = append(serverSigAlgs, k) + } + return serverSigAlgs +} + // unexpectedMessageError results when the SSH message that we received didn't // match what we wanted. func unexpectedMessageError(expected, got uint8) error { @@ -130,6 +158,16 @@ func findCommon(what string, client []string, server []string) (common string, e return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server) } +// hasString returns true if string "a" is in slice of strings "x", false otherwise. +func hasString(a string, x []string) bool { + for _, s := range x { + if a == s { + return true + } + } + return false +} + // directionAlgorithms records algorithm choices in one direction (either read or write) type directionAlgorithms struct { Cipher string @@ -165,6 +203,11 @@ func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMs result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) if err != nil { return + } else if result.kex == extInfoClient || result.kex == extInfoServer { + // According to RFC8308 section 2.2 if either the client or server extension signal + // is chosen as the kex algorithm the parties must disconnect. + // chosen + return result, fmt.Errorf("ssh: invalid kex algorithm chosen") } result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) @@ -238,6 +281,10 @@ type Config struct { // The allowed MAC algorithms. If unspecified then a sensible default // is used. MACs []string + + // A list of enabled extensions. If unspecified then a sensible + // default is used + Extensions []string } // SetDefaults sets sensible values for unset fields in config. This is @@ -267,6 +314,10 @@ func (c *Config) SetDefaults() { c.MACs = supportedMACs } + if c.Extensions == nil { + c.Extensions = defaultExtensions + } + if c.RekeyThreshold == 0 { // cipher specific default } else if c.RekeyThreshold < minRekeyThreshold { diff --git a/ssh/common_test.go b/ssh/common_test.go index 96744dcf0f..00e1fb5c92 100644 --- a/ssh/common_test.go +++ b/ssh/common_test.go @@ -110,6 +110,28 @@ func TestFindAgreedAlgorithms(t *testing.T) { wantErr: true, }, + testcase{ + name: "server ext info kex chosen", + serverIn: kexInitMsg{ + KexAlgos: []string{extInfoServer}, + }, + clientIn: kexInitMsg{ + KexAlgos: []string{extInfoServer}, + }, + wantErr: true, + }, + + testcase{ + name: "client ext info kex chosen", + serverIn: kexInitMsg{ + KexAlgos: []string{extInfoClient}, + }, + clientIn: kexInitMsg{ + KexAlgos: []string{extInfoClient}, + }, + wantErr: true, + }, + testcase{ name: "client decides cipher", serverIn: kexInitMsg{ diff --git a/ssh/handshake.go b/ssh/handshake.go index 05ad49c364..2b51168760 100644 --- a/ssh/handshake.go +++ b/ssh/handshake.go @@ -11,6 +11,7 @@ import ( "io" "log" "net" + "strings" "sync" ) @@ -94,6 +95,16 @@ type handshakeTransport struct { // The session ID or nil if first kex did not complete yet. sessionID []byte + + // True if the other side has signaled support for extensions. + extInfoSupported bool + // True if the first ext info message has been sent immediately following + // SSH_MSG_NEWKEYS, false otherwise. + extInfoSent bool + + // extensions is a map of extensions received from the other side as + // part of extension negotiation. + extensions map[string][]byte } type pendingKex struct { @@ -109,6 +120,7 @@ func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, incoming: make(chan []byte, chanSize), requestKex: make(chan struct{}, 1), startKex: make(chan *pendingKex, 1), + extensions: map[string][]byte{}, config: config, } @@ -467,8 +479,14 @@ func (t *handshakeTransport) sendKexInit() error { msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algo) } } + if hasString(ExtServerSigAlgs, t.config.Extensions) { + msg.KexAlgos = append(msg.KexAlgos, extInfoServer) + } } else { msg.ServerHostKeyAlgos = t.hostKeyAlgorithms + if hasString(ExtServerSigAlgs, t.config.Extensions) { + msg.KexAlgos = append(msg.KexAlgos, extInfoClient) + } } packet := Marshal(msg) @@ -615,6 +633,36 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { return unexpectedMessageError(msgNewKeys, packet[0]) } + // Handle the ext-info-X extension signal from RFC8308 section 2. + // First, see if the other side supports us sending SSH_MSG_EXT_INFO to them. + if isClient { + // We're on the client side, see if the server sent the extension signal + t.extInfoSupported = hasString(extInfoServer, serverInit.KexAlgos) + } else { + // We're on the server side, see if the client sent the extension signal + t.extInfoSupported = hasString(extInfoClient, clientInit.KexAlgos) + } + if t.extInfoSupported && !t.extInfoSent && len(t.config.Extensions) > 0 { + // The other side supports ext info, an ext info message hasn't been sent this session, + // and we have at least one extension enabled, so send an SSH_MSG_EXT_INFO message. + extensions := map[string][]byte{} + if !isClient { + if hasString(ExtServerSigAlgs, t.config.Extensions) { + // We're the server, the client supports SSH_MSG_EXT_INFO and server-sig-algs + // is enabled. Prepare the server-sig-algos extension message to send. + extensions[ExtServerSigAlgs] = []byte(strings.Join(supportedSigAlgs(), ",")) + } + } + extInfo := extInfoMsg{ + NumExtensions: uint32(len(extensions)), + Extensions: extensions, + } + if err := t.conn.writePacket(Marshal(&extInfo)); err != nil { + return err + } + t.extInfoSent = true + } + return nil } diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go index 02fbe83821..1d44ca637c 100644 --- a/ssh/handshake_test.go +++ b/ssh/handshake_test.go @@ -88,7 +88,7 @@ func addNoiseTransport(t keyingTransport) keyingTransport { // handshakePair creates two handshakeTransports connected with each // other. If the noise argument is true, both transports will try to // confuse the other side by sending ignore and debug messages. -func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) { +func handshakePair(serverConf *ServerConfig, clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) { a, b, err := netPipe() if err != nil { return nil, nil, err @@ -107,7 +107,6 @@ func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *h v := []byte("version") client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr()) - serverConf := &ServerConfig{} serverConf.AddHostKey(testSigners["ecdsa"]) serverConf.AddHostKey(testSigners["rsa"]) serverConf.SetDefaults() @@ -132,9 +131,15 @@ func TestHandshakeBasic(t *testing.T) { waitCall: make(chan int, 10), called: make(chan int, 10), } - checker.waitCall <- 1 - trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + + serverConf := &ServerConfig{} + serverConf.Extensions = []string{} + + clientConf := &ClientConfig{HostKeyCallback: checker.Check} + clientConf.Extensions = []string{} + + trC, trS, err := handshakePair(serverConf, clientConf, "addr", false) if err != nil { t.Fatalf("handshakePair: %v", err) } @@ -260,9 +265,12 @@ func TestHandshakeAutoRekeyWrite(t *testing.T) { called: make(chan int, 10), waitCall: nil, } + serverConf := &ServerConfig{} + serverConf.Extensions = []string{} clientConf := &ClientConfig{HostKeyCallback: checker.Check} + clientConf.Extensions = []string{} clientConf.RekeyThreshold = 500 - trC, trS, err := handshakePair(clientConf, "addr", false) + trC, trS, err := handshakePair(serverConf, clientConf, "addr", false) if err != nil { t.Fatalf("handshakePair: %v", err) } @@ -325,12 +333,13 @@ func TestHandshakeAutoRekeyRead(t *testing.T) { called: make(chan int, 2), waitCall: nil, } + serverConf := &ServerConfig{} clientConf := &ClientConfig{ HostKeyCallback: sync.Check, } clientConf.RekeyThreshold = 500 - trC, trS, err := handshakePair(clientConf, "addr", false) + trC, trS, err := handshakePair(serverConf, clientConf, "addr", false) if err != nil { t.Fatalf("handshakePair: %v", err) } @@ -496,7 +505,16 @@ func TestDisconnect(t *testing.T) { t.Skip("see golang.org/issue/7237") } checker := &testChecker{} - trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + // TODO(kxd) Would it make more sense for this test to disable extensions? Then there + // wouldn't be a need to check for the ext-info packet explicitly in the test. I've + // waffled on this because then it's not testing the case of a disconnect right after + // connecting with extensions, but maybe that's not needed/not the intent of the test + // anyway? + serverConf := &ServerConfig{} + clientConf := &ClientConfig{ + HostKeyCallback: checker.Check, + } + trC, trS, err := handshakePair(serverConf, clientConf, "addr", false) if err != nil { t.Fatalf("handshakePair: %v", err) } @@ -504,6 +522,14 @@ func TestDisconnect(t *testing.T) { defer trC.Close() defer trS.Close() + packet, err := trS.readPacket() + if err != nil { + t.Fatalf("readPacket 1: %v", err) + } + if packet[0] != msgExtInfo { + t.Errorf("got packet %v, want packet type %d", packet, msgExtInfo) + } + trC.writePacket([]byte{msgRequestSuccess, 0, 0}) errMsg := &disconnectMsg{ Reason: 42, @@ -512,9 +538,9 @@ func TestDisconnect(t *testing.T) { trC.writePacket(Marshal(errMsg)) trC.writePacket([]byte{msgRequestSuccess, 0, 0}) - packet, err := trS.readPacket() + packet, err = trS.readPacket() if err != nil { - t.Fatalf("readPacket 1: %v", err) + t.Fatalf("readPacket 2: %v", err) } if packet[0] != msgRequestSuccess { t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess) @@ -522,25 +548,26 @@ func TestDisconnect(t *testing.T) { _, err = trS.readPacket() if err == nil { - t.Errorf("readPacket 2 succeeded") + t.Errorf("readPacket 3 succeeded") } else if !reflect.DeepEqual(err, errMsg) { t.Errorf("got error %#v, want %#v", err, errMsg) } _, err = trS.readPacket() if err == nil { - t.Errorf("readPacket 3 succeeded") + t.Errorf("readPacket 4 succeeded") } } func TestHandshakeRekeyDefault(t *testing.T) { + serverConf := &ServerConfig{} clientConf := &ClientConfig{ Config: Config{ Ciphers: []string{"aes128-ctr"}, }, HostKeyCallback: InsecureIgnoreHostKey(), } - trC, trS, err := handshakePair(clientConf, "addr", false) + trC, trS, err := handshakePair(serverConf, clientConf, "addr", false) if err != nil { t.Fatalf("handshakePair: %v", err) } @@ -560,3 +587,221 @@ func TestHandshakeRekeyDefault(t *testing.T) { t.Errorf("got rekey after %dG write, want 64G", wgb) } } + +func TestHandshakeExtInfoSignal(t *testing.T) { + t.Run("no ext info signal from server when disabled", func(t *testing.T) { + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + var trC, trS keyingTransport + + trC = newTransport(a, rand.Reader, true) + trS = newTransport(b, rand.Reader, false) + + v := []byte("version") + + serverConf := &ServerConfig{} + serverConf.SetDefaults() + serverConf.Extensions = []string{} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.AddHostKey(testSigners["rsa"]) + server := newServerTransport(trS, v, v, serverConf) + + defer trC.Close() + defer server.Close() + + data, err := trC.readPacket() + if err != nil { + t.Fatalf("could not read client packet: %v", err) + } + + if data[0] != msgKexInit { + t.Fatalf("client didn't receive kexinit") + } + var serverInit kexInitMsg + if err = Unmarshal(data, &serverInit); err != nil { + t.Errorf("could not unmarshal kexinit: %v", err) + } + + if hasString(extInfoServer, serverInit.KexAlgos) { + t.Errorf("server sent ext info signal when extension was disabled") + } + }) + t.Run("ext info signal from server when enabled", func(t *testing.T) { + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + var trC, trS keyingTransport + + trC = newTransport(a, rand.Reader, true) + trS = newTransport(b, rand.Reader, false) + + v := []byte("version") + + serverConf := &ServerConfig{} + serverConf.SetDefaults() + serverConf.Extensions = []string{ExtServerSigAlgs} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.AddHostKey(testSigners["rsa"]) + server := newServerTransport(trS, v, v, serverConf) + + defer trC.Close() + defer server.Close() + + data, err := trC.readPacket() + if err != nil { + t.Fatalf("could not read client packet: %v", err) + } + + if data[0] != msgKexInit { + t.Fatalf("client didn't receive kexinit") + } + var serverInit kexInitMsg + if err = Unmarshal(data, &serverInit); err != nil { + t.Errorf("could not unmarshal kexinit: %v", err) + } + + if !hasString(extInfoServer, serverInit.KexAlgos) { + t.Errorf("server did not send ext info signal") + } + }) + t.Run("no ext info signal from client when disabled", func(t *testing.T) { + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + var trC, trS keyingTransport + + trC = newTransport(a, rand.Reader, true) + trS = newTransport(b, rand.Reader, false) + + v := []byte("version") + + clientConf := &ClientConfig{} + clientConf.SetDefaults() + clientConf.Extensions = []string{} + client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr()) + + defer client.Close() + defer trS.Close() + + data, err := trS.readPacket() + if err != nil { + t.Fatalf("could not read server packet: %v", err) + } + + if data[0] != msgKexInit { + t.Fatalf("server didn't receive kexinit") + } + var clientInit kexInitMsg + if err = Unmarshal(data, &clientInit); err != nil { + t.Errorf("could not unmarshal kexinit: %v", err) + } + + if hasString(extInfoClient, clientInit.KexAlgos) { + t.Errorf("client sent ext info signal when extension was disabled") + } + }) + t.Run("ext info signal from client when enabled", func(t *testing.T) { + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + var trC, trS keyingTransport + + trC = newTransport(a, rand.Reader, true) + trS = newTransport(b, rand.Reader, false) + + v := []byte("version") + + clientConf := &ClientConfig{} + clientConf.SetDefaults() + clientConf.Extensions = []string{ExtServerSigAlgs} + client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr()) + + defer client.Close() + defer trS.Close() + + data, err := trS.readPacket() + if err != nil { + t.Fatalf("could not read server packet: %v", err) + } + + if data[0] != msgKexInit { + t.Fatalf("server didn't receive kexinit") + } + var clientInit kexInitMsg + if err = Unmarshal(data, &clientInit); err != nil { + t.Errorf("could not unmarshal kexinit: %v", err) + } + + if !hasString(extInfoClient, clientInit.KexAlgos) { + t.Errorf("client did not send ext info signal") + } + }) +} + +func TestHandshakeExtInfoAfterNewkeys(t *testing.T) { + serverConf := &ServerConfig{ + Config: Config{ + Extensions: []string{ + ExtServerSigAlgs, + }, + }, + } + clientConf := &ClientConfig{ + Config: Config{ + Extensions: []string{ + ExtServerSigAlgs, + }, + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + trC, trS, err := handshakePair(serverConf, clientConf, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + packet, err := trC.readPacket() + if err != nil { + t.Fatalf("readPacket 1: %v", err) + } + + var extInfo extInfoMsg + if err := Unmarshal(packet, &extInfo); err != nil { + t.Fatalf("unable to unmarshal packet: %v", err) + } + + if rawServerAlgs, ok := extInfo.Extensions[ExtServerSigAlgs]; ok { + expectedAlgs := supportedSigAlgs() + serverAlgs := strings.Split(string(rawServerAlgs), ",") + + if len(expectedAlgs) != len(serverAlgs) { + t.Errorf("expected %d server-sig-algs but received %d", + len(expectedAlgs), len(serverAlgs)) + } + + for _, expected := range expectedAlgs { + foundAlg := false + for _, found := range serverAlgs { + if expected == found { + foundAlg = true + break + } + } + if !foundAlg { + t.Errorf("could not find expected alg '%s' in server-sig-algs", expected) + } + } + } else { + t.Error("could not find expected server-sig-algs extension") + } +} diff --git a/ssh/messages.go b/ssh/messages.go index ac41a4168b..ecb2d81240 100644 --- a/ssh/messages.go +++ b/ssh/messages.go @@ -141,6 +141,15 @@ type serviceAcceptMsg struct { Service string `sshtype:"6"` } +// See RFC 4253, section 10. +const msgExtInfo = 7 + +// See RFC 8308, section 2.3. +type extInfoMsg struct { + NumExtensions uint32 `sshtype:"7"` + Extensions map[string][]byte `sshlen:"NumExtensions"` +} + // See RFC 4252, section 5. const msgUserAuthRequest = 50 @@ -467,6 +476,56 @@ func Unmarshal(data []byte, out interface{}) error { default: return fieldError(structType, i, "slice of unsupported type") } + case reflect.Map: + if t.Key().Kind() != reflect.String { + return fieldError(structType, i, "map of unsupported key type") + } + + lenTagVal := structType.Field(i).Tag.Get("sshlen") + if lenTagVal == "" { + return fieldError(structType, i, "unable to determine data length") + } + lenVal := v.FieldByName(lenTagVal) + + mapLen := 0 + switch lenVal.Kind() { + case reflect.Uint32: + mapLen = int(lenVal.Uint()) + default: + return fieldError(structType, i, "len field of unsupported type") + } + + newMap := reflect.MakeMapWithSize(t, mapLen) + for j := 0; j < mapLen; j++ { + var k, v []byte + if k, data, ok = parseString(data); !ok { + return errShortRead + } + + switch t.Elem().Kind() { + case reflect.String: + if v, data, ok = parseString(data); !ok { + return errShortRead + } + + newMap.SetMapIndex(reflect.ValueOf(string(k)), reflect.ValueOf(string(v))) + case reflect.Slice: + sliceElem := t.Elem() + switch sliceElem.Elem().Kind() { + case reflect.Uint8: + if v, data, ok = parseString(data); !ok { + return errShortRead + } + newMap.SetMapIndex(reflect.ValueOf(string(k)), reflect.ValueOf(v)) + default: + return fieldError(structType, i, "map of unsupported element type") + } + default: + return fieldError(structType, i, "map of unsupported element type") + } + } + + field.Set(newMap) case reflect.Ptr: if t == bigIntType { var n *big.Int @@ -556,6 +615,33 @@ func marshalStruct(out []byte, msg interface{}) []byte { default: panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface())) } + case reflect.Map: + for _, k := range field.MapKeys() { + switch k.Kind() { + case reflect.String: + out = appendInt(out, len(k.String())) + out = append(out, k.String()...) + + v := field.MapIndex(k) + switch v.Kind() { + case reflect.String: + out = appendInt(out, len(v.String())) + out = append(out, v.String()...) + case reflect.Slice: + switch v.Type().Elem().Kind() { + case reflect.Uint8: + out = appendInt(out, len(v.Bytes())) + out = append(out, v.Bytes()...) + default: + panic(fmt.Sprintf("map of unknown element type in field %d: %T", i, field.Interface())) + } + default: + panic(fmt.Sprintf("map of unknown element type in field %d: %T", i, field.Interface())) + } + default: + panic(fmt.Sprintf("map of unknown key type in field %d: %T", i, field.Interface())) + } + } case reflect.Ptr: if t == bigIntType { var n *big.Int @@ -782,6 +868,8 @@ func decode(packet []byte) (interface{}, error) { msg = new(serviceRequestMsg) case msgServiceAccept: msg = new(serviceAcceptMsg) + case msgExtInfo: + msg = new(extInfoMsg) case msgKexInit: msg = new(kexInitMsg) case msgKexDHInit: @@ -843,6 +931,7 @@ var packetTypeNames = map[byte]string{ msgDisconnect: "disconnectMsg", msgServiceRequest: "serviceRequestMsg", msgServiceAccept: "serviceAcceptMsg", + msgExtInfo: "extInfoMsg", msgKexInit: "kexInitMsg", msgKexDHInit: "kexDHInitMsg", msgKexDHReply: "kexDHReplyMsg", diff --git a/ssh/messages_test.go b/ssh/messages_test.go index e79076412a..600470c942 100644 --- a/ssh/messages_test.go +++ b/ssh/messages_test.go @@ -34,16 +34,20 @@ func TestIntLength(t *testing.T) { } type msgAllTypes struct { - Bool bool `sshtype:"21"` - Array [16]byte - Uint64 uint64 - Uint32 uint32 - Uint8 uint8 - String string - Strings []string - Bytes []byte - Int *big.Int - Rest []byte `ssh:"rest"` + Bool bool `sshtype:"21"` + Array [16]byte + Uint64 uint64 + Uint32 uint32 + Uint8 uint8 + String string + Strings []string + StringMapLen uint32 + StringMap map[string]string `sshlen:"StringMapLen"` + Bytes []byte + BytesMapLen uint32 + BytesMap map[string][]byte `sshlen:"BytesMapLen"` + Int *big.Int + Rest []byte `ssh:"rest"` } func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value { @@ -53,8 +57,12 @@ func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value { m.Uint64 = uint64(rand.Int63n(1<<63 - 1)) m.Uint32 = uint32(rand.Intn((1 << 31) - 1)) m.Uint8 = uint8(rand.Intn(1 << 8)) - m.String = string(m.Array[:]) + m.String = randomString(rand) m.Strings = randomNameList(rand) + m.StringMap = randomStringMap(rand) + m.StringMapLen = uint32(len(m.StringMap)) + m.BytesMap = randomBytesMap(rand) + m.BytesMapLen = uint32(len(m.BytesMap)) m.Bytes = m.Array[:] m.Int = randomInt(rand) m.Rest = m.Array[:] @@ -212,15 +220,39 @@ func randomBytes(out []byte, rand *rand.Rand) { } } +func randomString(rand *rand.Rand) string { + s := make([]byte, 1+(rand.Int31()&15)) + for j := range s { + s[j] = 'a' + uint8(rand.Int31()&15) + } + return string(s) +} + func randomNameList(rand *rand.Rand) []string { ret := make([]string, rand.Int31()&15) for i := range ret { - s := make([]byte, 1+(rand.Int31()&15)) - for j := range s { - s[j] = 'a' + uint8(rand.Int31()&15) - } - ret[i] = string(s) + ret[i] = randomString(rand) + } + return ret +} + +func randomStringMap(rand *rand.Rand) map[string]string { + mapSize := rand.Int31() & 15 + ret := make(map[string]string, mapSize) + for i := 0; i < int(mapSize); i++ { + ret[randomString(rand)] = randomString(rand) + } + return ret +} +func randomBytesMap(rand *rand.Rand) map[string][]byte { + mapSize := rand.Int31() & 15 + ret := make(map[string][]byte, mapSize) + for i := 0; i < int(mapSize); i++ { + mapData := make([]byte, rand.Int31()&15) + randomBytes(mapData, rand) + ret[randomString(rand)] = mapData } + return ret } diff --git a/ssh/server.go b/ssh/server.go index 6a58e12089..6d97a0a23e 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -255,18 +255,35 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) // We just did the key change, so the session ID is established. s.sessionID = s.transport.getSessionID() - var packet []byte - if packet, err = s.transport.readPacket(); err != nil { - return nil, err - } - var serviceRequest serviceRequestMsg - if err = Unmarshal(packet, &serviceRequest); err != nil { - return nil, err - } - if serviceRequest.Service != serviceUserAuth { - return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating") +readServiceRequestLoop: + for { + var packet []byte + if packet, err = s.transport.readPacket(); err != nil { + return nil, err + } + + switch packet[0] { + case msgExtInfo: + var extInfo extInfoMsg + if err := Unmarshal(packet, &extInfo); err != nil { + return nil, err + } + s.transport.extensions = extInfo.Extensions + continue + case msgServiceRequest: + if err = Unmarshal(packet, &serviceRequest); err != nil { + return nil, err + } + if serviceRequest.Service != serviceUserAuth { + return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating") + } + break readServiceRequestLoop + default: + return nil, fmt.Errorf("ssh: unexpected message received") + } } + serviceAccept := serviceAcceptMsg{ Service: serviceUserAuth, } @@ -658,6 +675,14 @@ userAuthLoop: } } + // TODO(kxd) This is where the second ext info message was being sent as per + // RFC8308 section 2.4, but OpenSSH 8.8's client doesn't seem to like it. It + // thinks it's an unexpected packet for some reason. It seems like all the + // server implementations I've been able to find so far don't actually send + // this message, so it's probably fine to not have it. There aren't currently + // any extension values you'd want to hide from anonymous users at this point + // anyway, as far as I'm aware. + if err := s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil { return nil, err }