From d158fb7b26005f2c0c14db77a4ef757a0c6475fe Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Tue, 20 Jul 2021 16:11:00 +0200 Subject: [PATCH 1/3] Refactor authentication server plugin mechanism This change implements a large refactor of the plugin authentication mechanism. This refactor needed for a number of reasons: - The current AuthServer interface is not well designed. Many of the AuthServer implementations implement at least one if not multiple methods as a panic() as they are unused or they implement something never really called (like the salt generation). This indicates that the current interface doesn't abstract things well. - Adding `caching_sha2_password` as a server side mechanism would make this interface problem worse. It would add even more methods that would have emply implementations / panic(). - It is impossible right now for a plugin to support multiple authentication mechanisms. The caching_sha2_password plugin allows for caching of passwords and could be used to improve for example the LDAP plugin. But we can't then still keep the `mysql_clear_password` as a fallback for such an improvement in the LDAP plugin. Based on these issues, I'm proposing this refactor. It allows for a number of things that also make auth server implementations easier. - An auth server can implement multiple auth methods. This allows for the above mentioned optimization for the LDAP plugin and many other improvements. Note that these improvements are not implemented here, but it opens up the opportunity. - There are helpers for all the already supported authentication mechanisms plus basic support for caching_sha2_password. The last only supports TLS connections or Unix sockets right now, since the full authentication protocol is much complicated. The value of the latter is also debatable in the context of Vitess. It depends on a TLS certificate and private key also for the plain text full auth cycle and in that case, TLS would already be available as well anyway. For clients that want caching_sha2_password, switching those to enable TLS seems like a much simpler solution. - With the helpers, specific interfaces are provided that plugins can implement. Those interfaces provide a simpler way than the complex Negotiate from before which might need to read packets etc. For these helpers, all reading / writing of the protocol is already dealt with and not something auth server plugin writers need to be concerned with. Given the above advantages, I think the current setup is a better design and better fits what we'd want to use for authentication plugins. Signed-off-by: Dirkjan Bussink --- go/mysql/auth_server.go | 568 +++++++++++++++----- go/mysql/auth_server_clientcert.go | 76 +-- go/mysql/auth_server_clientcert_test.go | 8 +- go/mysql/auth_server_none.go | 44 +- go/mysql/auth_server_static.go | 199 ++++--- go/mysql/auth_server_static_test.go | 8 +- go/mysql/auth_server_test.go | 55 ++ go/mysql/auth_server_vault.go | 134 ++--- go/mysql/client.go | 8 +- go/mysql/conn.go | 13 +- go/mysql/constants.go | 12 +- go/mysql/fakesqldb/server.go | 2 +- go/mysql/handshake_test.go | 6 +- go/mysql/ldapauthserver/auth_server_ldap.go | 54 +- go/mysql/query_benchmark_test.go | 2 +- go/mysql/server.go | 155 +++--- go/mysql/server_test.go | 123 ++++- 17 files changed, 1008 insertions(+), 459 deletions(-) create mode 100644 go/mysql/auth_server_test.go diff --git a/go/mysql/auth_server.go b/go/mysql/auth_server.go index b568537e16e..300a531203e 100644 --- a/go/mysql/auth_server.go +++ b/go/mysql/auth_server.go @@ -17,14 +17,14 @@ limitations under the License. package mysql import ( - "bytes" "crypto/rand" "crypto/rsa" "crypto/sha1" "crypto/sha256" + "crypto/subtle" + "crypto/x509" "encoding/hex" "net" - "strings" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/proto/vtrpc" @@ -32,94 +32,217 @@ import ( ) // AuthServer is the interface that servers must implement to validate -// users and passwords. It has two modes: +// users and passwords. It needs to be able to return a list of AuthMethod +// interfaces which implement the supported auth methods for the server. +type AuthServer interface { + // AuthMethods returns a list of auth methods that are part of this + // interface. Building an authentication server usually means + // creating AuthMethod instances with the known helpers for the + // currently supported AuthMethod implementations. + // + // When a client connects, the server checks the list of auth methods + // available. If an auth method for the requested auth mechanism by the + // client is available, it will be used immediately. + // + // If there is no overlap between the provided auth methods and + // the one the client requests, the server will send back an + // auth switch request using the first provided AuthMethod in this list. + AuthMethods() []AuthMethod + + // Default auth method description that the auth server sends during + // the initial server handshake. This needs to be either `mysql_native_password` + // or `caching_sha2_password` as those are the only supported auth + // methods during the initial handshake. + // + // It's not needed to also support those methods in the AuthMethods(), + // in fact, if you want to only support for example clear text passwords, + // you must still return `mysql_native_password` or `caching_sha2_password` + // here and the auth switch protocol will be used to switch to clear text. + DefaultAuthMethodDescription() AuthMethodDescription +} + +// AuthMethod interface for concrete auth method implementations. +// When building an auth server, you usually don't implement these yourself +// but the helper methods to build AuthMethod instances should be used. +type AuthMethod interface { + // Name returns the auth method description for this implementation. + // This is the name that is sent as the auth plugin name during the + // Mysql authentication protocol handshake. + Name() AuthMethodDescription + + // HandleUser verifies if the current auth method can authenticate + // the given user with the current auth method. This can be useful + // for example if you only have a plain text of hashed password + // for specific users and not all users and auth method support + // depends on what you have. + HandleUser(conn *Conn, user string) bool + + // AuthPluginData generates the information for the auth plugin. + // This is included in for example the auth switch request. This + // is auth plugin specific and opaque to the Mysql handshake + // protocol. + AuthPluginData() ([]byte, error) + + // HandleAuthPluginData handles the returned auth plugin data from + // the client. The original data the server sent is also included + // which can include things like the salt for `mysql_native_password`. + // + // The remote address is provided for plugins that also want to + // do additional checks like IP based restrictions. + HandleAuthPluginData(conn *Conn, user string, serverAuthPluginData []byte, clientAuthPluginData []byte, remoteAddr net.Addr) (Getter, error) +} + +// UserValidator is an interface that allows checking if a specific +// user will work for an auth method. This interface is called by +// all the default helpers that create AuthMethod instances for +// the various supported Mysql authentication methods. +type UserValidator interface { + HandleUser(user string) bool +} + +// CacheState is a state that is returned by the UserEntryWithCacheHash +// method from the CachingStorage interface. This state is needed to indicate +// whether the authentication is accepted, rejected by the cache itself +// or if the cache can't fullfill the request. In that case it indicates +// that with AuthNeedMoreData. +type CacheState int + +const ( + // AuthRejected is used when the cache knows the request can be rejected. + AuthRejected CacheState = iota + // AuthAccepted is used when the cache knows the request can be accepted. + AuthAccepted + // AuthNeedMoreData is used when the cache doesn't know the answer and more data is needed. + AuthNeedMoreData +) + +// HashStorage describes an object that is suitable to retrieve user information +// based on the hashed authentication response for mysql_native_password. // -// 1. using salt the way MySQL native auth does it. In that case, the -// password is not sent in the clear, but the salt is used to hash the -// password both on the client and server side, and the result is sent -// and compared. +// In general, an implementation of this would use an internally stored password +// that is hashed twice with SHA1. // -// 2. sending the user / password in the clear (using MySQL Cleartext -// method). The server then gets access to both user and password, and -// can authenticate using any method. If SSL is not used, it means the -// password is sent in the clear. That may not be suitable for some -// use cases. -type AuthServer interface { - // AuthMethod returns the authentication method to use for the - // given user. If this returns MysqlNativePassword - // (mysql_native_password), then ValidateHash() will be - // called, and no further roundtrip with the client is - // expected. If anything else is returned, Negotiate() - // will be called on the connection, and the AuthServer - // needs to handle the packets. - AuthMethod(user string) (string, error) - - // Salt returns the salt to use for a connection. - // It should be 20 bytes of data. - // Most implementations should just use mysql.NewSalt(). - // (this is meant to support a plugin that would use an - // existing MySQL server as the source of auth, and just forward - // the salt generated by that server). - // Do not return zero bytes, as a known salt can be the source - // of a crypto attack. - Salt() ([]byte, error) - - // ValidateHash validates the data sent by the client matches - // what the server computes. It also returns the user data. - ValidateHash(salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) - - // Negotiate is called if AuthMethod returns anything else - // than MysqlNativePassword. It is handed the connection after the - // AuthSwitchRequest packet is sent. - // - If the negotiation fails, it should just return an error - // (should be a SQLError if possible). - // The framework is responsible for writing the Error packet - // and closing the connection in that case. - // - If the negotiation works, it should return the Getter, - // and no error. The framework is responsible for writing the - // OK packet. - Negotiate(c *Conn, user string, remoteAddr net.Addr) (Getter, error) +// The VerifyHashedMysqlNativePassword helper method can be used to verify +// such a hash based on the salt and auth response provided here after retrieving +// the hashed password from the storage. +type HashStorage interface { + UserEntryWithHash(userCerts []*x509.Certificate, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) } -// authServers is a registry of AuthServer implementations. -var authServers = make(map[string]AuthServer) +// PlainTextStorage describes an object that is suitable to retrieve user information +// based on the plain text password of a user. This can be obtained through various +// Mysql authentication methods, such as `mysql_clear_passwrd`, `dialog` or +// `caching_sha2_password` in the full authentication handshake case of the latter. +// +// This mechanism also would allow for picking your own password storage in the backend, +// such as BCrypt, SCrypt, PBKDF2 or Argon2 once the plain text is obtained. +// +// When comparing plain text passwords directly, please ensure to use `subtle.ConstantTimeCompare` +// to prevent timing based attacks on the password. +type PlainTextStorage interface { + UserEntryWithPassword(userCerts []*x509.Certificate, user string, password string, remoteAddr net.Addr) (Getter, error) +} -// RegisterAuthServerImpl registers an implementations of AuthServer. -func RegisterAuthServerImpl(name string, authServer AuthServer) { - if _, ok := authServers[name]; ok { - log.Fatalf("AuthServer named %v already exists", name) +// CachingStorage describes an object that is suitable to retrieve user information +// based on a hashed value of the password. This applies to the `caching_sha2_password` +// authentication method. +// +// The cache would hash the password internally as `SHA256(SHA256(password))`. +// +// The VerifyHashedCachingSha2Password helper method can be used to verify +// such a hash based on the salt and auth response provided here after retrieving +// the hashed password from the cache. +type CachingStorage interface { + UserEntryWithCacheHash(userCerts []*x509.Certificate, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, CacheState, error) +} + +// NewMysqlNativeAuthMethod will create a new AuthMethod that implements the +// `mysql_native_password` handshake. The caller will need to provide a storage +// object and validator that will be called during the handshake phase. +func NewMysqlNativeAuthMethod(layer HashStorage, validator UserValidator) AuthMethod { + authMethod := mysqlNativePasswordAuthMethod{ + storage: layer, + validator: validator, } - authServers[name] = authServer + return &authMethod } -// GetAuthServer returns an AuthServer by name, or log.Exitf. -func GetAuthServer(name string) AuthServer { - authServer, ok := authServers[name] - if !ok { - log.Exitf("no AuthServer name %v registered", name) +// NewMysqlClearAuthMethod will create a new AuthMethod that implements the +// `mysql_clear_password` handshake. The caller will need to provide a storage +// object for plain text passwords and validator that will be called during the +// handshake phase. +func NewMysqlClearAuthMethod(layer PlainTextStorage, validator UserValidator) AuthMethod { + authMethod := mysqlClearAuthMethod{ + storage: layer, + validator: validator, } - return authServer + return &authMethod } -// NewSalt returns a 20 character salt. -func NewSalt() ([]byte, error) { - salt := make([]byte, 20) - if _, err := rand.Read(salt); err != nil { - return nil, err +// Constants for the dialog plugin. +const ( + // Default message if no custom message + // is configured. This is used when the message + // is the empty string. + mysqlDialogDefaultMessage = "Enter password: " + + // Dialog plugin is similar to clear text, but can respond to multiple + // prompts in a row. This is not yet implemented. + // Follow questions should be prepended with a `cmd` byte: + // 0x02 - ordinary question + // 0x03 - last question + // 0x04 - password question + // 0x05 - last password + mysqlDialogAskPassword = 0x04 +) + +// NewMysqlDialogAuthMethod will create a new AuthMethod that implements the +// `dialog` handshake. The caller will need to provide a storage object for plain +// text passwords and validator that will be called during the handshake phase. +// The message given will be sent as part of the dialog. If the empty string is +// provided, the default message of "Enter password: " will be used. +func NewMysqlDialogAuthMethod(layer PlainTextStorage, validator UserValidator, msg string) AuthMethod { + if msg == "" { + msg = mysqlDialogDefaultMessage } - // Salt must be a legal UTF8 string. - for i := 0; i < len(salt); i++ { - salt[i] &= 0x7f - if salt[i] == '\x00' || salt[i] == '$' { - salt[i]++ - } + authMethod := mysqlDialogAuthMethod{ + storage: layer, + validator: validator, + msg: msg, } + return &authMethod +} - return salt, nil +// NewSha2CachingAuthMethod will create a new AuthMethod that implements the +// `caching_sha2_password` handshake. The caller will need to provide a cache +// object for the fast auth path and a plain text storage object that will +// be called if the return of the first layer indicates the full auth dance is +// needed. +// +// Right now we only support caching_sha2_password over TLS or a Unix socket. +// +// If TLS is not enabled, the client needs to encrypt it with the public +// key of the server. In that case, Vitess is already configured with +// a certificate anyway, so we recommend to use TLS if you want to use +// caching_sha2_password in that case instead of allowing the plain +// text fallback path here. +// +// This might change in the future if there's a good argument and implementation +// for allowing the plain text path here as well. +func NewSha2CachingAuthMethod(layer1 CachingStorage, layer2 PlainTextStorage, validator UserValidator) AuthMethod { + authMethod := mysqlCachingSha2AuthMethod{ + cache: layer1, + storage: layer2, + validator: validator, + } + return &authMethod } // ScrambleMysqlNativePassword computes the hash of the password using 4.1+ method. +// +// This can be used for example inside a `mysql_native_password` plugin implementation +// if the backend storage implements storage of plain text passwords. func ScrambleMysqlNativePassword(salt, password []byte) []byte { if len(password) == 0 { return nil @@ -148,34 +271,63 @@ func ScrambleMysqlNativePassword(salt, password []byte) []byte { return scramble } -func isPassScrambleMysqlNativePassword(reply, salt []byte, mysqlNativePassword string) bool { - /* - SERVER: recv(reply) - hash_stage1=xor(reply, sha1(salt,hash)) - candidate_hash2=sha1(hash_stage1) - check(candidate_hash2==hash) - */ - if len(reply) == 0 { - return false +// DecodeMysqlNativePasswordHex decodes the standard format used by MySQL +// for 4.1 style password hashes. It drops the optionally leading * before +// decoding the rest as a hex encoded string. +func DecodeMysqlNativePasswordHex(hexEncodedPassword string) ([]byte, error) { + if hexEncodedPassword[0] == '*' { + hexEncodedPassword = hexEncodedPassword[1:] } + return hex.DecodeString(hexEncodedPassword) +} - if mysqlNativePassword == "" { +// VerifyHashedMysqlNativePassword verifies a client reply against a stored hash. +// +// This can be used for example inside a `mysql_native_password` plugin implementation +// if the backend storage where the stored password is a SHA1(SHA1(password)). +// +// All values here are non encoded byte slices, so if you store for example the double +// SHA1 of the password as hex encoded characters, you need to decode that first. +// See DecodeMysqlNativePasswordHex for a decoding helper for the standard encoding +// format of this hash used by MySQL. +func VerifyHashedMysqlNativePassword(reply, salt, hashedNativePassword []byte) bool { + if len(reply) == 0 || len(hashedNativePassword) == 0 { return false } - if strings.Contains(mysqlNativePassword, "*") { - mysqlNativePassword = mysqlNativePassword[1:] + // scramble = SHA1(salt+hash) + crypt := sha1.New() + crypt.Write(salt) + crypt.Write(hashedNativePassword) + scramble := crypt.Sum(nil) + + for i := range scramble { + scramble[i] ^= reply[i] } + hashStage1 := scramble - hash, err := hex.DecodeString(mysqlNativePassword) - if err != nil { + crypt.Reset() + crypt.Write(hashStage1) + candidateHash2 := crypt.Sum(nil) + + return subtle.ConstantTimeCompare(candidateHash2, hashedNativePassword) == 1 +} + +// VerifyHashedCachingSha2Password verifies a client reply against a stored hash. +// +// This can be used for example inside a `caching_sha2_password` plugin implementation +// if the cache storage uses password keys with SHA256(SHA256(password)). +// +// All values here are non encoded byte slices, so if you store for example the double +// SHA256 of the password as hex encoded characters, you need to decode that first. +func VerifyHashedCachingSha2Password(reply, salt, hashedCachingSha2Password []byte) bool { + if len(reply) == 0 || len(hashedCachingSha2Password) == 0 { return false } - // scramble = SHA1(salt+hash) - crypt := sha1.New() + crypt := sha256.New() + crypt.Write(hashedCachingSha2Password) crypt.Write(salt) - crypt.Write(hash) scramble := crypt.Sum(nil) // token = scramble XOR stage1Hash @@ -188,7 +340,7 @@ func isPassScrambleMysqlNativePassword(reply, salt []byte, mysqlNativePassword s crypt.Write(hashStage1) candidateHash2 := crypt.Sum(nil) - return bytes.Equal(candidateHash2, hash) + return subtle.ConstantTimeCompare(candidateHash2, hashedCachingSha2Password) == 1 } // ScrambleCachingSha2Password computes the hash of the password using SHA256 as required by @@ -243,34 +395,207 @@ func EncryptPasswordWithPublicKey(salt []byte, password []byte, pub *rsa.PublicK return enc, nil } -// Constants for the dialog plugin. -const ( - mysqlDialogMessage = "Enter password: " +type mysqlNativePasswordAuthMethod struct { + storage HashStorage + validator UserValidator +} - // Dialog plugin is similar to clear text, but can respond to multiple - // prompts in a row. This is not yet implemented. - // Follow questions should be prepended with a `cmd` byte: - // 0x02 - ordinary question - // 0x03 - last question - // 0x04 - password question - // 0x05 - last password - mysqlDialogAskPassword = 0x04 -) +func (n *mysqlNativePasswordAuthMethod) Name() AuthMethodDescription { + return MysqlNativePassword +} -// authServerDialogSwitchData is a helper method to return the data -// needed in the AuthSwitchRequest packet for the dialog plugin -// to ask for a password. -func authServerDialogSwitchData() []byte { - result := make([]byte, len(mysqlDialogMessage)+2) +func (n *mysqlNativePasswordAuthMethod) HandleUser(conn *Conn, user string) bool { + return n.validator.HandleUser(user) +} + +func (n *mysqlNativePasswordAuthMethod) AuthPluginData() ([]byte, error) { + salt, err := newSalt() + if err != nil { + return nil, err + } + return append(salt, 0), nil +} + +func (n *mysqlNativePasswordAuthMethod) HandleAuthPluginData(conn *Conn, user string, serverAuthPluginData []byte, clientAuthPluginData []byte, remoteAddr net.Addr) (Getter, error) { + if serverAuthPluginData[len(serverAuthPluginData)-1] != 0x00 { + return nil, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) + } + + salt := serverAuthPluginData[:len(serverAuthPluginData)-1] + return n.storage.UserEntryWithHash(conn.GetTLSClientCerts(), salt, user, clientAuthPluginData, remoteAddr) +} + +type mysqlClearAuthMethod struct { + storage PlainTextStorage + validator UserValidator +} + +func (n *mysqlClearAuthMethod) Name() AuthMethodDescription { + return MysqlClearPassword +} + +func (n *mysqlClearAuthMethod) HandleUser(conn *Conn, user string) bool { + return n.validator.HandleUser(user) +} + +func (n *mysqlClearAuthMethod) AuthPluginData() ([]byte, error) { + return nil, nil +} + +func (n *mysqlClearAuthMethod) HandleAuthPluginData(conn *Conn, user string, serverAuthPluginData []byte, clientAuthPluginData []byte, remoteAddr net.Addr) (Getter, error) { + password := "" + if len(clientAuthPluginData) > 0 { + password = string(clientAuthPluginData[:len(clientAuthPluginData)-1]) + } + + return n.storage.UserEntryWithPassword(conn.GetTLSClientCerts(), user, password, remoteAddr) +} + +type mysqlDialogAuthMethod struct { + storage PlainTextStorage + validator UserValidator + msg string +} + +func (n *mysqlDialogAuthMethod) Name() AuthMethodDescription { + return MysqlDialog +} + +func (n *mysqlDialogAuthMethod) HandleUser(conn *Conn, user string) bool { + return n.validator.HandleUser(user) +} + +func (n *mysqlDialogAuthMethod) AuthPluginData() ([]byte, error) { + result := make([]byte, len(n.msg)+2) result[0] = mysqlDialogAskPassword - writeNullString(result, 1, mysqlDialogMessage) - return result + writeNullString(result, 1, n.msg) + return result, nil +} + +func (n *mysqlDialogAuthMethod) HandleAuthPluginData(conn *Conn, user string, serverAuthPluginData []byte, clientAuthPluginData []byte, remoteAddr net.Addr) (Getter, error) { + return n.storage.UserEntryWithPassword(conn.GetTLSClientCerts(), user, string(clientAuthPluginData[:len(clientAuthPluginData)-1]), remoteAddr) +} + +type mysqlCachingSha2AuthMethod struct { + cache CachingStorage + storage PlainTextStorage + validator UserValidator +} + +func (n *mysqlCachingSha2AuthMethod) Name() AuthMethodDescription { + return CachingSha2Password +} + +func (n *mysqlCachingSha2AuthMethod) HandleUser(conn *Conn, user string) bool { + if !conn.TLSEnabled() && !conn.IsUnixSocket() { + return false + } + return n.validator.HandleUser(user) +} + +func (n *mysqlCachingSha2AuthMethod) AuthPluginData() ([]byte, error) { + salt, err := newSalt() + if err != nil { + return nil, err + } + return append(salt, 0), nil +} + +func (n *mysqlCachingSha2AuthMethod) HandleAuthPluginData(c *Conn, user string, serverAuthPluginData []byte, clientAuthPluginData []byte, remoteAddr net.Addr) (Getter, error) { + if serverAuthPluginData[len(serverAuthPluginData)-1] != 0x00 { + return nil, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) + } + + salt := serverAuthPluginData[:len(serverAuthPluginData)-1] + result, cacheState, err := n.cache.UserEntryWithCacheHash(c.GetTLSClientCerts(), salt, user, clientAuthPluginData, remoteAddr) + + if err != nil { + return nil, err + } + + if cacheState == AuthRejected { + return nil, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) + } + + // If we get a result back from the cache that's valid, we can return that immediately. + if cacheState == AuthAccepted { + // We need to write a more data packet to indicate the + // handshake completed properly. This will be followed + // by a regular OK packet which the caller of this method will send. + data, pos := c.startEphemeralPacketWithHeader(2) + pos = writeByte(data, pos, AuthMoreDataPacket) + _ = writeByte(data, pos, CachingSha2FastAuth) + err = c.writeEphemeralPacket() + if err != nil { + return nil, err + } + return result, nil + } + + if !c.TLSEnabled() && !c.IsUnixSocket() { + return nil, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) + } + + data, pos := c.startEphemeralPacketWithHeader(1) + pos = writeByte(data, pos, AuthMoreDataPacket) + writeByte(data, pos, CachingSha2FullAuth) + c.writeEphemeralPacket() + + password, err := readPacketPasswordString(c) + if err != nil { + return nil, err + } + + return n.storage.UserEntryWithPassword(c.GetTLSClientCerts(), user, password, remoteAddr) +} + +// authServers is a registry of AuthServer implementations. +var authServers = make(map[string]AuthServer) + +// RegisterAuthServer registers an implementations of AuthServer. +func RegisterAuthServer(name string, authServer AuthServer) { + if _, ok := authServers[name]; ok { + log.Fatalf("AuthServer named %v already exists", name) + } + authServers[name] = authServer +} + +// GetAuthServer returns an AuthServer by name, or log.Exitf. +func GetAuthServer(name string) AuthServer { + authServer, ok := authServers[name] + if !ok { + log.Exitf("no AuthServer name %v registered", name) + } + return authServer } -// AuthServerReadPacketString is a helper method to read a packet -// as a null terminated string. It is used by the mysql_clear_password -// and dialog plugins. -func AuthServerReadPacketString(c *Conn) (string, error) { +func newSalt() ([]byte, error) { + salt := make([]byte, 20) + if _, err := rand.Read(salt); err != nil { + return nil, err + } + + // Salt must be a legal UTF8 string. + for i := 0; i < len(salt); i++ { + salt[i] &= 0x7f + if salt[i] == '\x00' || salt[i] == '$' { + salt[i]++ + } + } + + return salt, nil +} + +func negotiateAuthMethod(conn *Conn, as AuthServer, user string, requestedAuth AuthMethodDescription) (AuthMethod, error) { + for _, m := range as.AuthMethods() { + if m.Name() == requestedAuth && m.HandleUser(conn, user) { + return m, nil + } + } + return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "unknown auth method requested: %s", string(requestedAuth)) +} + +func readPacketPasswordString(c *Conn) (string, error) { // Read a packet, the password is the payload, as a // zero terminated string. data, err := c.ReadPacket() @@ -282,20 +607,3 @@ func AuthServerReadPacketString(c *Conn) (string, error) { } return string(data[:len(data)-1]), nil } - -// AuthServerNegotiateClearOrDialog will finish a negotiation based on -// the method type for the connection. Only supports -// MysqlClearPassword and MysqlDialog. -func AuthServerNegotiateClearOrDialog(c *Conn, method string) (string, error) { - switch method { - case MysqlClearPassword: - // The password is the next packet in plain text. - return AuthServerReadPacketString(c) - - case MysqlDialog: - return AuthServerReadPacketString(c) - - default: - return "", vterrors.Errorf(vtrpc.Code_INTERNAL, "unrecognized method: %v", method) - } -} diff --git a/go/mysql/auth_server_clientcert.go b/go/mysql/auth_server_clientcert.go index e860248efb4..04f6bb97986 100644 --- a/go/mysql/auth_server_clientcert.go +++ b/go/mysql/auth_server_clientcert.go @@ -17,6 +17,7 @@ limitations under the License. package mysql import ( + "crypto/x509" "flag" "fmt" "net" @@ -24,58 +25,71 @@ import ( "vitess.io/vitess/go/vt/log" ) -var clientcertAuthMethod = flag.String("mysql_clientcert_auth_method", MysqlClearPassword, "client-side authentication method to use. Supported values: mysql_clear_password, dialog.") +var clientcertAuthMethod = flag.String("mysql_clientcert_auth_method", string(MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.") +// AuthServerClientCert implements AuthServer which enforces client side certificates type AuthServerClientCert struct { - Method string + methods []AuthMethod + Method AuthMethodDescription } -// Init is public so it can be called from plugin_auth_clientcert.go (go/cmd/vtgate) +// InitAuthServerClientCert is public so it can be called from plugin_auth_clientcert.go (go/cmd/vtgate) func InitAuthServerClientCert() { if flag.CommandLine.Lookup("mysql_server_ssl_ca").Value.String() == "" { log.Info("Not configuring AuthServerClientCert because mysql_server_ssl_ca is empty") return } - if *clientcertAuthMethod != MysqlClearPassword && *clientcertAuthMethod != MysqlDialog { + if *clientcertAuthMethod != string(MysqlClearPassword) && *clientcertAuthMethod != string(MysqlDialog) { log.Exitf("Invalid mysql_clientcert_auth_method value: only support mysql_clear_password or dialog") } + + ascc := newAuthServerClientCert() + RegisterAuthServer("clientcert", ascc) +} + +func newAuthServerClientCert() *AuthServerClientCert { ascc := &AuthServerClientCert{ - Method: *clientcertAuthMethod, + Method: AuthMethodDescription(*clientcertAuthMethod), } - RegisterAuthServerImpl("clientcert", ascc) -} -// AuthMethod is part of the AuthServer interface. -func (ascc *AuthServerClientCert) AuthMethod(user string) (string, error) { - return ascc.Method, nil + var authMethod AuthMethod + switch AuthMethodDescription(*clientcertAuthMethod) { + case MysqlClearPassword: + authMethod = NewMysqlClearAuthMethod(ascc, ascc) + case MysqlDialog: + authMethod = NewMysqlDialogAuthMethod(ascc, ascc, "") + default: + log.Exitf("Invalid mysql_ldap_auth_method value: only support mysql_clear_password or dialog") + } + + ascc.methods = []AuthMethod{authMethod} + return ascc } -// Salt is not used for this plugin. -func (ascc *AuthServerClientCert) Salt() ([]byte, error) { - return NewSalt() +// AuthMethods returns the implement auth methods for the client +// certificate authentication setup. +func (asl *AuthServerClientCert) AuthMethods() []AuthMethod { + return asl.methods } -// ValidateHash is unimplemented. -func (ascc *AuthServerClientCert) ValidateHash(salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) { - panic("unimplemented") +// DefaultAuthMethodDescription returns always MysqlNativePassword +// for the client certificate authentication setup. +func (asl *AuthServerClientCert) DefaultAuthMethodDescription() AuthMethodDescription { + return MysqlNativePassword } -// Negotiate is part of the AuthServer interface. -func (ascc *AuthServerClientCert) Negotiate(c *Conn, user string, remoteAddr net.Addr) (Getter, error) { - // This code depends on the fact that golang's tls server enforces client cert verification. - // Note that the -mysql_server_ssl_ca flag must be set in order for the vtgate to accept client certs. - // If not set, the vtgate will effectively deny all incoming mysql connections, since they will all lack certificates. - // For more info, check out go/vt/vtttls/vttls.go - certs := c.GetTLSClientCerts() - if len(certs) == 0 { - return nil, fmt.Errorf("no client certs for connection ID %v", c.ConnectionID) - } +// HandleUser is part of the UserValidator interface. We +// handle any user here since we don't check up front. +func (asl *AuthServerClientCert) HandleUser(user string) bool { + return true +} - if _, err := AuthServerReadPacketString(c); err != nil { - return nil, err +// UserEntryWithPassword is part of the PlaintextStorage interface +func (asl *AuthServerClientCert) UserEntryWithPassword(userCerts []*x509.Certificate, user string, password string, remoteAddr net.Addr) (Getter, error) { + if len(userCerts) == 0 { + return nil, fmt.Errorf("no client certs for connection") } - - commonName := certs[0].Subject.CommonName + commonName := userCerts[0].Subject.CommonName if user != commonName { return nil, fmt.Errorf("MySQL connection username '%v' does not match client cert common name '%v'", user, commonName) @@ -83,6 +97,6 @@ func (ascc *AuthServerClientCert) Negotiate(c *Conn, user string, remoteAddr net return &StaticUserData{ username: commonName, - groups: certs[0].DNSNames, + groups: userCerts[0].DNSNames, }, nil } diff --git a/go/mysql/auth_server_clientcert_test.go b/go/mysql/auth_server_clientcert_test.go index 0922858212e..468f675e559 100644 --- a/go/mysql/auth_server_clientcert_test.go +++ b/go/mysql/auth_server_clientcert_test.go @@ -35,9 +35,7 @@ const clientCertUsername = "Client Cert" func TestValidCert(t *testing.T) { th := &testHandler{} - authServer := &AuthServerClientCert{ - Method: MysqlClearPassword, - } + authServer := newAuthServerClientCert() // Create the listener, so we can get its host. l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false) @@ -120,9 +118,7 @@ func TestValidCert(t *testing.T) { func TestNoCert(t *testing.T) { th := &testHandler{} - authServer := &AuthServerClientCert{ - Method: MysqlClearPassword, - } + authServer := newAuthServerClientCert() // Create the listener, so we can get its host. l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false) diff --git a/go/mysql/auth_server_none.go b/go/mysql/auth_server_none.go index c560238303b..39baa4b357b 100644 --- a/go/mysql/auth_server_none.go +++ b/go/mysql/auth_server_none.go @@ -17,6 +17,7 @@ limitations under the License. package mysql import ( + "crypto/x509" "net" querypb "vitess.io/vitess/go/vt/proto/query" @@ -27,32 +28,43 @@ import ( // With this config, you can connect to a local vtgate using // the following command line: 'mysql -P port -h ::'. // It only uses MysqlNativePassword method. -type AuthServerNone struct{} +type AuthServerNone struct { + methods []AuthMethod +} -// AuthMethod is part of the AuthServer interface. -// We always return MysqlNativePassword. -func (a *AuthServerNone) AuthMethod(user string) (string, error) { - return MysqlNativePassword, nil +// AuthMethods returns the list of registered auth methods +// implemented by this auth server. +func (a *AuthServerNone) AuthMethods() []AuthMethod { + return a.methods } -// Salt makes salt -func (a *AuthServerNone) Salt() ([]byte, error) { - return NewSalt() +// DefaultAuthMethodDescription returns MysqlNativePassword as the default +// authentication method for the auth server implementation. +func (a *AuthServerNone) DefaultAuthMethodDescription() AuthMethodDescription { + return MysqlNativePassword } -// ValidateHash validates hash -func (a *AuthServerNone) ValidateHash(salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) { - return &NoneGetter{}, nil +// HandleUser validates if this user can use this auth method +func (a *AuthServerNone) HandleUser(user string) bool { + return true } -// Negotiate is part of the AuthServer interface. -// It will never be called. -func (a *AuthServerNone) Negotiate(c *Conn, user string, remotAddr net.Addr) (Getter, error) { - panic("Negotiate should not be called as AuthMethod returned mysql_native_password") +// UserEntryWithHash validates the user if it exists and returns the information. +// Always accepts any user. +func (a *AuthServerNone) UserEntryWithHash(userCerts []*x509.Certificate, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) { + return &NoneGetter{}, nil } func init() { - RegisterAuthServerImpl("none", &AuthServerNone{}) + a := NewAuthServerNone() + RegisterAuthServer("none", a) +} + +// NewAuthServerNone returns an empty auth server. Always accepts all clients. +func NewAuthServerNone() *AuthServerNone { + a := &AuthServerNone{} + a.methods = []AuthMethod{NewMysqlNativeAuthMethod(a, a)} + return a } // NoneGetter holds the empty string diff --git a/go/mysql/auth_server_static.go b/go/mysql/auth_server_static.go index ef95febb292..496cf6c8577 100644 --- a/go/mysql/auth_server_static.go +++ b/go/mysql/auth_server_static.go @@ -18,6 +18,8 @@ package mysql import ( "bytes" + "crypto/subtle" + "crypto/x509" "encoding/json" "flag" "io/ioutil" @@ -46,14 +48,9 @@ const ( // AuthServerStatic implements AuthServer using a static configuration. type AuthServerStatic struct { + methods []AuthMethod file, jsonConfig string reloadInterval time.Duration - // method can be set to: - // - MysqlNativePassword - // - MysqlClearPassword - // - MysqlDialog - // It defaults to MysqlNativePassword. - method string // This mutex helps us prevent data races between the multiple updates of entries. mu sync.Mutex // entries contains the users, passwords and user data. @@ -109,7 +106,7 @@ func RegisterAuthServerStaticFromParams(file, jsonConfig string, reloadInterval if len(authServerStatic.entries) <= 0 { log.Exitf("Failed to populate entries from file: %v", file) } - RegisterAuthServerImpl("static", authServerStatic) + RegisterAuthServer("static", authServerStatic) } // NewAuthServerStatic returns a new empty AuthServerStatic. @@ -118,14 +115,137 @@ func NewAuthServerStatic(file, jsonConfig string, reloadInterval time.Duration) file: file, jsonConfig: jsonConfig, reloadInterval: reloadInterval, - method: MysqlNativePassword, entries: make(map[string][]*AuthServerStaticEntry), } + + a.methods = []AuthMethod{NewMysqlNativeAuthMethod(a, a)} + a.reload() a.installSignalHandlers() return a } +// NewAuthServerStaticWithAuthMethodDescription returns a new empty AuthServerStatic +// but with support for a different auth method. Mostly used for testing purposes. +func NewAuthServerStaticWithAuthMethodDescription(file, jsonConfig string, reloadInterval time.Duration, authMethodDescription AuthMethodDescription) *AuthServerStatic { + a := &AuthServerStatic{ + file: file, + jsonConfig: jsonConfig, + reloadInterval: reloadInterval, + entries: make(map[string][]*AuthServerStaticEntry), + } + + var authMethod AuthMethod + switch authMethodDescription { + case CachingSha2Password: + authMethod = NewSha2CachingAuthMethod(a, a, a) + case MysqlNativePassword: + authMethod = NewMysqlNativeAuthMethod(a, a) + case MysqlClearPassword: + authMethod = NewMysqlClearAuthMethod(a, a) + case MysqlDialog: + authMethod = NewMysqlDialogAuthMethod(a, a, "") + } + + a.methods = []AuthMethod{authMethod} + + a.reload() + a.installSignalHandlers() + return a +} + +// HandleUser is part of the Validator interface. We +// handle any user here since we don't check up front. +func (a *AuthServerStatic) HandleUser(user string) bool { + return true +} + +// UserEntryWithPassword implements password lookup based on a plain +// text password that is negotiated with the client. +func (a *AuthServerStatic) UserEntryWithPassword(userCerts []*x509.Certificate, user string, password string, remoteAddr net.Addr) (Getter, error) { + a.mu.Lock() + entries, ok := a.entries[user] + a.mu.Unlock() + + if !ok { + return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) + } + + for _, entry := range entries { + // Validate the password. + if matchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare([]byte(password), []byte(entry.Password)) == 1 { + return &StaticUserData{entry.UserData, entry.Groups}, nil + } + } + return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) +} + +// UserEntryWithHash implements password lookup based on a +// mysql_native_password hash that is negotiated with the client. +func (a *AuthServerStatic) UserEntryWithHash(userCerts []*x509.Certificate, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) { + a.mu.Lock() + entries, ok := a.entries[user] + a.mu.Unlock() + + if !ok { + return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) + } + + for _, entry := range entries { + if entry.MysqlNativePassword != "" { + hash, err := DecodeMysqlNativePasswordHex(entry.MysqlNativePassword) + if err != nil { + return &StaticUserData{entry.UserData, entry.Groups}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) + } + + isPass := VerifyHashedMysqlNativePassword(authResponse, salt, hash) + if matchSourceHost(remoteAddr, entry.SourceHost) && isPass { + return &StaticUserData{entry.UserData, entry.Groups}, nil + } + } else { + computedAuthResponse := ScrambleMysqlNativePassword(salt, []byte(entry.Password)) + // Validate the password. + if matchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 { + return &StaticUserData{entry.UserData, entry.Groups}, nil + } + } + } + return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) +} + +// UserEntryWithCacheHash implements password lookup based on a +// caching_sha2_password hash that is negotiated with the client. +func (a *AuthServerStatic) UserEntryWithCacheHash(userCerts []*x509.Certificate, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, CacheState, error) { + a.mu.Lock() + entries, ok := a.entries[user] + a.mu.Unlock() + + if !ok { + return &StaticUserData{}, AuthRejected, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) + } + + for _, entry := range entries { + computedAuthResponse := ScrambleCachingSha2Password(salt, []byte(entry.Password)) + + // Validate the password. + if matchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 { + return &StaticUserData{entry.UserData, entry.Groups}, AuthAccepted, nil + } + } + return &StaticUserData{}, AuthRejected, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) +} + +// AuthMethods returns the AuthMethod instances this auth server can handle. +func (a *AuthServerStatic) AuthMethods() []AuthMethod { + return a.methods +} + +// DefaultAuthMethodDescription returns the default auth method in the handshake which +// is MysqlNativePassword for this auth server. +func (a *AuthServerStatic) DefaultAuthMethodDescription() AuthMethodDescription { + return MysqlNativePassword +} + func (a *AuthServerStatic) reload() { jsonBytes := []byte(a.jsonConfig) if a.file != "" { @@ -217,69 +337,6 @@ func validateConfig(config map[string][]*AuthServerStaticEntry) error { return nil } -// AuthMethod is part of the AuthServer interface. -func (a *AuthServerStatic) AuthMethod(user string) (string, error) { - return a.method, nil -} - -// Salt is part of the AuthServer interface. -func (a *AuthServerStatic) Salt() ([]byte, error) { - return NewSalt() -} - -// ValidateHash is part of the AuthServer interface. -func (a *AuthServerStatic) ValidateHash(salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) { - a.mu.Lock() - entries, ok := a.entries[user] - a.mu.Unlock() - - if !ok { - return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) - } - - for _, entry := range entries { - if entry.MysqlNativePassword != "" { - isPass := isPassScrambleMysqlNativePassword(authResponse, salt, entry.MysqlNativePassword) - if matchSourceHost(remoteAddr, entry.SourceHost) && isPass { - return &StaticUserData{entry.UserData, entry.Groups}, nil - } - } else { - computedAuthResponse := ScrambleMysqlNativePassword(salt, []byte(entry.Password)) - // Validate the password. - if matchSourceHost(remoteAddr, entry.SourceHost) && bytes.Equal(authResponse, computedAuthResponse) { - return &StaticUserData{entry.UserData, entry.Groups}, nil - } - } - } - return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) -} - -// Negotiate is part of the AuthServer interface. -// It will be called if method is anything else than MysqlNativePassword. -// We only recognize MysqlClearPassword and MysqlDialog here. -func (a *AuthServerStatic) Negotiate(c *Conn, user string, remoteAddr net.Addr) (Getter, error) { - // Finish the negotiation. - password, err := AuthServerNegotiateClearOrDialog(c, a.method) - if err != nil { - return nil, err - } - - a.mu.Lock() - entries, ok := a.entries[user] - a.mu.Unlock() - - if !ok { - return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) - } - for _, entry := range entries { - // Validate the password. - if matchSourceHost(remoteAddr, entry.SourceHost) && entry.Password == password { - return &StaticUserData{entry.UserData, entry.Groups}, nil - } - } - return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) -} - func matchSourceHost(remoteAddr net.Addr, targetSourceHost string) bool { // Legacy support, there was not matcher defined default to true if targetSourceHost == "" { diff --git a/go/mysql/auth_server_static_test.go b/go/mysql/auth_server_static_test.go index d5e43f312fb..0c943debb32 100644 --- a/go/mysql/auth_server_static_test.go +++ b/go/mysql/auth_server_static_test.go @@ -87,13 +87,13 @@ func TestValidateHashGetter(t *testing.T) { ip := net.ParseIP("127.0.0.1") addr := &net.IPAddr{IP: ip, Zone: ""} - salt, err := NewSalt() + salt, err := newSalt() if err != nil { t.Fatalf("error generating salt: %v", err) } scrambled := ScrambleMysqlNativePassword(salt, []byte("password")) - getter, err := auth.ValidateHash(salt, "mysql_user", scrambled, addr) + getter, err := auth.UserEntryWithHash(nil, salt, "mysql_user", scrambled, addr) if err != nil { t.Fatalf("error validating password: %v", err) } @@ -265,13 +265,13 @@ func TestStaticPasswords(t *testing.T) { for _, c := range tests { t.Run(fmt.Sprintf("%s-%s", c.user, c.password), func(t *testing.T) { - salt, err := NewSalt() + salt, err := newSalt() if err != nil { t.Fatalf("error generating salt: %v", err) } scrambled := ScrambleMysqlNativePassword(salt, []byte(c.password)) - _, err = auth.ValidateHash(salt, c.user, scrambled, addr) + _, err = auth.UserEntryWithHash(nil, salt, c.user, scrambled, addr) if c.success { if err != nil { diff --git a/go/mysql/auth_server_test.go b/go/mysql/auth_server_test.go new file mode 100644 index 00000000000..14457da8ba4 --- /dev/null +++ b/go/mysql/auth_server_test.go @@ -0,0 +1,55 @@ +/* +copyright 2020 The Vitess Authors. + +licensed under the apache license, version 2.0 (the "license"); +you may not use this file except in compliance with the license. +you may obtain a copy of the license at + + http://www.apache.org/licenses/license-2.0 + +unless required by applicable law or agreed to in writing, software +distributed under the license is distributed on an "as is" basis, +without warranties or conditions of any kind, either express or implied. +see the license for the specific language governing permissions and +limitations under the license. +*/ + +package mysql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestVerifyHashedMysqlNativePassword(t *testing.T) { + salt := []byte{10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21} + password := "secret" + + // Double SHA1 of "secret" + passwordHash := []byte{0x38, 0x81, 0x21, 0x9d, 0x08, 0x7d, 0xd9, 0xc6, 0x34, 0x37, 0x3f, 0xd3, + 0x3d, 0xfa, 0x33, 0xa2, 0xcb, 0x6b, 0xfc, 0x6c, 0x52, 0x0b, 0x64, 0xb8, + 0xbb, 0x60, 0xef, 0x2c, 0xeb, 0x53, 0x4a, 0xe7} + + reply := ScrambleCachingSha2Password(salt, []byte(password)) + + assert.True(t, VerifyHashedCachingSha2Password(reply, salt, passwordHash), "password hash mismatch") + + passwordHash[0] = 0x00 + assert.False(t, VerifyHashedCachingSha2Password(reply, salt, passwordHash), "password hash match") +} + +func TestVerifyHashedCachingSha2Password(t *testing.T) { + salt := []byte{10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21} + password := "secret" + + // Double SHA1 of "secret" + passwordHash := []byte{0x14, 0xe6, 0x55, 0x67, 0xab, 0xdb, 0x51, 0x35, 0xd0, 0xcf, 0xd9, 0xa7, + 0x0b, 0x30, 0x32, 0xc1, 0x79, 0xa4, 0x9e, 0xe7} + + reply := ScrambleMysqlNativePassword(salt, []byte(password)) + assert.True(t, VerifyHashedMysqlNativePassword(reply, salt, passwordHash), "password hash mismatch") + + passwordHash[0] = 0x00 + assert.False(t, VerifyHashedMysqlNativePassword(reply, salt, passwordHash), "password hash match") +} diff --git a/go/mysql/auth_server_vault.go b/go/mysql/auth_server_vault.go index 25fe6611a40..b0a8c1fe49a 100644 --- a/go/mysql/auth_server_vault.go +++ b/go/mysql/auth_server_vault.go @@ -17,7 +17,8 @@ limitations under the License. package mysql import ( - "bytes" + "crypto/subtle" + "crypto/x509" "flag" "fmt" "io/ioutil" @@ -48,16 +49,11 @@ var ( // AuthServerVault implements AuthServer with a config loaded from Vault. type AuthServerVault struct { - mu sync.Mutex - // method can be set to: - // - MysqlNativePassword - // - MysqlClearPassword - // - MysqlDialog - // It defaults to MysqlNativePassword. - method string + methods []AuthMethod + mu sync.Mutex // users, passwords and user data - // We use the same JSON format as for -mysql_auth_server_static - // Acts as a cache for the in-Vault data + // We use the same JSON format as for -mysql_auth_server_static + // Acts as a cache for the in-Vault data entries map[string][]*AuthServerStaticEntry vaultCacheExpireTicker *time.Ticker vaultClient *vaultapi.Client @@ -86,7 +82,7 @@ func registerAuthServerVault(addr string, timeout time.Duration, caCertPath stri if err != nil { log.Exitf("%s", err) } - RegisterAuthServerImpl("vault", authServerVault) + RegisterAuthServer("vault", authServerVault) } func newAuthServerVault(addr string, timeout time.Duration, caCertPath string, path string, ttl time.Duration, tokenFilePath string, roleID string, secretIDPath string, roleMountPoint string) (*AuthServerVault, error) { @@ -140,15 +136,66 @@ func newAuthServerVault(addr string, timeout time.Duration, caCertPath string, p vaultClient: client, vaultPath: path, vaultTTL: ttl, - method: MysqlNativePassword, entries: make(map[string][]*AuthServerStaticEntry), } + authMethodNative := NewMysqlNativeAuthMethod(a, a) + a.methods = []AuthMethod{authMethodNative} + a.reloadVault() a.installSignalHandlers() return a, nil } +// AuthMethods returns the list of registered auth methods +// implemented by this auth server. +func (a *AuthServerVault) AuthMethods() []AuthMethod { + return a.methods +} + +// DefaultAuthMethodDescription returns MysqlNativePassword as the default +// authentication method for the auth server implementation. +func (a *AuthServerVault) DefaultAuthMethodDescription() AuthMethodDescription { + return MysqlNativePassword +} + +// HandleUser is part of the Validator interface. We +// handle any user here since we don't check up front. +func (a *AuthServerVault) HandleUser(user string) bool { + return true +} + +// UserEntryWithHash is called when mysql_native_password is used. +func (a *AuthServerVault) UserEntryWithHash(userCerts []*x509.Certificate, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) { + a.mu.Lock() + userEntries, ok := a.entries[user] + a.mu.Unlock() + + if !ok { + return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) + } + + for _, entry := range userEntries { + if entry.MysqlNativePassword != "" { + hash, err := DecodeMysqlNativePasswordHex(entry.MysqlNativePassword) + if err != nil { + return &StaticUserData{entry.UserData, entry.Groups}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) + } + isPass := VerifyHashedMysqlNativePassword(authResponse, salt, hash) + if matchSourceHost(remoteAddr, entry.SourceHost) && isPass { + return &StaticUserData{entry.UserData, entry.Groups}, nil + } + } else { + computedAuthResponse := ScrambleMysqlNativePassword(salt, []byte(entry.Password)) + // Validate the password. + if matchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 { + return &StaticUserData{entry.UserData, entry.Groups}, nil + } + } + } + return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) +} + func (a *AuthServerVault) setTTLTicker(ttl time.Duration) { a.mu.Lock() defer a.mu.Unlock() @@ -224,69 +271,6 @@ func (a *AuthServerVault) close() { } } -// AuthMethod is part of the AuthServer interface. -func (a *AuthServerVault) AuthMethod(user string) (string, error) { - return a.method, nil -} - -// Salt is part of the AuthServer interface. -func (a *AuthServerVault) Salt() ([]byte, error) { - return NewSalt() -} - -// ValidateHash is part of the AuthServer interface. -func (a *AuthServerVault) ValidateHash(salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) { - a.mu.Lock() - userEntries, ok := a.entries[user] - a.mu.Unlock() - - if !ok { - return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) - } - - for _, entry := range userEntries { - if entry.MysqlNativePassword != "" { - isPass := isPassScrambleMysqlNativePassword(authResponse, salt, entry.MysqlNativePassword) - if matchSourceHost(remoteAddr, entry.SourceHost) && isPass { - return &StaticUserData{entry.UserData, entry.Groups}, nil - } - } else { - computedAuthResponse := ScrambleMysqlNativePassword(salt, []byte(entry.Password)) - // Validate the password. - if matchSourceHost(remoteAddr, entry.SourceHost) && bytes.Equal(authResponse, computedAuthResponse) { - return &StaticUserData{entry.UserData, entry.Groups}, nil - } - } - } - return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) -} - -// Negotiate is part of the AuthServer interface. -// It will be called if method is anything else than MysqlNativePassword. -// We only recognize MysqlClearPassword and MysqlDialog here. -func (a *AuthServerVault) Negotiate(c *Conn, user string, remoteAddr net.Addr) (Getter, error) { - // Finish the negotiation. - password, err := AuthServerNegotiateClearOrDialog(c, a.method) - if err != nil { - return nil, err - } - - a.mu.Lock() - userEntries, ok := a.entries[user] - a.mu.Unlock() - - if !ok { - return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) - } - for _, entry := range userEntries { - // Validate the password. - if matchSourceHost(remoteAddr, entry.SourceHost) && entry.Password == password { - return &StaticUserData{entry.UserData, entry.Groups}, nil - } - } - return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) -} - // We ignore most errors here, to allow us to retry cleanly // or ignore the cases where the input is not passed by file, but via env func readFromFile(filePath string) (string, error) { diff --git a/go/mysql/client.go b/go/mysql/client.go index b63ce692058..dee5d8e4833 100644 --- a/go/mysql/client.go +++ b/go/mysql/client.go @@ -483,7 +483,7 @@ func (c *Conn) parseInitialHandshakePacket(data []byte) (uint32, []byte, error) // 5.6.2 that don't have a null terminated string. authPluginName = string(data[pos : len(data)-1]) } - c.authPluginName = authPluginName + c.authPluginName = AuthMethodDescription(authPluginName) } return capabilities, authPluginData, nil @@ -617,7 +617,7 @@ func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword [ } // Assume native client during response - pos = writeNullString(data, pos, c.authPluginName) + pos = writeNullString(data, pos, string(c.authPluginName)) // Sanity-check the length. if pos != len(data) { @@ -734,7 +734,7 @@ func (c *Conn) handleAuthMoreDataPacket(data byte, params *ConnParams) error { } } -func parseAuthSwitchRequest(data []byte) (string, []byte, error) { +func parseAuthSwitchRequest(data []byte) (AuthMethodDescription, []byte, error) { pos := 1 pluginName, pos, ok := readNullString(data, pos) if !ok { @@ -746,7 +746,7 @@ func parseAuthSwitchRequest(data []byte) (string, []byte, error) { if len(salt) > 20 { salt = salt[:20] } - return pluginName, salt, nil + return AuthMethodDescription(pluginName), salt, nil } // requestPublicKey requests a public key from the server diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 834cf60afca..0c557d143bb 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -91,7 +91,7 @@ type Conn struct { // authPluginName is the name of server's authentication plugin. // It is set during the initial handshake. - authPluginName string + authPluginName AuthMethodDescription // schemaName is the default database name to use. It is set // during handshake, and by ComInitDb packets. Both client and @@ -1499,6 +1499,17 @@ func (c *Conn) GetTLSClientCerts() []*x509.Certificate { return nil } +// TLSEnabled returns true if this connection is using TLS. +func (c *Conn) TLSEnabled() bool { + return c.Capabilities&CapabilityClientSSL > 0 +} + +// IsUnixSocket returns true if this connection is over a Unix socket. +func (c *Conn) IsUnixSocket() bool { + _, ok := c.listener.listener.(*net.UnixListener) + return ok +} + // GetRawConn returns the raw net.Conn for nefarious purposes. func (c *Conn) GetRawConn() net.Conn { return c.conn diff --git a/go/mysql/constants.go b/go/mysql/constants.go index cd18e49b266..99cf5ff6273 100644 --- a/go/mysql/constants.go +++ b/go/mysql/constants.go @@ -34,20 +34,24 @@ const ( protocolVersion = 10 ) +// AuthMethodDescription is the type for different supported and +// implemented authentication methods. +type AuthMethodDescription string + // Supported auth forms. const ( // MysqlNativePassword uses a salt and transmits a hash on the wire. - MysqlNativePassword = "mysql_native_password" + MysqlNativePassword = AuthMethodDescription("mysql_native_password") // MysqlClearPassword transmits the password in the clear. - MysqlClearPassword = "mysql_clear_password" + MysqlClearPassword = AuthMethodDescription("mysql_clear_password") // CachingSha2Password uses a salt and transmits a SHA256 hash on the wire. - CachingSha2Password = "caching_sha2_password" + CachingSha2Password = AuthMethodDescription("caching_sha2_password") // MysqlDialog uses the dialog plugin on the client side. // It transmits data in the clear. - MysqlDialog = "dialog" + MysqlDialog = AuthMethodDescription("dialog") ) // Capability flags. diff --git a/go/mysql/fakesqldb/server.go b/go/mysql/fakesqldb/server.go index 863ee2bbd91..065c3400ecf 100644 --- a/go/mysql/fakesqldb/server.go +++ b/go/mysql/fakesqldb/server.go @@ -173,7 +173,7 @@ func New(t testing.TB) *DB { db.Handler = db - authServer := &mysql.AuthServerNone{} + authServer := mysql.NewAuthServerNone() // Start listening. db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false) diff --git a/go/mysql/handshake_test.go b/go/mysql/handshake_test.go index 35f3224d630..8acc543a6fb 100644 --- a/go/mysql/handshake_test.go +++ b/go/mysql/handshake_test.go @@ -38,8 +38,7 @@ import ( func TestClearTextClientAuth(t *testing.T) { th := &testHandler{} - authServer := NewAuthServerStatic("", "", 0) - authServer.method = MysqlClearPassword + authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword) authServer.entries["user1"] = []*AuthServerStaticEntry{ {Password: "password1"}, } @@ -96,7 +95,7 @@ func TestClearTextClientAuth(t *testing.T) { func TestSSLConnection(t *testing.T) { th := &testHandler{} - authServer := NewAuthServerStatic("", "", 0) + authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword) authServer.entries["user1"] = []*AuthServerStaticEntry{ {Password: "password1"}, } @@ -156,7 +155,6 @@ func TestSSLConnection(t *testing.T) { // Make sure clear text auth works over SSL. t.Run("ClearText", func(t *testing.T) { - authServer.method = MysqlClearPassword testSSLConnectionClearText(t, params) }) } diff --git a/go/mysql/ldapauthserver/auth_server_ldap.go b/go/mysql/ldapauthserver/auth_server_ldap.go index 172b20335c6..727b28b5b14 100644 --- a/go/mysql/ldapauthserver/auth_server_ldap.go +++ b/go/mysql/ldapauthserver/auth_server_ldap.go @@ -17,6 +17,7 @@ limitations under the License. package ldapauthserver import ( + "crypto/x509" "encoding/json" "flag" "fmt" @@ -37,19 +38,19 @@ import ( var ( ldapAuthConfigFile = flag.String("mysql_ldap_auth_config_file", "", "JSON File from which to read LDAP server config.") ldapAuthConfigString = flag.String("mysql_ldap_auth_config_string", "", "JSON representation of LDAP server config.") - ldapAuthMethod = flag.String("mysql_ldap_auth_method", mysql.MysqlClearPassword, "client-side authentication method to use. Supported values: mysql_clear_password, dialog.") + ldapAuthMethod = flag.String("mysql_ldap_auth_method", string(mysql.MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.") ) // AuthServerLdap implements AuthServer with an LDAP backend type AuthServerLdap struct { Client ServerConfig - Method string User string Password string GroupQuery string UserDnPattern string RefreshSeconds int64 + methods []mysql.AuthMethod } // Init is public so it can be called from plugin_auth_ldap.go (go/cmd/vtgate) @@ -62,13 +63,13 @@ func Init() { log.Infof("Both mysql_ldap_auth_config_file and mysql_ldap_auth_config_string are non-empty, can only use one.") return } - if *ldapAuthMethod != mysql.MysqlClearPassword && *ldapAuthMethod != mysql.MysqlDialog { + + if *ldapAuthMethod != string(mysql.MysqlClearPassword) && *ldapAuthMethod != string(mysql.MysqlDialog) { log.Exitf("Invalid mysql_ldap_auth_method value: only support mysql_clear_password or dialog") } ldapAuthServer := &AuthServerLdap{ Client: &ClientImpl{}, ServerConfig: ServerConfig{}, - Method: *ldapAuthMethod, } data := []byte(*ldapAuthConfigString) @@ -82,31 +83,42 @@ func Init() { if err := json.Unmarshal(data, ldapAuthServer); err != nil { log.Exitf("Error parsing AuthServerLdap config: %v", err) } - mysql.RegisterAuthServerImpl("ldap", ldapAuthServer) + + var authMethod mysql.AuthMethod + switch mysql.AuthMethodDescription(*ldapAuthMethod) { + case mysql.MysqlClearPassword: + authMethod = mysql.NewMysqlClearAuthMethod(ldapAuthServer, ldapAuthServer) + case mysql.MysqlDialog: + authMethod = mysql.NewMysqlDialogAuthMethod(ldapAuthServer, ldapAuthServer, "") + default: + log.Exitf("Invalid mysql_ldap_auth_method value: only support mysql_clear_password or dialog") + } + + ldapAuthServer.methods = []mysql.AuthMethod{authMethod} + mysql.RegisterAuthServer("ldap", ldapAuthServer) } -// AuthMethod is part of the AuthServer interface. -func (asl *AuthServerLdap) AuthMethod(user string) (string, error) { - return asl.Method, nil +// AuthMethods returns the list of registered auth methods +// implemented by this auth server. +func (asl *AuthServerLdap) AuthMethods() []mysql.AuthMethod { + return asl.methods } -// Salt will be unused in AuthServerLdap. -func (asl *AuthServerLdap) Salt() ([]byte, error) { - return mysql.NewSalt() +// DefaultAuthMethodDescription returns MysqlNativePassword as the default +// authentication method for the auth server implementation. +func (asl *AuthServerLdap) DefaultAuthMethodDescription() mysql.AuthMethodDescription { + return mysql.MysqlNativePassword } -// ValidateHash is unimplemented for AuthServerLdap. -func (asl *AuthServerLdap) ValidateHash(salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (mysql.Getter, error) { - panic("unimplemented") +// HandleUser is part of the Validator interface. We +// handle any user here since we don't check up front. +func (asl *AuthServerLdap) HandleUser(user string) bool { + return true } -// Negotiate is part of the AuthServer interface. -func (asl *AuthServerLdap) Negotiate(c *mysql.Conn, user string, remoteAddr net.Addr) (mysql.Getter, error) { - // Finish the negotiation. - password, err := mysql.AuthServerNegotiateClearOrDialog(c, asl.Method) - if err != nil { - return nil, err - } +// UserEntryWithPassword is part of the PlaintextStorage interface +// and called after the password is sent by the client. +func (asl *AuthServerLdap) UserEntryWithPassword(userCerts []*x509.Certificate, user string, password string, remoteAddr net.Addr) (mysql.Getter, error) { return asl.validate(user, password) } diff --git a/go/mysql/query_benchmark_test.go b/go/mysql/query_benchmark_test.go index 155d809675a..7b9a8c2d8f7 100644 --- a/go/mysql/query_benchmark_test.go +++ b/go/mysql/query_benchmark_test.go @@ -37,7 +37,7 @@ const benchmarkQueryPrefix = "benchmark " func benchmarkQuery(b *testing.B, threads int, query string) { th := &testHandler{} - authServer := &AuthServerNone{} + authServer := NewAuthServerNone() lCfg := ListenerConfig{ Protocol: "tcp", diff --git a/go/mysql/server.go b/go/mysql/server.go index e6de1925974..7397d2deeb3 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -317,7 +317,7 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti defer connCount.Add(-1) // First build and send the server handshake packet. - salt, err := c.writeHandshakeV10(l.ServerVersion, l.authServer, l.TLSConfig.Load() != nil) + serverAuthPluginData, err := c.writeHandshakeV10(l.ServerVersion, l.authServer, l.TLSConfig.Load() != nil) if err != nil { if err != io.EOF { log.Errorf("Cannot send HandshakeV10 packet to %s: %v", c, err) @@ -335,7 +335,7 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti } return } - user, authMethod, authResponse, err := l.parseClientHandshakePacket(c, true, response) + user, clientAuthMethod, clientAuthResponse, err := l.parseClientHandshakePacket(c, true, response) if err != nil { log.Errorf("Cannot parse client handshake response from %s: %v", c, err) return @@ -343,7 +343,7 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti c.recycleReadPacket() - if c.Capabilities&CapabilityClientSSL > 0 { + if c.TLSEnabled() { // SSL was enabled. We need to re-read the auth packet. response, err = c.readEphemeralPacket() if err != nil { @@ -352,7 +352,7 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti } // Returns copies of the data, so we can recycle the buffer. - user, authMethod, authResponse, err = l.parseClientHandshakePacket(c, false, response) + user, clientAuthMethod, clientAuthResponse, err = l.parseClientHandshakePacket(c, false, response) if err != nil { log.Errorf("Cannot parse post-SSL client handshake response from %s: %v", c, err) return @@ -376,89 +376,59 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti } // See what auth method the AuthServer wants to use for that user. - authServerMethod, err := l.authServer.AuthMethod(user) + negotiatedAuthMethod, err := negotiateAuthMethod(c, l.authServer, user, clientAuthMethod) if err != nil { - c.writeErrorPacketFromError(err) - return - } - - // Compare with what the client sent back. - switch { - case authServerMethod == MysqlNativePassword && authMethod == MysqlNativePassword: - // Both server and client want to use MysqlNativePassword: - // the negotiation can be completed right away, using the - // ValidateHash() method. - userData, err := l.authServer.ValidateHash(salt, user, authResponse, conn.RemoteAddr()) - if err != nil { - log.Warningf("Error authenticating user using MySQL native password: %v", err) - c.writeErrorPacketFromError(err) - return + // The client will disconnect if it doesn't understand + // the first auth method that we send, so we only have to send the + // first one that we allow for the user. + for _, m := range l.authServer.AuthMethods() { + if m.HandleUser(c, user) { + negotiatedAuthMethod = m + break + } } - c.User = user - c.UserData = userData - - case authServerMethod == MysqlNativePassword: - // The server really wants to use MysqlNativePassword, - // but the client returned a result for something else. - salt, err := l.authServer.Salt() - if err != nil { - return - } - // The binary protocol requires padding with 0 - data := append(salt, byte(0x00)) - if err := c.writeAuthSwitchRequest(MysqlNativePassword, data); err != nil { - log.Errorf("Error writing auth switch packet for %s: %v", c, err) + if negotiatedAuthMethod == nil { + c.writeErrorPacket(CRServerHandshakeErr, SSUnknownSQLState, "No authentication methods available for authentication.") return } - response, err := c.readEphemeralPacket() - if err != nil { - log.Errorf("Error reading auth switch response for %s: %v", c, err) - return + if negotiatedAuthMethod.Name() == MysqlClearPassword || negotiatedAuthMethod.Name() == MysqlDialog { + if !l.AllowClearTextWithoutTLS.Get() && !c.TLSEnabled() { + c.writeErrorPacket(CRServerHandshakeErr, SSUnknownSQLState, "Cannot use clear text authentication over non-SSL connections.") + return + } } - c.recycleReadPacket() - userData, err := l.authServer.ValidateHash(salt, user, response, conn.RemoteAddr()) + serverAuthPluginData, err = negotiatedAuthMethod.AuthPluginData() if err != nil { - log.Warningf("Error authenticating user using MySQL native password: %v", err) - c.writeErrorPacketFromError(err) - return - } - c.User = user - c.UserData = userData - - default: - // The server wants to use something else, re-negotiate. - - // The negotiation happens in clear text. Let's check we can. - if !l.AllowClearTextWithoutTLS.Get() && c.Capabilities&CapabilityClientSSL == 0 { - c.writeErrorPacket(CRServerHandshakeErr, SSUnknownSQLState, "Cannot use clear text authentication over non-SSL connections.") + log.Errorf("Error generating auth switch packet for %s: %v", c, err) return } - // Switch our auth method to what the server wants. - // Dialog plugin expects an AskPassword prompt. - var data []byte - if authServerMethod == MysqlDialog { - data = authServerDialogSwitchData() - } - if err := c.writeAuthSwitchRequest(authServerMethod, data); err != nil { + if err := c.writeAuthSwitchRequest(string(negotiatedAuthMethod.Name()), serverAuthPluginData); err != nil { log.Errorf("Error writing auth switch packet for %s: %v", c, err) return } - // Then hand over the rest of the negotiation to the - // auth server. - userData, err := l.authServer.Negotiate(c, user, conn.RemoteAddr()) + clientAuthResponse, err = c.readEphemeralPacket() if err != nil { - c.writeErrorPacketFromError(err) + log.Errorf("Error reading auth switch response for %s: %v", c, err) return } - c.User = user - c.UserData = userData + c.recycleReadPacket() } + userData, err := negotiatedAuthMethod.HandleAuthPluginData(c, user, serverAuthPluginData, clientAuthResponse, conn.RemoteAddr()) + if err != nil { + log.Warningf("Error authenticating user %s using: %s", user, negotiatedAuthMethod.Name()) + c.writeErrorPacketFromError(err) + return + } + + c.User = user + c.UserData = userData + if c.User != "" { connCountPerUser.Add(c.User, 1) defer connCountPerUser.Add(c.User, -1) @@ -536,11 +506,23 @@ func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, en capabilities |= CapabilityClientSSL } + // Grab the default auth method. This can only be either + // mysql_native_password or caching_sha2_password. Both + // need the salt as well to be present too. + // + // Any other auth method will cause clients to throw a + // handshake error. + authMethod := authServer.DefaultAuthMethodDescription() + + if authMethod != MysqlNativePassword && authMethod != CachingSha2Password { + authMethod = MysqlNativePassword + } + length := 1 + // protocol version lenNullString(serverVersion) + 4 + // connection ID - 8 + // first part of salt data + 8 + // first part of plugin auth data 1 + // filler byte 2 + // capability flags (lower 2 bytes) 1 + // character set @@ -549,7 +531,7 @@ func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, en 1 + // length of auth plugin data 10 + // reserved (0) 13 + // auth-plugin-data - lenNullString(MysqlNativePassword) // auth-plugin-name + lenNullString(string(authMethod)) // auth-plugin-name data, pos := c.startEphemeralPacketWithHeader(length) @@ -562,13 +544,17 @@ func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, en // Add connectionID in. pos = writeUint32(data, pos, c.ConnectionID) - // Generate the salt, put 8 bytes in. - salt, err := authServer.Salt() + // Generate the salt as the plugin data. Will be reused + // later on if no auth method switch happens and the real + // auth method is also mysql_native_password or caching_sha2_password. + pluginData, err := newSalt() if err != nil { return nil, err } + // Plugin data is always defined as having a trailing NULL + pluginData = append(pluginData, 0) - pos += copy(data[pos:], salt[:8]) + pos += copy(data[pos:], pluginData[:8]) // One filler byte, always 0. pos = writeByte(data, pos, 0) @@ -593,12 +579,11 @@ func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, en pos = writeZeroes(data, pos, 10) // Second part of auth plugin data. - pos += copy(data[pos:], salt[8:]) - data[pos] = 0 - pos++ + pos += copy(data[pos:], pluginData[8:]) - // Copy authPluginName. We always start with mysql_native_password. - pos = writeNullString(data, pos, MysqlNativePassword) + // Copy authPluginName. We always start with the first + // registered auth method name. + pos = writeNullString(data, pos, string(authMethod)) // Sanity check. if pos != len(data) { @@ -615,13 +600,13 @@ func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, en return nil, err } - return salt, nil + return pluginData, nil } // parseClientHandshakePacket parses the handshake sent by the client. // Returns the username, auth method, auth data, error. // The original data is not pointed at, and can be freed. -func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []byte) (string, string, []byte, error) { +func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []byte) (string, AuthMethodDescription, []byte, error) { pos := 0 // Client flags, 4 bytes. @@ -724,15 +709,15 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by // authMethod (with default) authMethod := MysqlNativePassword if clientFlags&CapabilityClientPluginAuth != 0 { - authMethod, pos, ok = readNullString(data, pos) + var authMethodStr string + authMethodStr, pos, ok = readNullString(data, pos) if !ok { return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read authMethod") } - } - - // The JDBC driver sometimes sends an empty string as the auth method when it wants to use mysql_native_password - if authMethod == "" { - authMethod = MysqlNativePassword + // The JDBC driver sometimes sends an empty string as the auth method when it wants to use mysql_native_password + if authMethodStr != "" { + authMethod = AuthMethodDescription(authMethodStr) + } } // Decode connection attributes send by the client @@ -742,7 +727,7 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by } } - return username, authMethod, authResponse, nil + return username, AuthMethodDescription(authMethod), authResponse, nil } func parseConnAttrs(data []byte, pos int) (map[string]string, int, error) { diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index a322314a5a2..fa1fe9a1985 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -34,6 +34,7 @@ import ( "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/test/utils" vtenv "vitess.io/vitess/go/vt/env" "vitess.io/vitess/go/vt/tlstest" "vitess.io/vitess/go/vt/vterrors" @@ -694,12 +695,11 @@ func TestServerStats(t *testing.T) { func TestClearTextServer(t *testing.T) { th := &testHandler{} - authServer := NewAuthServerStatic("", "", 0) + authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword) authServer.entries["user1"] = []*AuthServerStaticEntry{{ Password: "password1", UserData: "userData1", }} - authServer.method = MysqlClearPassword defer authServer.close() l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false) require.NoError(t, err) @@ -768,12 +768,11 @@ func TestClearTextServer(t *testing.T) { func TestDialogServer(t *testing.T) { th := &testHandler{} - authServer := NewAuthServerStatic("", "", 0) + authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlDialog) authServer.entries["user1"] = []*AuthServerStaticEntry{{ Password: "password1", UserData: "userData1", }} - authServer.method = MysqlDialog defer authServer.close() l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false) require.NoError(t, err) @@ -987,6 +986,120 @@ func TestTLSRequired(t *testing.T) { } } +func TestCachingSha2PasswordAuthWithTLS(t *testing.T) { + th := &testHandler{} + + authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, CachingSha2Password) + authServer.entries["user1"] = []*AuthServerStaticEntry{ + {Password: "password1"}, + } + defer authServer.close() + + // Create the listener, so we can get its host. + l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false) + if err != nil { + t.Fatalf("NewListener failed: %v", err) + } + defer l.Close() + host := l.Addr().(*net.TCPAddr).IP.String() + port := l.Addr().(*net.TCPAddr).Port + + // Create the certs. + root, err := ioutil.TempDir("", "TestSSLConnection") + if err != nil { + t.Fatalf("TempDir failed: %v", err) + } + defer os.RemoveAll(root) + tlstest.CreateCA(root) + tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") + tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") + + // Create the server with TLS config. + serverConfig, err := vttls.ServerConfig( + path.Join(root, "server-cert.pem"), + path.Join(root, "server-key.pem"), + path.Join(root, "ca-cert.pem"), + "", + tls.VersionTLS12) + if err != nil { + t.Fatalf("TLSServerConfig failed: %v", err) + } + l.TLSConfig.Store(serverConfig) + go func() { + l.Accept() + }() + + // Setup the right parameters. + params := &ConnParams{ + Host: host, + Port: port, + Uname: "user1", + Pass: "password1", + // SSL flags. + Flags: CapabilityClientSSL, + SslCa: path.Join(root, "ca-cert.pem"), + SslCert: path.Join(root, "client-cert.pem"), + SslKey: path.Join(root, "client-key.pem"), + ServerName: "server.example.com", + } + + // Connection should fail, as server requires SSL for caching_sha2_password. + ctx := context.Background() + + conn, err := Connect(ctx, params) + if err != nil { + t.Fatalf("unexpected connection error: %v", err) + } + defer conn.Close() + + // Run a 'select rows' command with results. + result, err := conn.ExecuteFetch("select rows", 10000, true) + if err != nil { + t.Fatalf("ExecuteFetch failed: %v", err) + } + utils.MustMatch(t, result, selectRowsResult) + + // Send a ComQuit to avoid the error message on the server side. + conn.writeComQuit() +} + +func TestCachingSha2PasswordAuthWithoutTLS(t *testing.T) { + th := &testHandler{} + + authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, CachingSha2Password) + authServer.entries["user1"] = []*AuthServerStaticEntry{ + {Password: "password1"}, + } + defer authServer.close() + + // Create the listener. + l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false) + if err != nil { + t.Fatalf("NewListener failed: %v", err) + } + defer l.Close() + host := l.Addr().(*net.TCPAddr).IP.String() + port := l.Addr().(*net.TCPAddr).Port + go func() { + l.Accept() + }() + + // Setup the right parameters. + params := &ConnParams{ + Host: host, + Port: port, + Uname: "user1", + Pass: "password1", + } + + // Connection should fail, as server requires SSL for caching_sha2_password. + ctx := context.Background() + _, err = Connect(ctx, params) + if err == nil || !strings.Contains(err.Error(), "No authentication methods available for authentication") { + t.Fatalf("unexpected connection error: %v", err) + } +} + func checkCountForTLSVer(t *testing.T, version string, expected int64) { connCounts := connCountByTLSVer.Counts() count, ok := connCounts[version] @@ -1254,7 +1367,7 @@ func TestServerFlush(t *testing.T) { th := &testHandler{} - l, err := NewListener("tcp", ":0", &AuthServerNone{}, th, 0, 0, false) + l, err := NewListener("tcp", ":0", NewAuthServerNone(), th, 0, 0, false) require.NoError(t, err) defer l.Close() go l.Accept() From a74e77d42b32b4739d2417ca85c567b5098de76e Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Wed, 21 Jul 2021 15:33:14 +0200 Subject: [PATCH 2/3] Add locking around authServers global Signed-off-by: Dirkjan Bussink --- go/mysql/auth_server.go | 8 ++++++++ go/mysql/auth_server_static_test.go | 2 ++ 2 files changed, 10 insertions(+) diff --git a/go/mysql/auth_server.go b/go/mysql/auth_server.go index 300a531203e..4fb75af960f 100644 --- a/go/mysql/auth_server.go +++ b/go/mysql/auth_server.go @@ -25,6 +25,7 @@ import ( "crypto/x509" "encoding/hex" "net" + "sync" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/proto/vtrpc" @@ -552,8 +553,13 @@ func (n *mysqlCachingSha2AuthMethod) HandleAuthPluginData(c *Conn, user string, // authServers is a registry of AuthServer implementations. var authServers = make(map[string]AuthServer) +// mu is used to lock access to authServers +var mu sync.Mutex + // RegisterAuthServer registers an implementations of AuthServer. func RegisterAuthServer(name string, authServer AuthServer) { + mu.Lock() + defer mu.Unlock() if _, ok := authServers[name]; ok { log.Fatalf("AuthServer named %v already exists", name) } @@ -562,6 +568,8 @@ func RegisterAuthServer(name string, authServer AuthServer) { // GetAuthServer returns an AuthServer by name, or log.Exitf. func GetAuthServer(name string) AuthServer { + mu.Lock() + defer mu.Unlock() authServer, ok := authServers[name] if !ok { log.Exitf("no AuthServer name %v registered", name) diff --git a/go/mysql/auth_server_static_test.go b/go/mysql/auth_server_static_test.go index 0c943debb32..6a0170d4921 100644 --- a/go/mysql/auth_server_static_test.go +++ b/go/mysql/auth_server_static_test.go @@ -150,6 +150,8 @@ func TestStaticConfigHUP(t *testing.T) { hupTest(t, aStatic, tmpFile, oldStr, "str2") hupTest(t, aStatic, tmpFile, "str2", "str3") // still handling the signal + mu.Lock() + defer mu.Unlock() // delete registered Auth server for auth := range authServers { delete(authServers, auth) From 67fc05782d7ceb348a0eb9cecc9a973cadfc128d Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Fri, 23 Jul 2021 11:04:10 +0200 Subject: [PATCH 3/3] Small refactor of auth state handling and comment improvement Signed-off-by: Dirkjan Bussink --- go/mysql/auth_server.go | 45 +++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/go/mysql/auth_server.go b/go/mysql/auth_server.go index 4fb75af960f..e86f08a66c2 100644 --- a/go/mysql/auth_server.go +++ b/go/mysql/auth_server.go @@ -50,10 +50,10 @@ type AuthServer interface { // auth switch request using the first provided AuthMethod in this list. AuthMethods() []AuthMethod - // Default auth method description that the auth server sends during - // the initial server handshake. This needs to be either `mysql_native_password` - // or `caching_sha2_password` as those are the only supported auth - // methods during the initial handshake. + // DefaultAuthMethodDescription returns the auth method that the auth server + // sends during the initial server handshake. This needs to be either + // `mysql_native_password` or `caching_sha2_password` as those are the only + // supported auth methods during the initial handshake. // // It's not needed to also support those methods in the AuthMethods(), // in fact, if you want to only support for example clear text passwords, @@ -514,12 +514,10 @@ func (n *mysqlCachingSha2AuthMethod) HandleAuthPluginData(c *Conn, user string, return nil, err } - if cacheState == AuthRejected { + switch cacheState { + case AuthRejected: return nil, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) - } - - // If we get a result back from the cache that's valid, we can return that immediately. - if cacheState == AuthAccepted { + case AuthAccepted: // We need to write a more data packet to indicate the // handshake completed properly. This will be followed // by a regular OK packet which the caller of this method will send. @@ -531,23 +529,26 @@ func (n *mysqlCachingSha2AuthMethod) HandleAuthPluginData(c *Conn, user string, return nil, err } return result, nil - } + case AuthNeedMoreData: + if !c.TLSEnabled() && !c.IsUnixSocket() { + return nil, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) + } - if !c.TLSEnabled() && !c.IsUnixSocket() { - return nil, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) - } + data, pos := c.startEphemeralPacketWithHeader(1) + pos = writeByte(data, pos, AuthMoreDataPacket) + writeByte(data, pos, CachingSha2FullAuth) + c.writeEphemeralPacket() - data, pos := c.startEphemeralPacketWithHeader(1) - pos = writeByte(data, pos, AuthMoreDataPacket) - writeByte(data, pos, CachingSha2FullAuth) - c.writeEphemeralPacket() + password, err := readPacketPasswordString(c) + if err != nil { + return nil, err + } - password, err := readPacketPasswordString(c) - if err != nil { - return nil, err + return n.storage.UserEntryWithPassword(c.GetTLSClientCerts(), user, password, remoteAddr) + default: + // Somehow someone returned an unknown state, let's error with access denied. + return nil, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user) } - - return n.storage.UserEntryWithPassword(c.GetTLSClientCerts(), user, password, remoteAddr) } // authServers is a registry of AuthServer implementations.