diff --git a/ssh/client_auth.go b/ssh/client_auth.go index 409b5ea1d4..a531b4c81e 100644 --- a/ssh/client_auth.go +++ b/ssh/client_auth.go @@ -35,23 +35,10 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error { // RFC 8308, Section 2.4. extensions := make(map[string][]byte) if len(packet) > 0 && packet[0] == msgExtInfo { - var extInfo extInfoMsg - if err := Unmarshal(packet, &extInfo); err != nil { + extensions, err = parseExtInfoMsg(packet) + if err != nil { return err } - payload := extInfo.Payload - for i := uint32(0); i < extInfo.NumExtensions; i++ { - name, rest, ok := parseString(payload) - if !ok { - return parseError(msgExtInfo) - } - value, rest, ok := parseString(rest) - if !ok { - return parseError(msgExtInfo) - } - extensions[string(name)] = value - payload = rest - } packet, err = c.transport.readPacket() if err != nil { return err diff --git a/ssh/common.go b/ssh/common.go index 33a88e1518..39cbbf734e 100644 --- a/ssh/common.go +++ b/ssh/common.go @@ -146,6 +146,31 @@ func parseError(tag uint8) error { return fmt.Errorf("ssh: parse error in message type %d", tag) } +// parseExtInfoMsg returns the extensions from an extInfoMsg packet. +// packet must be an already validated extInfoMsg +func parseExtInfoMsg(packet []byte) (map[string][]byte, error) { + extensions := make(map[string][]byte) + var extInfo extInfoMsg + + if err := Unmarshal(packet, &extInfo); err != nil { + return nil, err + } + payload := extInfo.Payload + for i := uint32(0); i < extInfo.NumExtensions; i++ { + name, rest, ok := parseString(payload) + if !ok { + return nil, parseError(msgExtInfo) + } + value, rest, ok := parseString(rest) + if !ok { + return nil, parseError(msgExtInfo) + } + extensions[string(name)] = value + payload = rest + } + return extensions, nil +} + func findCommon(what string, client []string, server []string) (common string, err error) { for _, c := range client { for _, s := range server { diff --git a/ssh/server.go b/ssh/server.go index 700657746f..20c09e58b4 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -265,25 +265,9 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) if len(packet) > 0 && packet[0] == msgExtInfo { // read SSH_MSG_EXT_INFO - var extInfo extInfoMsg - extensions := make(map[string][]byte) - if err := Unmarshal(packet, &extInfo); err != nil { + if _, err := parseExtInfoMsg(packet); err != nil { return nil, err } - payload := extInfo.Payload - for i := uint32(0); i < extInfo.NumExtensions; i++ { - name, rest, ok := parseString(payload) - if !ok { - return nil, parseError(msgExtInfo) - } - value, rest, ok := parseString(rest) - if !ok { - return nil, parseError(msgExtInfo) - } - extensions[string(name)] = value - payload = rest - } - // read the next packet if packet, err = s.transport.readPacket(); err != nil { return nil, err