Skip to content

Commit

Permalink
Merge branch 'fulghum/auth-refactor' into fulghum-588631ab
Browse files Browse the repository at this point in the history
  • Loading branch information
fulghum committed Dec 9, 2024
2 parents edbd0d7 + b963680 commit e99d9e2
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 50 deletions.
48 changes: 44 additions & 4 deletions enginetest/queries/priv_auth_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -692,15 +692,55 @@ var UserPrivTests = []UserPrivilegeTest{
},
},
{
Query: "SELECT Host, User FROM mysql.user;",
Query: "SELECT Host, User, Plugin, length(authentication_string) > 0 FROM mysql.user order by User;",
Expected: []sql.Row{
{"localhost", "root"},
{"localhost", "testuser2"},
{"127.0.0.1", "testuser"},
{"localhost", "root", "mysql_native_password", false},
{"127.0.0.1", "testuser", "mysql_native_password", false},
// testuser2 was inserted directly into the table, so it uses the column default
// from the plugin field – caching_sha2_password
{"localhost", "testuser2", "caching_sha2_password", false},
},
},
},
},
{
Name: "User creation with auth plugin specified: mysql_native_password",
SetUpScript: []string{
"CREATE USER testuser1@`127.0.0.1` identified with mysql_native_password by 'pass1';",
"CREATE USER testuser2@`127.0.0.1` identified with 'mysql_native_password';",
},
Assertions: []UserPrivilegeTestAssertion{
{
Query: "select user, host, plugin, authentication_string from mysql.user where user='testuser1';",
Expected: []sql.Row{{"testuser1", "127.0.0.1", "mysql_native_password", "*22A99BA288DB55E8E230679259740873101CD636"}},
},
{
Query: "select user, host, plugin, authentication_string from mysql.user where user='testuser2';",
Expected: []sql.Row{{"testuser2", "127.0.0.1", "mysql_native_password", ""}},
},
},
},
{
Name: "User creation with auth plugin specified: caching_sha2_password",
SetUpScript: []string{
"CREATE USER testuser1@`127.0.0.1` identified with caching_sha2_password by 'pass1';",
"CREATE USER testuser2@`127.0.0.1` identified with 'caching_sha2_password';",
},
Assertions: []UserPrivilegeTestAssertion{
{
// caching_sha2_password auth uses a random salt to create the authentication
// string. Since it's not a consistent value during each test run, we just sanity
// check the first bytes of metadata (digest type, iterations) in the auth string.
Query: "select user, host, plugin, authentication_string like '$A$005$%' from mysql.user where user='testuser1';",
Expected: []sql.Row{{"testuser1", "127.0.0.1", "caching_sha2_password", true}},
},
{
Query: "select user, host, plugin, authentication_string from mysql.user where user='testuser2';",
Expected: []sql.Row{{"testuser2", "127.0.0.1", "caching_sha2_password", ""}},
},
},
},

{
Name: "Dynamic privilege support",
SetUpScript: []string{
Expand Down
45 changes: 33 additions & 12 deletions sql/mysql_db/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ import (
"github.com/dolthub/go-mysql-server/sql"
)

// DefaultAuthMethod specifies the MySQL auth protocol (e.g. mysql_native_password,
// caching_sha2_password) that is used by default. When the auth server advertises
// what auth protocol it prefers, as part of the auth handshake, it is controlled
// by this constant. When a new user is created, if no auth plugin is specified, this
// auth method will be used.
const DefaultAuthMethod = mysql.MysqlNativePassword

// authServer implements the mysql.AuthServer interface. It exposes configured AuthMethod implementations
// that the auth framework in Vitess uses to negotiate authentication with a client. By default, authServer
// configures support for the mysql_native_password auth plugin, as well as an extensible auth method, built
Expand All @@ -42,10 +49,10 @@ var _ mysql.AuthServer = (*authServer)(nil)
// mysql_native_password support, as well as an extensible auth method, built on the mysql_clear_password auth
// method, that allows integrators to extend authentication to allow additional schemes.
func newAuthServer(db *MySQLDb) *authServer {
// The native password auth method allows auth over the mysql_native_password protocol
// mysql_native_password auth support
nativePasswordAuthMethod := mysql.NewMysqlNativeAuthMethod(
&nativePasswordHashStorage{db: db},
&nativePasswordUserValidator{db: db})
newUserValidator(db, mysql.MysqlNativePassword))

// TODO: Add CachingSha2Password AuthMethod

Expand All @@ -67,7 +74,7 @@ func (as *authServer) AuthMethods() []mysql.AuthMethod {

// DefaultAuthMethodDescription implements the mysql.AuthServer interface.
func (db *authServer) DefaultAuthMethodDescription() mysql.AuthMethodDescription {
return mysql.MysqlNativePassword
return DefaultAuthMethod
}

// extendedAuthPlainTextStorage implements the mysql.PlainTextStorage interface and plugs into
Expand Down Expand Up @@ -205,8 +212,8 @@ func (nphs *nativePasswordHashStorage) UserEntryWithHash(_ []*x509.Certificate,
if userEntry == nil || userEntry.Locked {
return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
}
if len(userEntry.Password) > 0 {
if !validateMysqlNativePassword(authResponse, salt, userEntry.Password) {
if len(userEntry.AuthString) > 0 {
if !validateMysqlNativePassword(authResponse, salt, userEntry.AuthString) {
return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
}
} else if len(authResponse) > 0 {
Expand All @@ -218,18 +225,32 @@ func (nphs *nativePasswordHashStorage) UserEntryWithHash(_ []*x509.Certificate,
return sql.MysqlConnectionUser{User: userEntry.User, Host: userEntry.Host}, nil
}

// nativePasswordUserValidator implements the mysql.UserValidator interface and plugs into the mysql_native_password
// auth method in Vitess. This implementation is called by the native password auth method to determine if a specific
// user and remote address can connect to this server via the mysql_native_password auth protocol.
type nativePasswordUserValidator struct {
// userValidator implements the mysql.UserValidator interface. It looks up a user and host from the
// associated mysql database (|db|) and validates that a user entry exists and that it is configured
// for the specified authentication plugin (|authMethod|).
type userValidator struct {
// db is the mysql database that contains user information
db *MySQLDb

// authMethod is the name of the auth plugin for which this validator will
// validate users.
authMethod mysql.AuthMethodDescription
}

var _ mysql.UserValidator = (*nativePasswordUserValidator)(nil)
var _ mysql.UserValidator = (*userValidator)(nil)

// newUserValidator creates a new userValidator instance, configured to use |db| to look up user
// entries and validate that they have the specified auth plugin (|authMethod|) configured.
func newUserValidator(db *MySQLDb, authMethod mysql.AuthMethodDescription) *userValidator {
return &userValidator{
db: db,
authMethod: authMethod,
}
}

// HandleUser implements the mysql.UserValidator interface and verifies if the mysql_native_password auth method
// can be used for the specified |user| at the specified |remoteAddr|.
func (uv *nativePasswordUserValidator) HandleUser(user string, remoteAddr net.Addr) bool {
func (uv *userValidator) HandleUser(user string, remoteAddr net.Addr) bool {
// If the mysql database is not enabled, then we don't have user information, so
// go ahead and return true without trying to look up the user in the db.
if !uv.db.Enabled() {
Expand All @@ -251,7 +272,7 @@ func (uv *nativePasswordUserValidator) HandleUser(user string, remoteAddr net.Ad
}
userEntry := db.GetUser(rd, user, host, false)

return userEntry != nil && (userEntry.Plugin == "" || userEntry.Plugin == string(mysql.MysqlNativePassword))
return userEntry != nil && userEntry.Plugin == string(uv.authMethod)
}

// extractHostAddress extracts the host address from |addr|, checking to see if it is a unix socket, and if
Expand Down
4 changes: 2 additions & 2 deletions sql/mysql_db/mysql_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,8 +780,8 @@ func (db *MySQLDb) ValidateHash(salt []byte, user string, authResponse []byte, a
if userEntry == nil || userEntry.Locked {
return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
}
if len(userEntry.Password) > 0 {
if !validateMysqlNativePassword(authResponse, salt, userEntry.Password) {
if len(userEntry.AuthString) > 0 {
if !validateMysqlNativePassword(authResponse, salt, userEntry.AuthString) {
return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
}
} else if len(authResponse) > 0 { // password is nil or empty, therefore no password is set
Expand Down
2 changes: 1 addition & 1 deletion sql/mysql_db/mysql_db_load.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func LoadUser(serialUser *serial.User) *User {
Host: string(serialUser.Host()),
PrivilegeSet: *privilegeSet,
Plugin: string(serialUser.Plugin()),
Password: string(serialUser.Password()),
AuthString: string(serialUser.Password()),
PasswordLastChanged: time.Unix(serialUser.PasswordLastChanged(), 0),
Locked: serialUser.Locked(),
Attributes: attributes,
Expand Down
4 changes: 2 additions & 2 deletions sql/mysql_db/mysql_db_serialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func serializeUser(b *flatbuffers.Builder, users []*User) flatbuffers.UOffsetT {
host := b.CreateString(user.Host)
privilegeSet := serializePrivilegeSet(b, &user.PrivilegeSet)
plugin := b.CreateString(user.Plugin)
password := b.CreateString(user.Password)
authString := b.CreateString(user.AuthString)
attributes := serializeAttributes(b, user.Attributes)
identity := b.CreateString(user.Identity)

Expand All @@ -178,7 +178,7 @@ func serializeUser(b *flatbuffers.Builder, users []*User) flatbuffers.UOffsetT {
serial.UserAddHost(b, host)
serial.UserAddPrivilegeSet(b, privilegeSet)
serial.UserAddPlugin(b, plugin)
serial.UserAddPassword(b, password)
serial.UserAddPassword(b, authString)
serial.UserAddPasswordLastChanged(b, user.PasswordLastChanged.Unix())
serial.UserAddLocked(b, user.Locked)
serial.UserAddAttributes(b, attributes)
Expand Down
8 changes: 4 additions & 4 deletions sql/mysql_db/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type User struct {
Host string
PrivilegeSet PrivilegeSet
Plugin string
Password string
AuthString string
PasswordLastChanged time.Time
Locked bool
Attributes *string
Expand All @@ -56,7 +56,7 @@ func UserToRow(ctx *sql.Context, u *User) (sql.Row, error) {
row[userTblColIndex_User] = u.User
row[userTblColIndex_Host] = u.Host
row[userTblColIndex_plugin] = u.Plugin
row[userTblColIndex_authentication_string] = u.Password
row[userTblColIndex_authentication_string] = u.AuthString
row[userTblColIndex_password_last_changed] = u.PasswordLastChanged
row[userTblColIndex_identity] = u.Identity
if u.Locked {
Expand Down Expand Up @@ -87,7 +87,7 @@ func UserFromRow(ctx *sql.Context, row sql.Row) (*User, error) {
Host: row[userTblColIndex_Host].(string),
PrivilegeSet: UserRowToPrivSet(ctx, row),
Plugin: row[userTblColIndex_plugin].(string),
Password: row[userTblColIndex_authentication_string].(string),
AuthString: row[userTblColIndex_authentication_string].(string),
PasswordLastChanged: passwordLastChanged,
Locked: row[userTblColIndex_account_locked].(uint16) == 2,
Attributes: attributes,
Expand Down Expand Up @@ -117,7 +117,7 @@ func UserEquals(left, right *User) bool {
if left.User != right.User ||
left.Host != right.Host ||
left.Plugin != right.Plugin ||
left.Password != right.Password ||
left.AuthString != right.AuthString ||
left.Identity != right.Identity ||
!left.PasswordLastChanged.Equal(right.PasswordLastChanged) ||
left.Locked != right.Locked ||
Expand Down
4 changes: 2 additions & 2 deletions sql/mysql_db/user_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,13 @@ func init() {
}
}

func addSuperUser(ed *Editor, username string, host string, password string) {
func addSuperUser(ed *Editor, username string, host string, authString string) {
ed.PutUser(&User{
User: username,
Host: host,
PrivilegeSet: NewPrivilegeSetWithAllPrivileges(),
Plugin: "mysql_native_password",
Password: password,
AuthString: authString,
PasswordLastChanged: time.Unix(1, 0).UTC(),
Locked: false,
Attributes: nil,
Expand Down
2 changes: 1 addition & 1 deletion sql/mysql_db/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestUserJson(t *testing.T) {
Host: "localhost",
PrivilegeSet: NewPrivilegeSet(),
Plugin: "mysql_native_password",
Password: "*2470C0C06DEE42FD1618BB99005ADCA2EC9D1E19",
AuthString: "*2470C0C06DEE42FD1618BB99005ADCA2EC9D1E19",
PasswordLastChanged: time.Unix(184301, 0),
Locked: false,
Attributes: nil,
Expand Down
73 changes: 63 additions & 10 deletions sql/plan/create_user_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"encoding/hex"
"fmt"
"strings"

"github.com/dolthub/vitess/go/mysql"
)

// UserName represents either a user or role name.
Expand Down Expand Up @@ -47,8 +49,9 @@ func (un *UserName) String(quote string) string {
type Authentication interface {
// Plugin returns the name of the plugin that this authentication represents.
Plugin() string
// Password returns the value to insert into the database as the password.
Password() string
// AuthString returns the authentication string that encodes the password
// in an obscured or hashed form specific to an auth plugin's protocol.
AuthString() (string, error)
}

// AuthenticatedUser represents a user with the relevant methods of authentication.
Expand Down Expand Up @@ -89,20 +92,69 @@ type PasswordOptions struct {
LockTime *int64
}

// CachingSha2PasswordAuthentication implements the Authentication interface for the
// caching_sha2_password auth plugin.
type CachingSha2PasswordAuthentication struct {
password string
authorizationStringBytes []byte
}

var _ Authentication = (*CachingSha2PasswordAuthentication)(nil)

// NewCachingSha2PasswordAuthentication creates a new CachingSha2PasswordAuthentication
// instance that will obscure the specified |password| when the AuthString() method is called.
func NewCachingSha2PasswordAuthentication(password string) Authentication {
return CachingSha2PasswordAuthentication{password: password}
}

// Plugin implements the Authentication interface.
func (a CachingSha2PasswordAuthentication) Plugin() string {
return string(mysql.CachingSha2Password)
}

// AuthString implements the Authentication interface.
func (a CachingSha2PasswordAuthentication) AuthString() (string, error) {
// We cache the computed authorization data, since it's expensive to compute
// and we must return the same authorization data on multiple calls, not
// generate new auth data with a new salt.
if a.authorizationStringBytes != nil {
return string(a.authorizationStringBytes), nil
}

// If there is no password, caching_sha2_password uses an empty auth string
if a.password == "" {
return "", nil
}

salt, err := mysql.NewSalt()
if err != nil {
return "", err
}

authorizationStringBytes, err := mysql.SerializeCachingSha2PasswordAuthString(
a.password, salt, mysql.DefaultCachingSha2PasswordHashIterations)
if err != nil {
return "", err
}
a.authorizationStringBytes = authorizationStringBytes

return string(a.authorizationStringBytes), nil
}

// AuthenticationMysqlNativePassword is an authentication type that represents "mysql_native_password".
type AuthenticationMysqlNativePassword string

var _ Authentication = AuthenticationMysqlNativePassword("")

// Plugin implements the interface Authentication.
func (a AuthenticationMysqlNativePassword) Plugin() string {
return "mysql_native_password"
return string(mysql.MysqlNativePassword)
}

// Password implements the interface Authentication.
func (a AuthenticationMysqlNativePassword) Password() string {
// AuthString implements the interface Authentication.
func (a AuthenticationMysqlNativePassword) AuthString() (string, error) {
if len(a) == 0 {
return ""
return "", nil
}
// native = sha1(sha1(password))
hash := sha1.New()
Expand All @@ -111,7 +163,7 @@ func (a AuthenticationMysqlNativePassword) Password() string {
hash.Reset()
hash.Write(s1)
s2 := hash.Sum(nil)
return "*" + strings.ToUpper(hex.EncodeToString(s2))
return "*" + strings.ToUpper(hex.EncodeToString(s2)), nil
}

// NewDefaultAuthentication returns the given password with the default
Expand All @@ -137,9 +189,10 @@ func (a AuthenticationOther) Plugin() string {
return a.plugin
}

func (a AuthenticationOther) Password() string {
// AuthString implements the interface Authentication.
func (a AuthenticationOther) AuthString() (string, error) {
if a.password == "" {
return a.identity
return a.identity, nil
}
return string(a.password)
return a.password, nil
}
9 changes: 8 additions & 1 deletion sql/planbuilder/priv.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"strconv"
"strings"

"github.com/dolthub/vitess/go/mysql"
ast "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/go-mysql-server/sql"
Expand Down Expand Up @@ -153,8 +154,14 @@ func (b *Builder) buildAuthenticatedUser(user ast.AccountWithAuth) plan.Authenti
}
if user.Auth1 != nil {
authUser.Identity = user.Auth1.Identity
if user.Auth1.Plugin == "mysql_native_password" && len(user.Auth1.Password) > 0 {
if user.Auth1.Password == "" && user.Auth1.Identity != "" {
// If an identity has been specified, instead of a password, then use the auth details
// directly, without an Authentication implementation that would obscure the password.
authUser.Auth1 = plan.NewOtherAuthentication(user.Auth1.Password, user.Auth1.Plugin, user.Auth1.Identity)
} else if user.Auth1.Plugin == string(mysql.MysqlNativePassword) {
authUser.Auth1 = plan.AuthenticationMysqlNativePassword(user.Auth1.Password)
} else if user.Auth1.Plugin == string(mysql.CachingSha2Password) {
authUser.Auth1 = plan.NewCachingSha2PasswordAuthentication(user.Auth1.Password)
} else if len(user.Auth1.Plugin) > 0 {
authUser.Auth1 = plan.NewOtherAuthentication(user.Auth1.Password, user.Auth1.Plugin, user.Auth1.Identity)
} else {
Expand Down
Loading

0 comments on commit e99d9e2

Please sign in to comment.