From 9e5c7a0b1a3a34a8782ead0a05ea63ae84d1a973 Mon Sep 17 00:00:00 2001 From: Santamaura Date: Wed, 3 Aug 2022 16:18:34 -0400 Subject: [PATCH 1/2] sql, server: regulate access to remaining observability features This change will control access to various observability features based on system privileges including the following: - admin ui databases/tables/schema endpoints requires admin or VIEWACTIVITY - EXPERIMENTAL_AUDIT requires admin or MODIFYCLUSTERSETTING - sql login requires not having NOSQLLOGIN or the equivalent role option Resolves: #83848, #83863, #83862 Release note (security update): Change requirements to access some observability features. Databases/tables/schema endpoints for admin ui require admin or VIEWACTIVITY. EXPERIMENTAL_AUDIT requires admin or MODIFYCLUSTERSETTING. SQL login requires not having NOSQLLOGIN or the equivalent role option. --- pkg/server/admin.go | 22 ++++++++- pkg/server/admin_test.go | 9 ++++ pkg/server/index_usage_stats.go | 2 +- pkg/sql/alter_table.go | 25 +++++++++- .../logictest/testdata/logic_test/alter_table | 2 +- pkg/sql/sessioninit/BUILD.bazel | 2 + pkg/sql/sessioninit/cache.go | 7 ++- pkg/sql/sessioninit/cache_test.go | 47 ++++++++++++++++--- pkg/sql/user.go | 43 ++++++++++++++++- 9 files changed, 143 insertions(+), 16 deletions(-) diff --git a/pkg/server/admin.go b/pkg/server/admin.go index 0f10702fa5dc..529a650beb9f 100644 --- a/pkg/server/admin.go +++ b/pkg/server/admin.go @@ -248,6 +248,10 @@ func (s *adminServer) Databases( return nil, serverError(ctx, err) } + if err := s.requireViewActivityPermission(ctx); err != nil { + return nil, err + } + r, err := s.databasesHelper(ctx, req, sessionUser, 0, 0) return r, maybeHandleNotFoundError(ctx, err) } @@ -314,6 +318,10 @@ func (s *adminServer) DatabaseDetails( return nil, serverError(ctx, err) } + if err := s.requireViewActivityPermission(ctx); err != nil { + return nil, err + } + r, err := s.databaseDetailsHelper(ctx, req, userName) return r, maybeHandleNotFoundError(ctx, err) } @@ -678,6 +686,10 @@ func (s *adminServer) TableDetails( return nil, serverError(ctx, err) } + if err := s.requireViewActivityPermission(ctx); err != nil { + return nil, err + } + r, err := s.tableDetailsHelper(ctx, req, userName) return r, maybeHandleNotFoundError(ctx, err) } @@ -1073,7 +1085,13 @@ func (s *adminServer) TableStats( ctx context.Context, req *serverpb.TableStatsRequest, ) (*serverpb.TableStatsResponse, error) { ctx = s.server.AnnotateCtx(ctx) - userName, err := s.requireAdminUser(ctx) + + userName, err := userFromContext(ctx) + if err != nil { + return nil, serverError(ctx, err) + } + + err = s.requireViewActivityPermission(ctx) if err != nil { // NB: not using serverError() here since the priv checker // already returns a proper gRPC error status. @@ -1103,7 +1121,7 @@ func (s *adminServer) NonTableStats( ctx context.Context, req *serverpb.NonTableStatsRequest, ) (*serverpb.NonTableStatsResponse, error) { ctx = s.server.AnnotateCtx(ctx) - if _, err := s.requireAdminUser(ctx); err != nil { + if err := s.requireViewActivityPermission(ctx); err != nil { // NB: not using serverError() here since the priv checker // already returns a proper gRPC error status. return nil, err diff --git a/pkg/server/admin_test.go b/pkg/server/admin_test.go index 2927e5f76b7f..15aa9b4a063f 100644 --- a/pkg/server/admin_test.go +++ b/pkg/server/admin_test.go @@ -386,6 +386,15 @@ func TestAdminAPIDatabases(t *testing.T) { if _, err := db.Exec(query); err != nil { t.Fatal(err) } + // Non admins now also require VIEWACTIVITY. + query = fmt.Sprintf( + "GRANT SYSTEM %s TO %s", + "VIEWACTIVITY", + authenticatedUserNameNoAdmin().SQLIdentifier(), + ) + if _, err := db.Exec(query); err != nil { + t.Fatal(err) + } for _, tc := range []struct { expectedDBs []string diff --git a/pkg/server/index_usage_stats.go b/pkg/server/index_usage_stats.go index 5ef5978705fc..6d7442d9fa09 100644 --- a/pkg/server/index_usage_stats.go +++ b/pkg/server/index_usage_stats.go @@ -194,7 +194,7 @@ func (s *statusServer) TableIndexStats( ctx = propagateGatewayMetadata(ctx) ctx = s.AnnotateCtx(ctx) - if err := s.privilegeChecker.requireViewActivityOrViewActivityRedactedPermission(ctx); err != nil { + if err := s.privilegeChecker.requireViewActivityPermission(ctx); err != nil { return nil, err } return getTableIndexUsageStats(ctx, req, s.sqlServer.pgServer.SQLServer.GetLocalIndexStatistics(), diff --git a/pkg/sql/alter_table.go b/pkg/sql/alter_table.go index dc3db15d0862..a3dba6806dc0 100644 --- a/pkg/sql/alter_table.go +++ b/pkg/sql/alter_table.go @@ -18,6 +18,7 @@ import ( "sort" "time" + "github.com/cockroachdb/cockroach/pkg/clusterversion" "github.com/cockroachdb/cockroach/pkg/jobs" "github.com/cockroachdb/cockroach/pkg/keys" "github.com/cockroachdb/cockroach/pkg/security/username" @@ -34,6 +35,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice" "github.com/cockroachdb/cockroach/pkg/sql/privilege" + "github.com/cockroachdb/cockroach/pkg/sql/roleoption" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sem/volatility" @@ -43,6 +45,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/stats" "github.com/cockroachdb/cockroach/pkg/sql/storageparam" "github.com/cockroachdb/cockroach/pkg/sql/storageparam/tablestorageparam" + "github.com/cockroachdb/cockroach/pkg/sql/syntheticprivilege" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" "github.com/cockroachdb/cockroach/pkg/util/log/eventpb" @@ -907,10 +910,28 @@ func (p *planner) setAuditMode( p.curPlan.auditEvents = append(p.curPlan.auditEvents, auditEvent{desc: desc, writing: true}) - // We require root for now. Later maybe use a different permission? - if err := p.RequireAdminRole(ctx, "change auditing settings on a table"); err != nil { + // Requires admin or MODIFYCLUSTERSETTING as of 22.2 + hasAdmin, err := p.HasAdminRole(ctx) + if err != nil { return false, err } + if !hasAdmin { + // Check for system privilege first, otherwise fall back to role options. + hasModify := false + if p.ExecCfg().Settings.Version.IsActive(ctx, clusterversion.SystemPrivilegesTable) { + hasModify = p.CheckPrivilege(ctx, syntheticprivilege.GlobalPrivilegeObject, privilege.MODIFYCLUSTERSETTING) == nil + } + if !hasModify { + hasModify, err = p.HasRoleOption(ctx, roleoption.MODIFYCLUSTERSETTING) + if err != nil { + return false, err + } + if !hasModify { + return false, pgerror.Newf(pgcode.InsufficientPrivilege, + "only users with admin or %s system privilege are allowed to change audit settings on a table ", privilege.MODIFYCLUSTERSETTING.String()) + } + } + } telemetry.Inc(sqltelemetry.SchemaSetAuditModeCounter(auditMode.TelemetryName())) diff --git a/pkg/sql/logictest/testdata/logic_test/alter_table b/pkg/sql/logictest/testdata/logic_test/alter_table index 0796fa5832c8..f3d11aea170f 100644 --- a/pkg/sql/logictest/testdata/logic_test/alter_table +++ b/pkg/sql/logictest/testdata/logic_test/alter_table @@ -850,7 +850,7 @@ statement ok ALTER TABLE audit ADD COLUMN y INT # But not the audit settings. -statement error change auditing settings on a table +statement error pq: only users with admin or MODIFYCLUSTERSETTING system privilege are allowed to change audit settings on a table ALTER TABLE audit EXPERIMENTAL_AUDIT SET OFF user root diff --git a/pkg/sql/sessioninit/BUILD.bazel b/pkg/sql/sessioninit/BUILD.bazel index e21e60ae0fc7..741fa3ce1b61 100644 --- a/pkg/sql/sessioninit/BUILD.bazel +++ b/pkg/sql/sessioninit/BUILD.bazel @@ -42,9 +42,11 @@ go_test( "//pkg/security/securitytest", "//pkg/security/username", "//pkg/server", + "//pkg/settings/cluster", "//pkg/sql", "//pkg/sql/catalog/descpb", "//pkg/sql/catalog/descs", + "//pkg/sql/sessiondatapb", "//pkg/sql/sqlutil", "//pkg/testutils/serverutils", "//pkg/testutils/sqlutils", diff --git a/pkg/sql/sessioninit/cache.go b/pkg/sql/sessioninit/cache.go index 79a04a366ca8..c18c56eff6f2 100644 --- a/pkg/sql/sessioninit/cache.go +++ b/pkg/sql/sessioninit/cache.go @@ -114,10 +114,13 @@ func (a *Cache) GetAuthInfo( ctx context.Context, ie sqlutil.InternalExecutor, username username.SQLUsername, + makePlanner func(opName string) (interface{}, func()), + settings *cluster.Settings, ) (AuthInfo, error), + makePlanner func(opName string) (interface{}, func()), ) (aInfo AuthInfo, err error) { if !CacheEnabled.Get(&settings.SV) { - return readFromSystemTables(ctx, ie, username) + return readFromSystemTables(ctx, ie, username, makePlanner, settings) } var usersTableDesc catalog.TableDescriptor @@ -164,7 +167,7 @@ func (a *Cache) GetAuthInfo( val, err := a.loadValueOutsideOfCache( ctx, fmt.Sprintf("authinfo-%s-%d-%d", username.Normalized(), usersTableVersion, roleOptionsTableVersion), func(loadCtx context.Context) (interface{}, error) { - return readFromSystemTables(loadCtx, ie, username) + return readFromSystemTables(loadCtx, ie, username, makePlanner, settings) }) if err != nil { return aInfo, err diff --git a/pkg/sql/sessioninit/cache_test.go b/pkg/sql/sessioninit/cache_test.go index a719d6fe5823..337b5095f5ae 100644 --- a/pkg/sql/sessioninit/cache_test.go +++ b/pkg/sql/sessioninit/cache_test.go @@ -19,9 +19,11 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/catalog/descs" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondatapb" "github.com/cockroachdb/cockroach/pkg/sql/sessioninit" "github.com/cockroachdb/cockroach/pkg/sql/sqlutil" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" @@ -74,18 +76,30 @@ func TestCacheInvalidation(t *testing.T) { return settings, didReadFromSystemTable, err } getAuthInfoFromCache := func() (sessioninit.AuthInfo, bool, error) { + makePlanner := func(opName string) (interface{}, func()) { + return sql.NewInternalPlanner( + opName, + execCfg.DB.NewTxn(ctx, opName), + username.RootUserName(), + &sql.MemoryMetrics{}, + s.ExecutorConfig().(*sql.ExecutorConfig), + sessiondatapb.SessionData{}, + ) + } didReadFromSystemTable := false + settings := s.ClusterSettings() aInfo, err := execCfg.SessionInitCache.GetAuthInfo( ctx, - s.ClusterSettings(), + settings, s.InternalExecutor().(sqlutil.InternalExecutor), s.DB(), s.CollectionFactory().(*descs.CollectionFactory), username.TestUserName(), - func(ctx context.Context, ie sqlutil.InternalExecutor, userName username.SQLUsername) (sessioninit.AuthInfo, error) { + func(ctx context.Context, ie sqlutil.InternalExecutor, userName username.SQLUsername, makePlanner func(opName string) (interface{}, func()), settings *cluster.Settings) (sessioninit.AuthInfo, error) { didReadFromSystemTable = true return sessioninit.AuthInfo{}, nil - }) + }, + makePlanner) return aInfo, didReadFromSystemTable, err } @@ -202,6 +216,7 @@ func TestCacheSingleFlight(t *testing.T) { ctx := context.Background() s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) defer s.Stopper().Stop(ctx) + execCfg := s.ExecutorConfig().(sql.ExecutorConfig) settings := s.ExecutorConfig().(sql.ExecutorConfig).Settings ie := s.InternalExecutor().(sqlutil.InternalExecutor) c := s.ExecutorConfig().(sql.ExecutorConfig).SessionInitCache @@ -219,18 +234,32 @@ func TestCacheSingleFlight(t *testing.T) { wgFirstGetAuthInfoCallInProgress.Add(1) wgForTestComplete.Add(3) + makePlanner := func(opName string) (interface{}, func()) { + return sql.NewInternalPlanner( + opName, + execCfg.DB.NewTxn(ctx, opName), + username.RootUserName(), + &sql.MemoryMetrics{}, + s.ExecutorConfig().(*sql.ExecutorConfig), + sessiondatapb.SessionData{}, + ) + } + go func() { didReadFromSystemTable := false _, err := c.GetAuthInfo(ctx, settings, ie, s.DB(), s.ExecutorConfig().(sql.ExecutorConfig).CollectionFactory, testuser, func( ctx context.Context, ie sqlutil.InternalExecutor, userName username.SQLUsername, + makePlanner func(opName string) (interface{}, func()), + settings *cluster.Settings, ) (sessioninit.AuthInfo, error) { wgFirstGetAuthInfoCallInProgress.Done() wgForConcurrentReadWrite.Wait() didReadFromSystemTable = true return sessioninit.AuthInfo{}, nil - }) + }, + makePlanner) require.NoError(t, err) require.True(t, didReadFromSystemTable) wgForTestComplete.Done() @@ -249,10 +278,13 @@ func TestCacheSingleFlight(t *testing.T) { ctx context.Context, ie sqlutil.InternalExecutor, userName username.SQLUsername, + makePlanner func(opName string) (interface{}, func()), + settings *cluster.Settings, ) (sessioninit.AuthInfo, error) { didReadFromSystemTable = true return sessioninit.AuthInfo{}, nil - }) + }, + makePlanner) require.NoError(t, err) require.False(t, didReadFromSystemTable) wgForTestComplete.Done() @@ -270,10 +302,13 @@ func TestCacheSingleFlight(t *testing.T) { ctx context.Context, ie sqlutil.InternalExecutor, userName username.SQLUsername, + makePlanner func(opName string) (interface{}, func()), + settings *cluster.Settings, ) (sessioninit.AuthInfo, error) { didReadFromSystemTable = true return sessioninit.AuthInfo{}, nil - }) + }, + makePlanner) require.NoError(t, err) require.True(t, didReadFromSystemTable) diff --git a/pkg/sql/user.go b/pkg/sql/user.go index 0cd96fa86e90..9d03beeb4646 100644 --- a/pkg/sql/user.go +++ b/pkg/sql/user.go @@ -14,20 +14,25 @@ import ( "context" "time" + "github.com/cockroachdb/cockroach/pkg/clusterversion" "github.com/cockroachdb/cockroach/pkg/keys" "github.com/cockroachdb/cockroach/pkg/kv" "github.com/cockroachdb/cockroach/pkg/security" "github.com/cockroachdb/cockroach/pkg/security/password" "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/settings" + "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/catalog/descs" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondatapb" "github.com/cockroachdb/cockroach/pkg/sql/sessioninit" "github.com/cockroachdb/cockroach/pkg/sql/sqlutil" + "github.com/cockroachdb/cockroach/pkg/sql/syntheticprivilege" "github.com/cockroachdb/cockroach/pkg/util/contextutil" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/log/eventpb" @@ -207,6 +212,16 @@ func retrieveSessionInitInfoWithCache( databaseName string, ) (aInfo sessioninit.AuthInfo, settingsEntries []sessioninit.SettingsCacheEntry, err error) { if err = func() (retErr error) { + makePlanner := func(opName string) (interface{}, func()) { + return NewInternalPlanner( + opName, + execCfg.DB.NewTxn(ctx, opName), + username.RootUserName(), + &MemoryMetrics{}, + execCfg, + sessiondatapb.SessionData{}, + ) + } aInfo, retErr = execCfg.SessionInitCache.GetAuthInfo( ctx, execCfg.Settings, @@ -215,6 +230,7 @@ func retrieveSessionInitInfoWithCache( execCfg.CollectionFactory, userName, retrieveAuthInfo, + makePlanner, ) if retErr != nil { return retErr @@ -243,7 +259,11 @@ func retrieveSessionInitInfoWithCache( } func retrieveAuthInfo( - ctx context.Context, ie sqlutil.InternalExecutor, user username.SQLUsername, + ctx context.Context, + ie sqlutil.InternalExecutor, + user username.SQLUsername, + makePlanner func(opName string) (interface{}, func()), + settings *cluster.Settings, ) (aInfo sessioninit.AuthInfo, retErr error) { // Use fully qualified table name to avoid looking up "".system.users. // We use a nil txn as login is not tied to any transaction state, and @@ -297,6 +317,23 @@ func retrieveAuthInfo( aInfo.CanLoginSQL = true aInfo.CanLoginDBConsole = true var ok bool + + // Check system privilege to see if user can sql login. + planner, cleanup := makePlanner("check-privilege") + defer cleanup() + aa := planner.(AuthorizationAccessor) + hasAdmin, err := aa.HasAdminRole(ctx) + if err != nil { + return aInfo, err + } + if !hasAdmin { + if settings.Version.IsActive(ctx, clusterversion.SystemPrivilegesTable) { + if noSQLLogin := aa.CheckPrivilegeForUser(ctx, syntheticprivilege.GlobalPrivilegeObject, privilege.NOSQLLOGIN, user) == nil; noSQLLogin { + aInfo.CanLoginSQL = false + } + } + } + for ok, err = roleOptsIt.Next(ctx); ok; ok, err = roleOptsIt.Next(ctx) { row := roleOptsIt.Cur() option := string(tree.MustBeDString(row[0])) @@ -305,7 +342,9 @@ func retrieveAuthInfo( aInfo.CanLoginSQL = false aInfo.CanLoginDBConsole = false } - if option == "NOSQLLOGIN" { + // If the user did not have the NOSQLLOGIN system privilege but has the + // equivalent role option set the flag to false. + if option == "NOSQLLOGIN" && aInfo.CanLoginSQL { aInfo.CanLoginSQL = false } From ae851faa2f25c6bcccb9101451538e18c2341459 Mon Sep 17 00:00:00 2001 From: Jay Date: Wed, 10 Aug 2022 18:47:17 -0400 Subject: [PATCH 2/2] ccl/sqlproxyccl: ensure that connections cannot be transferred before init Related to #80446. In #80446, we updated the connection tracker to track server assignments instead of forwarders. This also meant that there is a possibility where we can start transferring the connection before we even resumed the forwarder for the first time, breaking the TransferConnection invariant where the processors must be resumed before being called. This commit fixes that issue by introducing a new isInitialized flag to the forwarder, which will only get set to true once run returns. Attempting to transfer a connection with isInitialized=false will return an error. This should fix flakes that we've been seeing on CI. Release note: None Release justification: sqlproxy bug fix. This ensures that we don't resume the processors mid connection transfer, causing unexpected issues on the client's end. Note that this situation is rare since it involves ensuring timely behavior of forwarder.Run and forwarder.TransferConnection at the same time. --- pkg/ccl/sqlproxyccl/conn_migration.go | 12 ++--- pkg/ccl/sqlproxyccl/conn_migration_test.go | 2 + pkg/ccl/sqlproxyccl/forwarder.go | 45 +++++++++++++------ pkg/ccl/sqlproxyccl/forwarder_test.go | 16 +++++++ pkg/ccl/sqlproxyccl/proxy_handler_test.go | 23 +--------- .../tenantdirsvr/test_static_directory_svr.go | 2 +- 6 files changed, 59 insertions(+), 41 deletions(-) diff --git a/pkg/ccl/sqlproxyccl/conn_migration.go b/pkg/ccl/sqlproxyccl/conn_migration.go index 1706acc0067c..61ed4cf2b361 100644 --- a/pkg/ccl/sqlproxyccl/conn_migration.go +++ b/pkg/ccl/sqlproxyccl/conn_migration.go @@ -78,7 +78,7 @@ func (f *forwarder) tryBeginTransfer() (started bool, cleanupFn func()) { defer f.mu.Unlock() // Forwarder hasn't been initialized. - if !f.isInitializedLocked() { + if !f.mu.isInitialized { return false, nil } @@ -120,9 +120,9 @@ var errTransferCannotStart = errors.New("transfer cannot be started") // where the forwarder is not in a state that is eligible for a connection // migration. // -// NOTE: If the forwarder hasn't been closed, runTransfer has an invariant +// NOTE: If the forwarder hasn't been closed, TransferConnection has an invariant // where the processors have been resumed prior to calling this method. When -// runTransfer returns, it is guaranteed that processors will either be +// TransferConnection returns, it is guaranteed that processors will either be // re-resumed, or the forwarder will be closed (in the case of a non-recoverable // error). // @@ -145,7 +145,7 @@ func (f *forwarder) TransferConnection() (retErr error) { // Create a transfer context, and timeout handler which gets triggered // whenever the context expires. We have to close the forwarder because // the transfer may be blocked on I/O, and the only way for now is to close - // the connections. This then allow runTransfer to return and cleanup. + // the connections. This then allow TransferConnection to return and cleanup. ctx, cancel := newTransferContext(f.ctx) defer cancel() @@ -177,8 +177,8 @@ func (f *forwarder) TransferConnection() (retErr error) { latencyDur := timeutil.Since(tBegin) f.metrics.ConnMigrationAttemptedLatency.RecordValue(latencyDur.Nanoseconds()) - // When runTransfer returns, it's either the forwarder has been closed, - // or the procesors have been resumed. + // When TransferConnection returns, it's either the forwarder has been + // closed, or the procesors have been resumed. if !ctx.isRecoverable() { log.Infof(logCtx, "transfer failed: connection closed, latency=%v, err=%v", latencyDur, retErr) f.metrics.ConnMigrationErrorFatalCount.Inc(1) diff --git a/pkg/ccl/sqlproxyccl/conn_migration_test.go b/pkg/ccl/sqlproxyccl/conn_migration_test.go index 93781d060d7f..09ba67742b77 100644 --- a/pkg/ccl/sqlproxyccl/conn_migration_test.go +++ b/pkg/ccl/sqlproxyccl/conn_migration_test.go @@ -91,6 +91,7 @@ func TestForwarder_tryBeginTransfer(t *testing.T) { f := &forwarder{} f.mu.request = &processor{} f.mu.response = &processor{} + f.mu.isInitialized = true started, cleanupFn := f.tryBeginTransfer() require.False(t, started) @@ -107,6 +108,7 @@ func TestForwarder_tryBeginTransfer(t *testing.T) { f := &forwarder{} f.mu.request = &processor{} f.mu.response = &processor{} + f.mu.isInitialized = true started, cleanupFn := f.tryBeginTransfer() require.True(t, started) diff --git a/pkg/ccl/sqlproxyccl/forwarder.go b/pkg/ccl/sqlproxyccl/forwarder.go index 641af896463f..88067604dd32 100644 --- a/pkg/ccl/sqlproxyccl/forwarder.go +++ b/pkg/ccl/sqlproxyccl/forwarder.go @@ -71,6 +71,17 @@ type forwarder struct { mu struct { syncutil.Mutex + // isInitialized indicates that the forwarder has been initialized. + // + // TODO(jaylim-crl): This prevents the connection from being transferred + // before we fully resume the processors (because the balancer now + // tracks assignments instead of forwarders). If we don't do this, there + // could be a situation where we resume the processors mid transfer. One + // alternative idea is to replace both isInitialized and isTransferring + // with a lock, which is held by the owner of the forwarder (e.g. main + // thread, or connection migrator thread). + isInitialized bool + // isTransferring indicates that a connection migration is in progress. isTransferring bool @@ -154,7 +165,7 @@ func newForwarder( // // run can only be called once throughout the lifetime of the forwarder. func (f *forwarder) run(clientConn net.Conn, serverConn net.Conn) error { - initialize := func() error { + setup := func() error { f.mu.Lock() defer f.mu.Unlock() @@ -165,8 +176,9 @@ func (f *forwarder) run(clientConn net.Conn, serverConn net.Conn) error { return f.ctx.Err() } - // Run can only be called once. - if f.isInitializedLocked() { + // Run can only be called once. If lastUpdated has already been set + // (i.e. non-zero), it has to be the case where run has been called. + if !f.mu.activity.lastUpdated.IsZero() { return errors.AssertionFailedf("forwarder has already been started") } @@ -185,10 +197,23 @@ func (f *forwarder) run(clientConn net.Conn, serverConn net.Conn) error { f.mu.activity.lastUpdated = f.timeSource.Now() return nil } - if err := initialize(); err != nil { - return err + markInitialized := func() { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.isInitialized = true + } + + if err := setup(); err != nil { + return errors.Wrap(err, "setting up forwarder") } - return f.resumeProcessors() + + if err := f.resumeProcessors(); err != nil { + return errors.Wrap(err, "resuming processors") + } + + // Mark the forwarder as initialized, and connection is ready for a transfer. + markInitialized() + return nil } // Context returns the context associated with the forwarder. @@ -237,7 +262,7 @@ func (f *forwarder) IsIdle() (idle bool) { defer f.mu.Unlock() // If the forwarder hasn't been initialized, it is considered active. - if !f.isInitializedLocked() { + if !f.mu.isInitialized { return false } @@ -270,12 +295,6 @@ func (f *forwarder) IsIdle() (idle bool) { return now.Sub(f.mu.activity.lastUpdated) >= idleTimeout } -// isInitializedLocked returns true if the forwarder has been initialized -// through Run, or false otherwise. -func (f *forwarder) isInitializedLocked() bool { - return f.mu.request != nil && f.mu.response != nil -} - // resumeProcessors starts both the request and response processors // asynchronously. The forwarder will be closed if any of the processors // return an error while resuming. This is idempotent as resume() will return diff --git a/pkg/ccl/sqlproxyccl/forwarder_test.go b/pkg/ccl/sqlproxyccl/forwarder_test.go index 2f6d9ba46887..88059e9fd2bc 100644 --- a/pkg/ccl/sqlproxyccl/forwarder_test.go +++ b/pkg/ccl/sqlproxyccl/forwarder_test.go @@ -44,6 +44,12 @@ func TestForward(t *testing.T) { err := f.run(p1, p2) require.NoError(t, err) + func() { + f.mu.Lock() + defer f.mu.Unlock() + require.True(t, f.mu.isInitialized) + }() + // Close the connection right away to simulate processor error. p1.Close() @@ -77,6 +83,11 @@ func TestForward(t *testing.T) { require.NoError(t, err) require.Nil(t, f.ctx.Err()) require.False(t, f.IsIdle()) + func() { + f.mu.Lock() + defer f.mu.Unlock() + require.True(t, f.mu.isInitialized) + }() f.mu.Lock() requestProc := f.mu.request @@ -217,6 +228,11 @@ func TestForward(t *testing.T) { require.NoError(t, err) require.Nil(t, f.ctx.Err()) require.False(t, f.IsIdle()) + func() { + f.mu.Lock() + defer f.mu.Unlock() + require.True(t, f.mu.isInitialized) + }() f.mu.Lock() responseProc := f.mu.response diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go index 15c499a1354a..53c865b674c7 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -837,11 +837,6 @@ func TestConnectionRebalancingDisabled(t *testing.T) { const podCount = 2 tenantID := serverutils.TestTenantID() tenants := startTestTenantPods(ctx, t, s, tenantID, podCount, base.TestingKnobs{}) - defer func() { - for _, tenant := range tenants { - tenant.Stopper().Stop(ctx) - } - }() // Register one SQL pod in the directory server. tds := tenantdirsvr.NewTestStaticDirectoryServer(s.Stopper(), nil /* timeSource */) @@ -934,11 +929,6 @@ func TestCancelQuery(t *testing.T) { }, } tenants := startTestTenantPods(ctx, t, s, tenantID, podCount, tenantKnobs) - defer func() { - for _, tenant := range tenants { - tenant.Stopper().Stop(ctx) - } - }() // Use a custom time source for testing. t0 := time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC) @@ -1272,11 +1262,6 @@ func TestPodWatcher(t *testing.T) { const podCount = 4 tenantID := serverutils.TestTenantID() tenants := startTestTenantPods(ctx, t, s, tenantID, podCount, base.TestingKnobs{}) - defer func() { - for _, tenant := range tenants { - tenant.Stopper().Stop(ctx) - } - }() // Register only 3 SQL pods in the directory server. We will add the 4th // once the watcher has been established. @@ -1739,11 +1724,6 @@ func TestCurConnCountMetric(t *testing.T) { // Start a single SQL pod. tenantID := serverutils.TestTenantID() tenants := startTestTenantPods(ctx, t, s, tenantID, 1, base.TestingKnobs{}) - defer func() { - for _, tenant := range tenants { - tenant.Stopper().Stop(ctx) - } - }() // Register the SQL pod in the directory server. tds := tenantdirsvr.NewTestStaticDirectoryServer(s.Stopper(), nil /* timeSource */) @@ -2295,7 +2275,8 @@ func queryAddr(ctx context.Context, t *testing.T, db queryer) string { // startTestTenantPods starts count SQL pods for the given tenant, and returns // a list of tenant servers. Note that a default admin testuser with the -// password hunter2 will be created. +// password hunter2 will be created. The test tenants will automatically be +// stopped once the server's stopper (from ts) is stopped. func startTestTenantPods( ctx context.Context, t *testing.T, diff --git a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go index a2df19ca21dc..0d7396de2f8a 100644 --- a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go @@ -344,7 +344,7 @@ func (d *TestStaticDirectoryServer) RemovePod(tenantID roachpb.TenantID, podAddr } // Start starts the test directory server using an in-memory listener. This -// returns an error if the server cannot be started. If the sevrer has already +// returns an error if the server cannot be started. If the server has already // been started, this is a no-op. func (d *TestStaticDirectoryServer) Start(ctx context.Context) error { d.process.Lock()