diff --git a/management/server/account.go b/management/server/account.go index b2b23dcb9a6..4648c00cdd6 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -136,7 +136,7 @@ type AccountManager interface { GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(account *Account) (map[string]struct{}, error) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) - CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error + CancelPeerRoutines(ctx context.Context, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) @@ -1858,6 +1858,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { + // acquiring peer write lock here is ok since we only modify peer information that is supplied by the + // peer itself which can't be modified by API, and it only happens after an account read lock is acquired + peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) + defer peerUnlock() + accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey) if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { @@ -1868,8 +1873,6 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer accountUnlock() - peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) - defer peerUnlock() account, err := am.Store.GetAccount(ctx, accountID) if err != nil { @@ -1889,8 +1892,13 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey return peer, netMap, postureChecks, nil } -func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error { - accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peer.Key) +func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peerPubKey string) error { + // acquiring peer write lock here is ok since we only modify peer information that is supplied by the + // peer itself which can't be modified by API, and it only happens after an account read lock is acquired + peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) + defer peerUnlock() + + accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey) if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { return status.Errorf(status.Unauthenticated, "peer not registered") @@ -1900,17 +1908,15 @@ func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peer *n accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer accountUnlock() - peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peer.Key) - defer peerUnlock() account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } - err = am.MarkPeerConnected(ctx, peer.Key, false, nil, account) + err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account) if err != nil { - log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peer.Key, err) + log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } return nil diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 170e72dd0e6..4a12a5c3eda 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -236,7 +236,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer * func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) { s.peersUpdateManager.CloseChannel(ctx, peer.ID) s.turnCredentialsManager.CancelRefresh(peer.ID) - _ = s.accountManager.CancelPeerRoutines(ctx, peer) + _ = s.accountManager.CancelPeerRoutines(ctx, peer.Key) s.ephemeralManager.OnPeerDisconnected(ctx, peer) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 25bcdfcee71..1adf9a2d63d 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -112,7 +112,7 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey st return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peer *nbpeer.Peer) error { +func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peerPubKey string) error { // TODO implement me panic("implement me") }