From 406135134cf90848ad97cdcc0bbf62663f6b1eaa Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Thu, 3 Mar 2022 12:04:31 -0500 Subject: [PATCH] Better Semaphore Lease Contention Handling (#10666) Add the same retry logic from `AcquireSemaphore` to `CancelSemaphoreLease` to handle contetion. Without the retry it is possible for a cancellation to fail and the lease in question to remain held for its entire expiry. If the number of cancellations that fail is >= `max_connections` then this causes a user to effectively be locked out until the leases are expired. Fixes #10363 --- lib/services/local/presence.go | 89 ++++++++++++--------- lib/services/local/presence_test.go | 117 +++++++++++++++++++--------- 2 files changed, 136 insertions(+), 70 deletions(-) diff --git a/lib/services/local/presence.go b/lib/services/local/presence.go index 4e4b37cd2b8c7..b2d4f8b8f3b5d 100644 --- a/lib/services/local/presence.go +++ b/lib/services/local/presence.go @@ -798,6 +798,13 @@ func (s *PresenceService) DeleteAllRemoteClusters() error { return trace.Wrap(err) } +// this combination of backoff parameters leads to worst-case total time spent +// in backoff between 1200ms and 2400ms depending on jitter. tests are in +// place to verify that this is sufficient to resolve a 20-lease contention +// event, which is worse than should ever occur in practice. +const baseBackoff = time.Millisecond * 300 +const leaseRetryAttempts int64 = 6 + // AcquireSemaphore attempts to acquire the specified semaphore. AcquireSemaphore will automatically handle // retry on contention. If the semaphore has already reached MaxLeases, or there is too much contention, // a LimitExceeded error is returned (contention in this context means concurrent attempts to update the @@ -805,13 +812,6 @@ func (s *PresenceService) DeleteAllRemoteClusters() error { // is the only semaphore method that handles retries internally. This is because this method both blocks // user-facing operations, and contains multiple different potential contention points. func (s *PresenceService) AcquireSemaphore(ctx context.Context, req types.AcquireSemaphoreRequest) (*types.SemaphoreLease, error) { - // this combination of backoff parameters leads to worst-case total time spent - // in backoff between 1200ms and 2400ms depending on jitter. tests are in - // place to verify that this is sufficient to resolve a 20-lease contention - // event, which is worse than should ever occur in practice. - const baseBackoff = time.Millisecond * 300 - const acquireAttempts int64 = 6 - if err := req.Check(); err != nil { return nil, trace.Wrap(err) } @@ -826,7 +826,7 @@ func (s *PresenceService) AcquireSemaphore(ctx context.Context, req types.Acquir key := backend.Key(semaphoresPrefix, req.SemaphoreKind, req.SemaphoreName) Acquire: - for i := int64(0); i < acquireAttempts; i++ { + for i := int64(0); i < leaseRetryAttempts; i++ { if i > 0 { // Not our first attempt, apply backoff. If we knew that we were only in // contention with one other acquire attempt we could retry immediately @@ -997,40 +997,59 @@ func (s *PresenceService) CancelSemaphoreLease(ctx context.Context, lease types. return trace.BadParameter("the lease %v has expired at %v", lease.LeaseID, lease.Expires) } - key := backend.Key(semaphoresPrefix, lease.SemaphoreKind, lease.SemaphoreName) - item, err := s.Get(ctx, key) - if err != nil { - return trace.Wrap(err) - } + for i := int64(0); i < leaseRetryAttempts; i++ { + if i > 0 { + // Not our first attempt, apply backoff. If we knew that we were only in + // contention with one other cancel attempt we could retry immediately + // since we got here because some other attempt *succeeded*. It is safer, + // however, to assume that we are under high contention and attempt to + // spread out retries via random backoff. + select { + case <-time.After(s.jitter(baseBackoff * time.Duration(i))): + case <-ctx.Done(): + return trace.Wrap(ctx.Err()) + } + } - sem, err := services.UnmarshalSemaphore(item.Value) - if err != nil { - return trace.Wrap(err) - } + key := backend.Key(semaphoresPrefix, lease.SemaphoreKind, lease.SemaphoreName) + item, err := s.Get(ctx, key) + if err != nil { + return trace.Wrap(err) + } - if err := sem.Cancel(lease); err != nil { - return trace.Wrap(err) - } + sem, err := services.UnmarshalSemaphore(item.Value) + if err != nil { + return trace.Wrap(err) + } - newValue, err := services.MarshalSemaphore(sem) - if err != nil { - return trace.Wrap(err) - } + if err := sem.Cancel(lease); err != nil { + return trace.Wrap(err) + } - newItem := backend.Item{ - Key: key, - Value: newValue, - Expires: sem.Expiry(), - } + newValue, err := services.MarshalSemaphore(sem) + if err != nil { + return trace.Wrap(err) + } - _, err = s.CompareAndSwap(ctx, *item, newItem) - if err != nil { - if trace.IsCompareFailed(err) { - return trace.CompareFailed("semaphore %v/%v has been concurrently updated, try again", sem.GetSubKind(), sem.GetName()) + newItem := backend.Item{ + Key: key, + Value: newValue, + Expires: sem.Expiry(), + } + + _, err = s.CompareAndSwap(ctx, *item, newItem) + switch { + case err == nil: + return nil + case trace.IsCompareFailed(err): + // semaphore was concurrently updated + continue + default: + return trace.Wrap(err) } - return trace.Wrap(err) } - return nil + + return trace.LimitExceeded("too much contention on semaphore %s/%s", lease.SemaphoreKind, lease.SemaphoreName) } // GetSemaphores returns all semaphores matching the supplied filter. diff --git a/lib/services/local/presence_test.go b/lib/services/local/presence_test.go index ae9a152cea679..88c3e91592d0a 100644 --- a/lib/services/local/presence_test.go +++ b/lib/services/local/presence_test.go @@ -27,7 +27,6 @@ import ( "github.com/google/uuid" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" - "gopkg.in/check.v1" "github.com/gravitational/teleport/api/client/proto" apidefaults "github.com/gravitational/teleport/api/defaults" @@ -40,26 +39,14 @@ import ( "github.com/gravitational/trace" ) -type PresenceSuite struct { - bk backend.Backend -} - -var _ = check.Suite(&PresenceSuite{}) - -func (s *PresenceSuite) SetUpTest(c *check.C) { - var err error - - s.bk, err = lite.New(context.TODO(), backend.Params{"path": c.MkDir()}) - c.Assert(err, check.IsNil) -} +func TestTrustedClusterCRUD(t *testing.T) { + ctx := context.Background() -func (s *PresenceSuite) TearDownTest(c *check.C) { - c.Assert(s.bk.Close(), check.IsNil) -} + bk, err := lite.New(ctx, backend.Params{"path": t.TempDir()}) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, bk.Close()) }) -func (s *PresenceSuite) TestTrustedClusterCRUD(c *check.C) { - ctx := context.Background() - presenceBackend := NewPresenceService(s.bk) + presenceBackend := NewPresenceService(bk) tc, err := types.NewTrustedCluster("foo", types.TrustedClusterSpecV2{ Enabled: true, @@ -68,7 +55,7 @@ func (s *PresenceSuite) TestTrustedClusterCRUD(c *check.C) { ProxyAddress: "quux", ReverseTunnelAddress: "quuz", }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // we just insert this one for get all stc, err := types.NewTrustedCluster("bar", types.TrustedClusterSpecV2{ @@ -78,37 +65,37 @@ func (s *PresenceSuite) TestTrustedClusterCRUD(c *check.C) { ProxyAddress: "quuz", ReverseTunnelAddress: "corge", }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // create trusted clusters _, err = presenceBackend.UpsertTrustedCluster(ctx, tc) - c.Assert(err, check.IsNil) + require.NoError(t, err) _, err = presenceBackend.UpsertTrustedCluster(ctx, stc) - c.Assert(err, check.IsNil) + require.NoError(t, err) // get trusted cluster make sure it's correct gotTC, err := presenceBackend.GetTrustedCluster(ctx, "foo") - c.Assert(err, check.IsNil) - c.Assert(gotTC.GetName(), check.Equals, "foo") - c.Assert(gotTC.GetEnabled(), check.Equals, true) - c.Assert(gotTC.GetRoles(), check.DeepEquals, []string{"bar", "baz"}) - c.Assert(gotTC.GetToken(), check.Equals, "qux") - c.Assert(gotTC.GetProxyAddress(), check.Equals, "quux") - c.Assert(gotTC.GetReverseTunnelAddress(), check.Equals, "quuz") + require.NoError(t, err) + require.Equal(t, "foo", gotTC.GetName()) + require.True(t, gotTC.GetEnabled()) + require.EqualValues(t, []string{"bar", "baz"}, gotTC.GetRoles()) + require.Equal(t, "qux", gotTC.GetToken()) + require.Equal(t, "quux", gotTC.GetProxyAddress()) + require.Equal(t, "quuz", gotTC.GetReverseTunnelAddress()) // get all clusters allTC, err := presenceBackend.GetTrustedClusters(ctx) - c.Assert(err, check.IsNil) - c.Assert(allTC, check.HasLen, 2) + require.NoError(t, err) + require.Len(t, allTC, 2) // delete cluster err = presenceBackend.DeleteTrustedCluster(ctx, "foo") - c.Assert(err, check.IsNil) + require.NoError(t, err) // make sure it's really gone _, err = presenceBackend.GetTrustedCluster(ctx, "foo") - c.Assert(err, check.NotNil) - c.Assert(trace.IsNotFound(err), check.Equals, true) + require.Error(t, err) + require.ErrorIs(t, err, trace.NotFound("key /trustedclusters/foo is not found")) } // TestApplicationServersCRUD verifies backend operations on app servers. @@ -1130,3 +1117,63 @@ func TestFakePaginate_TotalCount(t *testing.T) { require.Equal(t, 3, resp.TotalCount) }) } + +func TestPresenceService_CancelSemaphoreLease(t *testing.T) { + ctx := context.Background() + bk, err := lite.New(ctx, backend.Params{"path": t.TempDir()}) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, bk.Close()) }) + presence := NewPresenceService(bk) + + maxLeases := 5 + leases := make([]*types.SemaphoreLease, maxLeases) + + // Acquire max number of leases + request := types.AcquireSemaphoreRequest{ + SemaphoreKind: "test", + SemaphoreName: "test", + MaxLeases: int64(maxLeases), + Expires: time.Now().Add(time.Hour), + Holder: "test", + } + for i := range leases { + lease, err := presence.AcquireSemaphore(ctx, request) + require.NoError(t, err) + require.NotNil(t, lease) + + leases[i] = lease + } + + // Validate a semaphore exists with the correct number of leases + semaphores, err := presence.GetSemaphores(ctx, types.SemaphoreFilter{ + SemaphoreKind: "test", + SemaphoreName: "test", + }) + require.NoError(t, err) + require.Len(t, semaphores, 1) + require.Len(t, semaphores[0].LeaseRefs(), maxLeases) + + // Cancel the leases concurrently and ensure that all + // cancellations are honored + errCh := make(chan error, maxLeases) + for _, l := range leases { + l := l + go func() { + errCh <- presence.CancelSemaphoreLease(ctx, *l) + }() + } + + for i := 0; i < maxLeases; i++ { + err := <-errCh + require.NoError(t, err) + } + + // Validate the semaphore still exists but all leases were removed + semaphores, err = presence.GetSemaphores(ctx, types.SemaphoreFilter{ + SemaphoreKind: "test", + SemaphoreName: "test", + }) + require.NoError(t, err) + require.Len(t, semaphores, 1) + require.Empty(t, semaphores[0].LeaseRefs()) +}