diff --git a/go/mysql/auth_server.go b/go/mysql/auth_server.go index 5df328e1896..b568537e16e 100644 --- a/go/mysql/auth_server.go +++ b/go/mysql/auth_server.go @@ -19,7 +19,9 @@ package mysql import ( "bytes" "crypto/rand" + "crypto/rsa" "crypto/sha1" + "crypto/sha256" "encoding/hex" "net" "strings" @@ -117,8 +119,8 @@ func NewSalt() ([]byte, error) { return salt, nil } -// ScramblePassword computes the hash of the password using 4.1+ method. -func ScramblePassword(salt, password []byte) []byte { +// ScrambleMysqlNativePassword computes the hash of the password using 4.1+ method. +func ScrambleMysqlNativePassword(salt, password []byte) []byte { if len(password) == 0 { return nil } @@ -189,6 +191,58 @@ func isPassScrambleMysqlNativePassword(reply, salt []byte, mysqlNativePassword s return bytes.Equal(candidateHash2, hash) } +// ScrambleCachingSha2Password computes the hash of the password using SHA256 as required by +// caching_sha2_password plugin for "fast" authentication +func ScrambleCachingSha2Password(salt []byte, password []byte) []byte { + if len(password) == 0 { + return nil + } + + // stage1Hash = SHA256(password) + crypt := sha256.New() + crypt.Write(password) + stage1 := crypt.Sum(nil) + + // scrambleHash = SHA256(SHA256(stage1Hash) + salt) + crypt.Reset() + crypt.Write(stage1) + innerHash := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(innerHash) + crypt.Write(salt) + scramble := crypt.Sum(nil) + + // token = stage1Hash XOR scrambleHash + for i := range stage1 { + stage1[i] ^= scramble[i] + } + + return stage1 +} + +// EncryptPasswordWithPublicKey obfuscates the password and encrypts it with server's public key as required by +// caching_sha2_password plugin for "full" authentication +func EncryptPasswordWithPublicKey(salt []byte, password []byte, pub *rsa.PublicKey) ([]byte, error) { + if len(password) == 0 { + return nil, nil + } + + buffer := make([]byte, len(password)+1) + copy(buffer, password) + for i := range buffer { + buffer[i] ^= salt[i%len(salt)] + } + + sha1Hash := sha1.New() + enc, err := rsa.EncryptOAEP(sha1Hash, rand.Reader, pub, buffer, nil) + if err != nil { + return nil, err + } + + return enc, nil +} + // Constants for the dialog plugin. const ( mysqlDialogMessage = "Enter password: " diff --git a/go/mysql/auth_server_static.go b/go/mysql/auth_server_static.go index acb90678f54..ef95febb292 100644 --- a/go/mysql/auth_server_static.go +++ b/go/mysql/auth_server_static.go @@ -244,7 +244,7 @@ func (a *AuthServerStatic) ValidateHash(salt []byte, user string, authResponse [ return &StaticUserData{entry.UserData, entry.Groups}, nil } } else { - computedAuthResponse := ScramblePassword(salt, []byte(entry.Password)) + computedAuthResponse := ScrambleMysqlNativePassword(salt, []byte(entry.Password)) // Validate the password. if matchSourceHost(remoteAddr, entry.SourceHost) && bytes.Equal(authResponse, computedAuthResponse) { return &StaticUserData{entry.UserData, entry.Groups}, nil diff --git a/go/mysql/auth_server_static_test.go b/go/mysql/auth_server_static_test.go index f195b4c72d1..d5e43f312fb 100644 --- a/go/mysql/auth_server_static_test.go +++ b/go/mysql/auth_server_static_test.go @@ -92,7 +92,7 @@ func TestValidateHashGetter(t *testing.T) { t.Fatalf("error generating salt: %v", err) } - scrambled := ScramblePassword(salt, []byte("password")) + scrambled := ScrambleMysqlNativePassword(salt, []byte("password")) getter, err := auth.ValidateHash(salt, "mysql_user", scrambled, addr) if err != nil { t.Fatalf("error validating password: %v", err) @@ -270,7 +270,7 @@ func TestStaticPasswords(t *testing.T) { t.Fatalf("error generating salt: %v", err) } - scrambled := ScramblePassword(salt, []byte(c.password)) + scrambled := ScrambleMysqlNativePassword(salt, []byte(c.password)) _, err = auth.ValidateHash(salt, c.user, scrambled, addr) if c.success { diff --git a/go/mysql/auth_server_vault.go b/go/mysql/auth_server_vault.go index 7f72cd72e19..25fe6611a40 100644 --- a/go/mysql/auth_server_vault.go +++ b/go/mysql/auth_server_vault.go @@ -251,7 +251,7 @@ func (a *AuthServerVault) ValidateHash(salt []byte, user string, authResponse [] return &StaticUserData{entry.UserData, entry.Groups}, nil } } else { - computedAuthResponse := ScramblePassword(salt, []byte(entry.Password)) + computedAuthResponse := ScrambleMysqlNativePassword(salt, []byte(entry.Password)) // Validate the password. if matchSourceHost(remoteAddr, entry.SourceHost) && bytes.Equal(authResponse, computedAuthResponse) { return &StaticUserData{entry.UserData, entry.Groups}, nil diff --git a/go/mysql/client.go b/go/mysql/client.go index c9fd577b190..fdb75a67076 100644 --- a/go/mysql/client.go +++ b/go/mysql/client.go @@ -17,7 +17,10 @@ limitations under the License. package mysql import ( + "crypto/rsa" "crypto/tls" + "crypto/x509" + "encoding/pem" "fmt" "net" "strconv" @@ -235,6 +238,7 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error { return err } c.fillFlavor(params) + c.salt = salt // Sanity check. if capabilities&CapabilityClientProtocol41 == 0 { @@ -291,7 +295,12 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error { } // Password encryption. - scrambledPassword := ScramblePassword(salt, []byte(params.Pass)) + var scrambledPassword []byte + if c.authPluginName == CachingSha2Password { + scrambledPassword = ScrambleCachingSha2Password(salt, []byte(params.Pass)) + } else { + scrambledPassword = ScrambleMysqlNativePassword(salt, []byte(params.Pass)) + } // Client Session Tracking Capability. if params.Flags&CapabilityClientSessionTrack == CapabilityClientSessionTrack { @@ -310,54 +319,8 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error { } // Read the server response. - response, err := c.readPacket() - if err != nil { - return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) - } - switch response[0] { - case OKPacket: - // OK packet, we are authenticated. Save the user, keep going. - c.User = params.Uname - case AuthSwitchRequestPacket: - // Server is asking to use a different auth method. We - // only support cleartext plugin. - pluginName, salt, err := parseAuthSwitchRequest(response) - if err != nil { - return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "cannot parse auth switch request: %v", err) - } - - if pluginName == MysqlClearPassword { - // Write the cleartext password packet. - if err := c.writeClearTextPassword(params); err != nil { - return err - } - } else if pluginName == MysqlNativePassword { - // Write the mysql_native_password packet. - if err := c.writeMysqlNativePassword(params, salt); err != nil { - return err - } - } else { - return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "server asked for unsupported auth method: %v", pluginName) - } - - // Wait for OK packet. - response, err = c.readPacket() - if err != nil { - return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) - } - switch response[0] { - case OKPacket: - // OK packet, we are authenticated. Save the user, keep going. - c.User = params.Uname - case ErrPacket: - return ParseErrorPacket(response) - default: - return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "initial server response cannot be parsed: %v", response) - } - case ErrPacket: - return ParseErrorPacket(response) - default: - return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "initial server response cannot be parsed: %v", response) + if err := c.handleAuthResponse(params); err != nil { + return err } // If the server didn't support DbName in its handshake, set @@ -515,10 +478,7 @@ func (c *Conn) parseInitialHandshakePacket(data []byte) (uint32, []byte, error) // 5.6.2 that don't have a null terminated string. authPluginName = string(data[pos : len(data)-1]) } - - if authPluginName != MysqlNativePassword { - return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: only support %v auth plugin name, but got %v", MysqlNativePassword, authPluginName) - } + c.authPluginName = authPluginName } return capabilities, authPluginData, nil @@ -603,7 +563,7 @@ func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword [ lenNullString(params.Uname) + // length of scrambled password is handled below. len(scrambledPassword) + - 21 + // "mysql_native_password" string. + len(c.authPluginName) + 1 // terminating zero. // Add the DB name if the server supports it. @@ -652,7 +612,7 @@ func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword [ } // Assume native client during response - pos = writeNullString(data, pos, MysqlNativePassword) + pos = writeNullString(data, pos, c.authPluginName) // Sanity-check the length. if pos != len(data) { @@ -665,6 +625,110 @@ func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword [ return nil } +// handleAuthResponse parses server's response after client sends the password for authentication +// and handles next steps for AuthSwitchRequestPacket and AuthMoreDataPacket. +func (c *Conn) handleAuthResponse(params *ConnParams) error { + response, err := c.readPacket() + if err != nil { + return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) + } + + switch response[0] { + case OKPacket: + // OK packet, we are authenticated. Save the user, keep going. + c.User = params.Uname + case AuthSwitchRequestPacket: + // Server is asking to use a different auth method + if err = c.handleAuthSwitchPacket(params, response); err != nil { + return err + } + case AuthMoreDataPacket: + // Server is requesting more data - maybe un-scrambled password + if err := c.handleAuthMoreDataPacket(response[1], params); err != nil { + return err + } + case ErrPacket: + return ParseErrorPacket(response) + default: + return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "initial server response cannot be parsed: %v", response) + } + + return nil +} + +// handleAuthSwitchPacket scrambles password for the plugin requested by the server and retries authentication +func (c *Conn) handleAuthSwitchPacket(params *ConnParams, response []byte) error { + var err error + var salt []byte + c.authPluginName, salt, err = parseAuthSwitchRequest(response) + if err != nil { + return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "cannot parse auth switch request: %v", err) + } + if salt != nil { + c.salt = salt + } + switch c.authPluginName { + case MysqlClearPassword: + if err := c.writeClearTextPassword(params); err != nil { + return err + } + case MysqlNativePassword: + scrambledPassword := ScrambleMysqlNativePassword(c.salt, []byte(params.Pass)) + if err := c.writeScrambledPassword(scrambledPassword); err != nil { + return err + } + case CachingSha2Password: + scrambledPassword := ScrambleCachingSha2Password(c.salt, []byte(params.Pass)) + if err := c.writeScrambledPassword(scrambledPassword); err != nil { + return err + } + default: + return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "server asked for unsupported auth method: %v", c.authPluginName) + } + + // The response could be an OKPacket, AuthMoreDataPacket or ErrPacket + return c.handleAuthResponse(params) +} + +// handleAuthMoreDataPacket handles response of CachingSha2Password authentication and sends full password to the +// server if requested +func (c *Conn) handleAuthMoreDataPacket(data byte, params *ConnParams) error { + switch data { + case CachingSha2FastAuth: + // User credentials are verified using the cache ("Fast" path). + // Next packet should be an OKPacket + return c.handleAuthResponse(params) + case CachingSha2FullAuth: + // User credentials are not cached, we have to exchange full password. + if c.Capabilities&CapabilityClientSSL > 0 || params.UnixSocket != "" { + // If we are using an SSL connection or Unix socket, write clear text password + if err := c.writeClearTextPassword(params); err != nil { + return err + } + } else { + // If we are not using an SSL connection or Unix socket, we have to fetch a public key + // from the server to encrypt password + pub, err := c.requestPublicKey() + if err != nil { + return err + } + // Encrypt password with public key + enc, err := EncryptPasswordWithPublicKey(c.salt, []byte(params.Pass), pub) + if err != nil { + return vterrors.Errorf(vtrpc.Code_INTERNAL, "error encrypting password with public key: %v", err) + } + // Write encrypted password + if err := c.writeScrambledPassword(enc); err != nil { + return err + } + } + // Next packet should either be an OKPacket or ErrPacket + return c.handleAuthResponse(params) + default: + return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "cannot parse AuthMoreDataPacket: %v", data) + } +} + func parseAuthSwitchRequest(data []byte) (string, []byte, error) { pos := 1 pluginName, pos, ok := readNullString(data, pos) @@ -680,6 +744,34 @@ func parseAuthSwitchRequest(data []byte) (string, []byte, error) { return pluginName, salt, nil } +// requestPublicKey requests a public key from the server +func (c *Conn) requestPublicKey() (rsaKey *rsa.PublicKey, err error) { + // get public key from server + data, pos := c.startEphemeralPacketWithHeader(1) + data[pos] = 0x02 + if err := c.writeEphemeralPacket(); err != nil { + return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "error sending public key request packet: %v", err) + } + + response, err := c.readPacket() + if err != nil { + return nil, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) + } + + // Server should respond with a AuthMoreDataPacket containing the public key + if response[0] != AuthMoreDataPacket { + return nil, ParseErrorPacket(response) + } + + block, _ := pem.Decode(response[1:]) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "failed to parse public key from server: %v", err) + } + + return pub.(*rsa.PublicKey), nil +} + // writeClearTextPassword writes the clear text password. // Returns a SQLError. func (c *Conn) writeClearTextPassword(params *ConnParams) error { @@ -693,15 +785,14 @@ func (c *Conn) writeClearTextPassword(params *ConnParams) error { return c.writeEphemeralPacket() } -// writeMysqlNativePassword writes the encrypted mysql_native_password format +// writeScrambledPassword writes the encrypted mysql_native_password format // Returns a SQLError. -func (c *Conn) writeMysqlNativePassword(params *ConnParams, salt []byte) error { - scrambledPassword := ScramblePassword(salt, []byte(params.Pass)) +func (c *Conn) writeScrambledPassword(scrambledPassword []byte) error { data, pos := c.startEphemeralPacketWithHeader(len(scrambledPassword)) pos += copy(data[pos:], scrambledPassword) // Sanity check. if pos != len(data) { - return vterrors.Errorf(vtrpc.Code_INTERNAL, "error building MysqlNativePassword packet: got %v bytes expected %v", pos, len(data)) + return vterrors.Errorf(vtrpc.Code_INTERNAL, "error building %v packet: got %v bytes expected %v", c.authPluginName, pos, len(data)) } return c.writeEphemeralPacket() } diff --git a/go/mysql/conn.go b/go/mysql/conn.go index e8edb7cd1f5..346517c6009 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -86,6 +86,13 @@ type Conn struct { // fields, this is set to an empty array (but not nil). fields []*querypb.Field + // salt is sent by the server during initial handshake to be used for authentication + salt []byte + + // authPluginName is the name of server's authentication plugin. + // It is set during the initial handshake. + authPluginName string + // schemaName is the default database name to use. It is set // during handshake, and by ComInitDb packets. Both client and // servers maintain it. This member is private because it's diff --git a/go/mysql/constants.go b/go/mysql/constants.go index a6c1ed7dbf2..a77569ad0f2 100644 --- a/go/mysql/constants.go +++ b/go/mysql/constants.go @@ -36,6 +36,9 @@ const ( // MysqlClearPassword transmits the password in the clear. MysqlClearPassword = "mysql_clear_password" + // CachingSha2Password uses a salt and transmits a SHA256 hash on the wire. + CachingSha2Password = "caching_sha2_password" + // MysqlDialog uses the dialog plugin on the client side. // It transmits data in the clear. MysqlDialog = "dialog" @@ -230,9 +233,6 @@ const ( // EOFPacket is the header of the EOF packet. EOFPacket = 0xfe - // AuthSwitchRequestPacket is used to switch auth method. - AuthSwitchRequestPacket = 0xfe - // ErrPacket is the header of the error packet. ErrPacket = 0xff @@ -240,6 +240,21 @@ const ( NullValue = 0xfb ) +// Auth packet types +const ( + // AuthMoreDataPacket is sent when server requires more data to authenticate + AuthMoreDataPacket = 0x01 + + // CachingSha2FastAuth is sent before OKPacket when server authenticates using cache + CachingSha2FastAuth = 0x03 + + // CachingSha2FullAuth is sent when server requests un-scrambled password to authenticate + CachingSha2FullAuth = 0x04 + + // AuthSwitchRequestPacket is used to switch auth method. + AuthSwitchRequestPacket = 0xfe +) + // Error codes for client-side errors. // Originally found in include/mysql/errmsg.h and // https://dev.mysql.com/doc/refman/5.7/en/error-messages-client.html diff --git a/go/mysql/endtoend/client_test.go b/go/mysql/endtoend/client_test.go index ca5637f04cb..feb723c5947 100644 --- a/go/mysql/endtoend/client_test.go +++ b/go/mysql/endtoend/client_test.go @@ -316,3 +316,44 @@ func TestSessionTrackGTIDs(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, qr.SessionStateChanges) } + +func TestCachingSha2Password(t *testing.T) { + ctx := context.Background() + + // connect as an existing user to create a user account with caching_sha2_password + params := connParams + conn, err := mysql.Connect(ctx, ¶ms) + expectNoError(t, err) + defer conn.Close() + + qr, err := conn.ExecuteFetch(`select true from information_schema.PLUGINS where PLUGIN_NAME='caching_sha2_password' and PLUGIN_STATUS='ACTIVE'`, 1, false) + if err != nil { + t.Errorf("select true from information_schema.PLUGINS failed: %v", err) + } + + if len(qr.Rows) != 1 { + t.Skip("Server does not support caching_sha2_password plugin") + } + + // create a user using caching_sha2_password password + if _, err = conn.ExecuteFetch(`create user 'sha2user'@'localhost' identified with caching_sha2_password by 'password';`, 0, false); err != nil { + t.Fatalf("Create user with caching_sha2_password failed: %v", err) + } + conn.Close() + + // connect as sha2user + params.Uname = "sha2user" + params.Pass = "password" + params.DbName = "information_schema" + conn, err = mysql.Connect(ctx, ¶ms) + expectNoError(t, err) + defer conn.Close() + + if qr, err = conn.ExecuteFetch(`select user()`, 1, true); err != nil { + t.Fatalf("select user() failed: %v", err) + } + + if len(qr.Rows) != 1 || qr.Rows[0][0].ToString() != "sha2user@localhost" { + t.Errorf("Logged in user is not sha2user") + } +}