Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

privilege, session, server: consistently map user login to identity (#30204) #30450

Merged
merged 3 commits into from
Feb 21, 2022
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions executor/coprocessor.go
Original file line number Diff line number Diff line change
@@ -144,8 +144,8 @@ func (h *CoprocessorDAGHandler) buildDAGExecutor(req *coprocessor.Request) (Exec
Username: dagReq.User.UserName,
Hostname: dagReq.User.UserHost,
}
authName, authHost, success := pm.GetAuthWithoutVerification(dagReq.User.UserName, dagReq.User.UserHost)
if success {
authName, authHost, success := pm.MatchIdentity(dagReq.User.UserName, dagReq.User.UserHost, false)
if success && pm.GetAuthWithoutVerification(authName, authHost) {
h.sctx.GetSessionVars().User.AuthUsername = authName
h.sctx.GetSessionVars().User.AuthHostname = authHost
h.sctx.GetSessionVars().ActiveRoles = pm.GetDefaultRoles(authName, authHost)
9 changes: 7 additions & 2 deletions privilege/privilege.go
Original file line number Diff line number Diff line change
@@ -59,10 +59,15 @@ type Manager interface {
RequestDynamicVerificationWithUser(privName string, grantable bool, user *auth.UserIdentity) bool

// ConnectionVerification verifies user privilege for connection.
ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) (string, string, bool)
// Requires exact match on user name and host name.
ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) bool

// GetAuthWithoutVerification uses to get auth name without verification.
GetAuthWithoutVerification(user, host string) (string, string, bool)
// Requires exact match on user name and host name.
GetAuthWithoutVerification(user, host string) bool

// MatchIdentity matches an identity
MatchIdentity(user, host string, skipNameResolve bool) (string, string, bool)

// DBIsVisible returns true is the database is visible to current user.
DBIsVisible(activeRole []*auth.RoleIdentity, db string) bool
48 changes: 46 additions & 2 deletions privilege/privileges/cache.go
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@ import (
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/chunk"
@@ -848,6 +849,9 @@ func decodeSetToPrivilege(s types.Set) mysql.PrivilegeType {
// See https://dev.mysql.com/doc/refman/5.7/en/account-names.html
func (record *baseRecord) hostMatch(s string) bool {
if record.hostIPNet == nil {
if record.Host == "localhost" && net.ParseIP(s).IsLoopback() {
return true
}
return false
}
ip := net.ParseIP(s).To4()
@@ -890,14 +894,54 @@ func patternMatch(str string, patChars, patTypes []byte) bool {
return stringutil.DoMatchBytes(str, patChars, patTypes)
}

// connectionVerification verifies the connection have access to TiDB server.
func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord {
// matchIdentity finds an identity to match a user + host
// using the correct rules according to MySQL.
func (p *MySQLPrivilege) matchIdentity(user, host string, skipNameResolve bool) *UserRecord {
for i := 0; i < len(p.User); i++ {
record := &p.User[i]
if record.match(user, host) {
return record
}
}

// If skip-name resolve is not enabled, and the host is not localhost
// we can fallback and try to resolve with all addrs that match.
// TODO: this is imported from previous code in session.Auth(), and can be improved in future.
if !skipNameResolve && host != variable.DefHostname {
addrs, err := net.LookupAddr(host)
if err != nil {
logutil.BgLogger().Warn(
"net.LookupAddr returned an error during auth check",
zap.String("host", host),
zap.Error(err),
)
return nil
}
for _, addr := range addrs {
for i := 0; i < len(p.User); i++ {
record := &p.User[i]
if record.match(user, addr) {
return record
}
}
}
}
return nil
}

// connectionVerification verifies the username + hostname according to exact
// match from the mysql.user privilege table. call matchIdentity() first if you
// do not have an exact match yet.
func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord {
records, exists := p.UserMap[user]
if exists {
for i := 0; i < len(records); i++ {
record := &records[i]
if record.Host == host { // exact match
return record
}
}
}
return nil
}

28 changes: 18 additions & 10 deletions privilege/privileges/privileges.go
Original file line number Diff line number Diff line change
@@ -256,8 +256,21 @@ func (p *UserPrivileges) GetAuthPlugin(user, host string) (string, error) {
return "", errors.New("Failed to get plugin for user")
}

// MatchIdentity implements the Manager interface.
func (p *UserPrivileges) MatchIdentity(user, host string, skipNameResolve bool) (u string, h string, success bool) {
if SkipWithGrant {
return user, host, true
}
mysqlPriv := p.Handle.Get()
record := mysqlPriv.matchIdentity(user, host, skipNameResolve)
if record != nil {
return record.User, record.Host, true
}
return "", "", false
}

// GetAuthWithoutVerification implements the Manager interface.
func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string, h string, success bool) {
func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (success bool) {
if SkipWithGrant {
p.user = user
p.host = host
@@ -273,16 +286,14 @@ func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string
return
}

u = record.User
h = record.Host
p.user = user
p.host = h
p.host = record.Host
success = true
return
}

// ConnectionVerification implements the Manager interface.
func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (u string, h string, success bool) {
func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (success bool) {
if SkipWithGrant {
p.user = user
p.host = host
@@ -298,9 +309,6 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio
return
}

u = record.User
h = record.Host

globalPriv := mysqlPriv.matchGlobalPriv(user, host)
if globalPriv != nil {
if !p.checkSSL(globalPriv, tlsState) {
@@ -328,7 +336,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio
// empty password
if len(pwd) == 0 && len(authentication) == 0 {
p.user = user
p.host = h
p.host = record.Host
success = true
return
}
@@ -371,7 +379,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio
}

p.user = user
p.host = h
p.host = record.Host
success = true
return
}
83 changes: 56 additions & 27 deletions server/conn.go
Original file line number Diff line number Diff line change
@@ -211,6 +211,9 @@ func (cc *clientConn) String() string {
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
// https://bugs.mysql.com/bug.php?id=93044
func (cc *clientConn) authSwitchRequest(ctx context.Context, plugin string) ([]byte, error) {
failpoint.Inject("FakeAuthSwitch", func() {
failpoint.Return([]byte(plugin), nil)
})
enclen := 1 + len(plugin) + 1 + len(cc.salt) + 1
data := cc.alloc.AllocWithLen(4, enclen)
data = append(data, mysql.AuthSwitchRequest) // switch request
@@ -708,40 +711,29 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con

func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeResponse41) error {
if resp.Capability&mysql.ClientPluginAuth > 0 {
newAuth, err := cc.checkAuthPlugin(ctx, &resp.AuthPlugin)
newAuth, err := cc.checkAuthPlugin(ctx, resp)
if err != nil {
logutil.Logger(ctx).Warn("failed to check the user authplugin", zap.Error(err))
return err
}
if len(newAuth) > 0 {
resp.Auth = newAuth
}

switch resp.AuthPlugin {
case mysql.AuthCachingSha2Password:
resp.Auth, err = cc.authSha(ctx)
if err != nil {
return err
}
case mysql.AuthNativePassword:
Comment on lines 724 to 725
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional. It was removed from master, I believe because this is handled in cc.checkAuthPlugin instead. The code could be cleaned up slightly, but that's for another PR.

case mysql.AuthSocket:
default:
logutil.Logger(ctx).Warn("Unknown Auth Plugin", zap.String("plugin", resp.AuthPlugin))
}
} else {
// MySQL 5.1 and older clients don't support authentication plugins.
logutil.Logger(ctx).Warn("Client without Auth Plugin support; Please upgrade client")
if cc.ctx == nil {
err := cc.openSession()
if err != nil {
return err
}
}
userplugin, err := cc.ctx.AuthPluginForUser(&auth.UserIdentity{Username: cc.user, Hostname: cc.peerHost})
_, err := cc.checkAuthPlugin(ctx, resp)
if err != nil {
return err
}
if userplugin != mysql.AuthNativePassword && userplugin != "" {
return errNotSupportedAuthMode
}
Comment on lines -732 to -744
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes are also based on master, and checkAuthPlugin does these checks.

resp.AuthPlugin = mysql.AuthNativePassword
}
return nil
@@ -845,7 +837,7 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string) e
}

// Check if the Authentication Plugin of the server, client and user configuration matches
func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) ([]byte, error) {
func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshakeResponse41) ([]byte, error) {
// Open a context unless this was done before.
if cc.ctx == nil {
err := cc.openSession()
@@ -854,22 +846,54 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) (
}
}

userplugin, err := cc.ctx.AuthPluginForUser(&auth.UserIdentity{Username: cc.user, Hostname: cc.peerHost})
authData := resp.Auth
hasPassword := "YES"
if len(authData) == 0 {
hasPassword = "NO"
}
host, _, err := cc.PeerHost(hasPassword)
if err != nil {
return nil, err
}
// Find the identity of the user based on username and peer host.
identity, err := cc.ctx.MatchIdentity(cc.user, host)
if err != nil {
return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword)
}
// Get the plugin for the identity.
userplugin, err := cc.ctx.AuthPluginForUser(identity)
if err != nil {
logutil.Logger(ctx).Warn("Failed to get authentication method for user",
zap.String("user", cc.user), zap.String("host", host))
}
failpoint.Inject("FakeUser", func(val failpoint.Value) {
userplugin = val.(string)
})
if userplugin == mysql.AuthSocket {
*authPlugin = mysql.AuthSocket
if !cc.isUnixSocket {
return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword)
}
resp.AuthPlugin = mysql.AuthSocket
user, err := user.LookupId(fmt.Sprint(cc.socketCredUID))
if err != nil {
return nil, err
}
return []byte(user.Username), nil
}
if len(userplugin) == 0 {
logutil.Logger(ctx).Warn("No user plugin set, assuming MySQL Native Password",
zap.String("user", cc.user), zap.String("host", cc.peerHost))
*authPlugin = mysql.AuthNativePassword
// No user plugin set, assuming MySQL Native Password
// This happens if the account doesn't exist or if the account doesn't have
// a password set.
if resp.AuthPlugin != mysql.AuthNativePassword {
if resp.Capability&mysql.ClientPluginAuth > 0 {
resp.AuthPlugin = mysql.AuthNativePassword
authData, err := cc.authSwitchRequest(ctx, mysql.AuthNativePassword)
if err != nil {
return nil, err
}
return authData, nil
}
}
return nil, nil
}

@@ -878,13 +902,18 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) (
// or if the authentication method send by the server doesn't match the authentication
// method send by the client (*authPlugin) then we need to switch the authentication
// method to match the one configured for that specific user.
if (cc.authPlugin != userplugin) || (cc.authPlugin != *authPlugin) {
authData, err := cc.authSwitchRequest(ctx, userplugin)
if err != nil {
return nil, err
if (cc.authPlugin != userplugin) || (cc.authPlugin != resp.AuthPlugin) {
if resp.Capability&mysql.ClientPluginAuth > 0 {
authData, err := cc.authSwitchRequest(ctx, userplugin)
if err != nil {
return nil, err
}
resp.AuthPlugin = userplugin
return authData, nil
} else if userplugin != mysql.AuthNativePassword {
// MySQL 5.1 and older don't support authentication plugins yet
return nil, errNotSupportedAuthMode
}
*authPlugin = userplugin
return authData, nil
}

return nil, nil
189 changes: 182 additions & 7 deletions server/conn_test.go
Original file line number Diff line number Diff line change
@@ -894,8 +894,6 @@ func TestShowErrors(t *testing.T) {
}

func TestHandleAuthPlugin(t *testing.T) {
t.Parallel()

store, clean := testkit.CreateMockStore(t)
defer clean()

@@ -905,25 +903,202 @@ func TestHandleAuthPlugin(t *testing.T) {
drv := NewTiDBDriver(store)
srv, err := NewServer(cfg, drv)
require.NoError(t, err)
ctx := context.Background()

tk := testkit.NewTestKit(t, store)
tk.MustExec("CREATE USER unativepassword")
defer func() {
tk.MustExec("DROP USER unativepassword")
}()

// 5.7 or newer client trying to authenticate with mysql_native_password
cc := &clientConn{
connectionID: 1,
alloc: arena.NewAllocator(1024),
collation: mysql.DefaultCollationID,
peerHost: "localhost",
pkt: &packetIO{
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
collation: mysql.DefaultCollationID,
server: srv,
user: "root",
server: srv,
user: "unativepassword",
}
ctx := context.Background()
resp := handshakeResponse41{
Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth,
AuthPlugin: mysql.AuthNativePassword,
}
err = cc.handleAuthPlugin(ctx, &resp)
require.NoError(t, err)

// 8.0 or newer client trying to authenticate with caching_sha2_password
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)"))
cc = &clientConn{
connectionID: 1,
alloc: arena.NewAllocator(1024),
collation: mysql.DefaultCollationID,
peerHost: "localhost",
pkt: &packetIO{
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp = handshakeResponse41{
Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth,
AuthPlugin: mysql.AuthCachingSha2Password,
}
err = cc.handleAuthPlugin(ctx, &resp)
require.NoError(t, err)
require.Equal(t, resp.Auth, []byte(mysql.AuthNativePassword))
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch"))

// MySQL 5.1 or older client, without authplugin support
cc = &clientConn{
connectionID: 1,
alloc: arena.NewAllocator(1024),
collation: mysql.DefaultCollationID,
peerHost: "localhost",
pkt: &packetIO{
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp = handshakeResponse41{
Capability: mysql.ClientProtocol41,
}
err = cc.handleAuthPlugin(ctx, &resp)
require.NoError(t, err)

// === Target account has mysql_native_password ===
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeUser", "return(\"mysql_native_password\")"))

// 5.7 or newer client trying to authenticate with mysql_native_password
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)"))
cc = &clientConn{
connectionID: 1,
alloc: arena.NewAllocator(1024),
collation: mysql.DefaultCollationID,
peerHost: "localhost",
pkt: &packetIO{
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp = handshakeResponse41{
Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth,
AuthPlugin: mysql.AuthNativePassword,
}
err = cc.handleAuthPlugin(ctx, &resp)
require.NoError(t, err)
require.Equal(t, []byte(mysql.AuthNativePassword), resp.Auth)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch"))

// 8.0 or newer client trying to authenticate with caching_sha2_password
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)"))
cc = &clientConn{
connectionID: 1,
alloc: arena.NewAllocator(1024),
collation: mysql.DefaultCollationID,
peerHost: "localhost",
pkt: &packetIO{
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp = handshakeResponse41{
Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth,
AuthPlugin: mysql.AuthCachingSha2Password,
}
err = cc.handleAuthPlugin(ctx, &resp)
require.NoError(t, err)
require.Equal(t, []byte(mysql.AuthNativePassword), resp.Auth)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch"))

// MySQL 5.1 or older client, without authplugin support
cc = &clientConn{
connectionID: 1,
alloc: arena.NewAllocator(1024),
collation: mysql.DefaultCollationID,
peerHost: "localhost",
pkt: &packetIO{
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp = handshakeResponse41{
Capability: mysql.ClientProtocol41,
}
err = cc.handleAuthPlugin(ctx, &resp)
require.NoError(t, err)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeUser"))

// === Target account has caching_sha2_password ===
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeUser", "return(\"caching_sha2_password\")"))

// 5.7 or newer client trying to authenticate with mysql_native_password
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)"))
cc = &clientConn{
connectionID: 1,
alloc: arena.NewAllocator(1024),
collation: mysql.DefaultCollationID,
peerHost: "localhost",
pkt: &packetIO{
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp = handshakeResponse41{
Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth,
AuthPlugin: mysql.AuthNativePassword,
}
err = cc.handleAuthPlugin(ctx, &resp)
require.NoError(t, err)
require.Equal(t, []byte(mysql.AuthCachingSha2Password), resp.Auth)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch"))

resp.Capability = mysql.ClientProtocol41
// 8.0 or newer client trying to authenticate with caching_sha2_password
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)"))
cc = &clientConn{
connectionID: 1,
alloc: arena.NewAllocator(1024),
collation: mysql.DefaultCollationID,
peerHost: "localhost",
pkt: &packetIO{
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp = handshakeResponse41{
Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth,
AuthPlugin: mysql.AuthCachingSha2Password,
}
err = cc.handleAuthPlugin(ctx, &resp)
require.NoError(t, err)
require.Equal(t, []byte(mysql.AuthCachingSha2Password), resp.Auth)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch"))

// MySQL 5.1 or older client, without authplugin support
cc = &clientConn{
connectionID: 1,
alloc: arena.NewAllocator(1024),
collation: mysql.DefaultCollationID,
peerHost: "localhost",
pkt: &packetIO{
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp = handshakeResponse41{
Capability: mysql.ClientProtocol41,
}
err = cc.handleAuthPlugin(ctx, &resp)
require.Error(t, err)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeUser"))
}
1 change: 1 addition & 0 deletions server/http_handler_test.go
Original file line number Diff line number Diff line change
@@ -483,6 +483,7 @@ func (ts *basicHTTPHandlerTestSuite) startServer(c *C) {
cfg.Port = 0
cfg.Status.StatusPort = 0
cfg.Status.ReportStatus = true
cfg.Socket = ""

server, err := NewServer(cfg, ts.tidbdrv)
c.Assert(err, IsNil)
1 change: 1 addition & 0 deletions server/plan_replayer_test.go
Original file line number Diff line number Diff line change
@@ -39,6 +39,7 @@ func TestDumpPlanReplayerAPI(t *testing.T) {
client := newTestServerClient()
cfg := newTestConfig()
cfg.Port = client.port
cfg.Socket = ""
cfg.Status.StatusPort = client.statusPort
cfg.Status.ReportStatus = true

1 change: 1 addition & 0 deletions server/statistics_handler_serial_test.go
Original file line number Diff line number Diff line change
@@ -38,6 +38,7 @@ func TestDumpStatsAPI(t *testing.T) {
client := newTestServerClient()
cfg := newTestConfig()
cfg.Port = client.port
cfg.Socket = ""
cfg.Status.StatusPort = client.statusPort
cfg.Status.ReportStatus = true

102 changes: 36 additions & 66 deletions session/session.go
Original file line number Diff line number Diff line change
@@ -24,7 +24,6 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"net"
"runtime/pprof"
"runtime/trace"
"strconv"
@@ -146,6 +145,7 @@ type Session interface {
Auth(user *auth.UserIdentity, auth []byte, salt []byte) bool
AuthWithoutVerification(user *auth.UserIdentity) bool
AuthPluginForUser(user *auth.UserIdentity) (string, error)
MatchIdentity(username, remoteHost string) (*auth.UserIdentity, error)
ShowProcess() *util.ProcessInfo
// Return the information of the txn current running
TxnInfo() *txninfo.TxnInfo
@@ -2211,91 +2211,61 @@ func (s *session) AuthPluginForUser(user *auth.UserIdentity) (string, error) {
return authplugin, nil
}

// Auth validates a user using an authentication string and salt.
// If the password fails, it will keep trying other users until exhausted.
// This means it can not be refactored to use MatchIdentity yet.
func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []byte) bool {
pm := privilege.GetPrivilegeManager(s)

// Check IP or localhost.
var success bool
user.AuthUsername, user.AuthHostname, success = pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt, s.sessionVars.TLSConnectionState)
if success {
authUser, err := s.MatchIdentity(user.Username, user.Hostname)
if err != nil {
return false
}
if pm.ConnectionVerification(authUser.Username, authUser.Hostname, authentication, salt, s.sessionVars.TLSConnectionState) {
user.AuthUsername = authUser.Username
user.AuthHostname = authUser.Hostname
s.sessionVars.User = user
s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname)
return true
} else if user.Hostname == variable.DefHostname {
return false
}
return false
}

// Check Hostname.
for _, addr := range s.getHostByIP(user.Hostname) {
u, h, success := pm.ConnectionVerification(user.Username, addr, authentication, salt, s.sessionVars.TLSConnectionState)
if success {
s.sessionVars.User = &auth.UserIdentity{
Username: user.Username,
Hostname: addr,
AuthUsername: u,
AuthHostname: h,
}
s.sessionVars.ActiveRoles = pm.GetDefaultRoles(u, h)
return true
}
// MatchIdentity finds the matching username + password in the MySQL privilege tables
// for a username + hostname, since MySQL can have wildcards.
func (s *session) MatchIdentity(username, remoteHost string) (*auth.UserIdentity, error) {
pm := privilege.GetPrivilegeManager(s)
var success bool
var skipNameResolve bool
var user = &auth.UserIdentity{}
varVal, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.SkipNameResolve)
if err == nil && variable.TiDBOptOn(varVal) {
skipNameResolve = true
}
return false
user.Username, user.Hostname, success = pm.MatchIdentity(username, remoteHost, skipNameResolve)
if success {
return user, nil
}
// This error will not be returned to the user, access denied will be instead
return nil, fmt.Errorf("could not find matching user in MatchIdentity: %s, %s", username, remoteHost)
}

// AuthWithoutVerification is required by the ResetConnection RPC
func (s *session) AuthWithoutVerification(user *auth.UserIdentity) bool {
pm := privilege.GetPrivilegeManager(s)

// Check IP or localhost.
var success bool
user.AuthUsername, user.AuthHostname, success = pm.GetAuthWithoutVerification(user.Username, user.Hostname)
if success {
authUser, err := s.MatchIdentity(user.Username, user.Hostname)
if err != nil {
return false
}
if pm.GetAuthWithoutVerification(authUser.Username, authUser.Hostname) {
user.AuthUsername = authUser.Username
user.AuthHostname = authUser.Hostname
s.sessionVars.User = user
s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname)
return true
} else if user.Hostname == variable.DefHostname {
return false
}

// Check Hostname.
for _, addr := range s.getHostByIP(user.Hostname) {
u, h, success := pm.GetAuthWithoutVerification(user.Username, addr)
if success {
s.sessionVars.User = &auth.UserIdentity{
Username: user.Username,
Hostname: addr,
AuthUsername: u,
AuthHostname: h,
}
s.sessionVars.ActiveRoles = pm.GetDefaultRoles(u, h)
return true
}
}
return false
}

func (s *session) getHostByIP(ip string) []string {
if ip == "127.0.0.1" {
return []string{variable.DefHostname}
}
skipNameResolve, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.SkipNameResolve)
if err == nil && variable.TiDBOptOn(skipNameResolve) {
return []string{ip} // user wants to skip name resolution
}
addrs, err := net.LookupAddr(ip)
if err != nil {
// These messages can be noisy.
// See: https://github.com/pingcap/tidb/pull/13989
logutil.BgLogger().Debug(
"net.LookupAddr returned an error during auth check",
zap.String("ip", ip),
zap.Error(err),
)
return []string{ip}
}
return addrs
}

// RefreshVars implements the sessionctx.Context interface.
func (s *session) RefreshVars(ctx context.Context) error {
pruneMode, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.TiDBPartitionPruneMode)
45 changes: 45 additions & 0 deletions session/session_test.go
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ import (
"context"
"flag"
"fmt"
"net"
"os"
"path"
"runtime"
@@ -691,6 +692,50 @@ func (s *testSessionSuite) TestGlobalVarAccessor(c *C) {
c.Assert(terror.ErrorEqual(err, variable.ErrUnknownTimeZone), IsTrue)
}

func (s *testSessionSuite) TestMatchIdentity(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.MustExec("CREATE USER `useridentity`@`%`")
tk.MustExec("CREATE USER `useridentity`@`localhost`")
tk.MustExec("CREATE USER `useridentity`@`192.168.1.1`")
tk.MustExec("CREATE USER `useridentity`@`example.com`")

// The MySQL matching rule is most specific to least specific.
// So if I log in from 192.168.1.1 I should match that entry always.
identity, err := tk.Se.MatchIdentity("useridentity", "192.168.1.1")
c.Assert(err, IsNil)
c.Assert(identity.Username, Equals, "useridentity")
c.Assert(identity.Hostname, Equals, "192.168.1.1")

// If I log in from localhost, I should match localhost
identity, err = tk.Se.MatchIdentity("useridentity", "localhost")
c.Assert(err, IsNil)
c.Assert(identity.Username, Equals, "useridentity")
c.Assert(identity.Hostname, Equals, "localhost")

// If I log in from 192.168.1.2 I should match wildcard.
identity, err = tk.Se.MatchIdentity("useridentity", "192.168.1.2")
c.Assert(err, IsNil)
c.Assert(identity.Username, Equals, "useridentity")
c.Assert(identity.Hostname, Equals, "%")

identity, err = tk.Se.MatchIdentity("useridentity", "127.0.0.1")
c.Assert(err, IsNil)
c.Assert(identity.Username, Equals, "useridentity")
c.Assert(identity.Hostname, Equals, "localhost")

// This uses the lookup of example.com to get an IP address.
// We then login with that IP address, but expect it to match the example.com
// entry in the privileges table (by reverse lookup).
ips, err := net.LookupHost("example.com")
c.Assert(err, IsNil)
identity, err = tk.Se.MatchIdentity("useridentity", ips[0])
c.Assert(err, IsNil)
c.Assert(identity.Username, Equals, "useridentity")
// FIXME: we *should* match example.com instead
// as long as skip-name-resolve is not set (DEFAULT)
c.Assert(identity.Hostname, Equals, "%")
}

func (s *testSessionSuite) TestGetSysVariables(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)